Tag Archives: python

python script to solve ODEs

Here is a simple python script that will solve a system of N Ordinary Differential Equations, based on N initial conditions. It will plot the result using matplotlib (so you have to install that python module). It uses a Runge-Kutta 4th-5th order stepper (inspired by the course that I thought at USD and by Numerical Recipes). The “driver” that uses the stepper is not fancy (i.e. there is no variable optimized step size).

The specific problem is implemented by filling the derivs() function (it wants an array giving all the derivatives of the functions to be solved, yin, as a function of x). e.g. dy1/dx = a1, dy2/dx = a2 (where the as are the elements of the array and can depend on the ys and x).

The example below implements a simple harmonic oscillator (dx/dt=v, dv/dt=-kx), where position and speed are the two variables (y[0] and y[1]) to solve. The initial conditions are a speed of 0 and a position of x=1, at t=0.

#! /usr/bin/python

#from ROOT import TGraph, TCanvas
import matplotlib.pyplot as plt

#RK4-5 stepper
def step(stepsize, yin, x, derivs):
    dy=derivs(x,yin)
    k1=[]
    y2=[]
    for i in range(0,len(yin)):
        k1.append(stepsize*dy[i])
        y2.append(yin[i]+0.5*k1[i])

    dy=derivs(x+0.5*stepsize,y2)
    k2=[]
    y3=[]
    for i in range(0,len(yin)):
        k2.append(stepsize*dy[i])
        y3.append(yin[i]+0.5*k2[i])

    dy=derivs(x+0.5*stepsize,y3)
    k3=[]
    y4=[]
    for i in range(0,len(yin)):
        k3.append(stepsize*dy[i])
        y4.append(yin[i]+k3[i])

    dy=derivs(x+stepsize,y4)
    k4=[]
    yout=[]
    for i in range(0,len(yin)):
        k4.append(stepsize*dy[i])
        yout.append(yin[i]+k1[i]/6.0+k2[i]/3.0+k3[i]/3.0+k4[i]/6.0)

    return yout

def solve(stepsize,x0,nSteps,derivs,initialConditions):
    y=[]
    y.append([x0,initialConditions])
    print "Initial ", y
    for i in range (0,nSteps):
      x=x0+i*stepsize
      #print y[len(y)-1][1:][0]
      y.append([x,step(stepsize,y[len(y)-1][1:][0],x,derivs)])
    return y

def plot(yx):
    ny=len(yx[1])
    for i in range(0,ny):
        y=[]
        x=[]
        for j in range (0,len(yx)):
            x.append(yx[j][0])
            y.append(yx[j][1][i])
        plt.plot(x,y)
        plt.show()

#problem-specific derivatives
def derivs(x,yin):
    print "derivatives at ",x," ",yin
    dy=[yin[1],-0.5*yin[0]]
    return dy

########################################
########################################
#initial conditions for each function
initialConditions=[1,0]
yx=solve(0.1,0.0,100,derivs,initialConditions)

#Do the drawing
plot(yx)