How can I use Cython well to solve a differential equation faster?
Asked Answered
U

4

8

I would like to lower the time Scipy's odeint takes for solving a differential equation.

To practice, I used the example covered in Python in scientific computations as template. Because odeint takes a function f as argument, I wrote this function as a statically typed Cython version and hoped the running time of odeint would decrease significantly.

The function f is contained in file called ode.pyx as follows:

import numpy as np
cimport numpy as np
from libc.math cimport sin, cos

def f(y, t, params):
  cdef double theta = y[0], omega = y[1]
  cdef double Q = params[0], d = params[1], Omega = params[2]
  cdef double derivs[2]
  derivs[0] = omega
  derivs[1] = -omega/Q + np.sin(theta) + d*np.cos(Omega*t)
  return derivs

def fCMath(y, double t, params):
  cdef double theta = y[0], omega = y[1]
  cdef double Q = params[0], d = params[1], Omega = params[2]
  cdef double derivs[2]
  derivs[0] = omega
  derivs[1] = -omega/Q + sin(theta) + d*cos(Omega*t)
  return derivs

I then create a file setup.py to complie the function:

from distutils.core import setup
from Cython.Build import cythonize

setup(ext_modules=cythonize('ode.pyx'))

The script solving the differential equation (also containing the Python version of f) is called solveODE.py and looks as:

import ode
import numpy as np
from scipy.integrate import odeint
import time

def f(y, t, params):
    theta, omega = y
    Q, d, Omega = params
    derivs = [omega,
             -omega/Q + np.sin(theta) + d*np.cos(Omega*t)]
    return derivs

params = np.array([2.0, 1.5, 0.65])
y0 = np.array([0.0, 0.0])
t = np.arange(0., 200., 0.05)

start_time = time.time()
odeint(f, y0, t, args=(params,))
print("The Python Code took: %.6s seconds" % (time.time() - start_time))

start_time = time.time()
odeint(ode.f, y0, t, args=(params,))
print("The Cython Code took: %.6s seconds ---" % (time.time() - start_time))

start_time = time.time()
odeint(ode.fCMath, y0, t, args=(params,))
print("The Cython Code incorpoarting two of DavidW_s suggestions took: %.6s seconds ---" % (time.time() - start_time))

I then run:

python setup.py build_ext --inplace
python solveODE.py 

in the terminal.

The time for the python version is approximately 0.055 seconds, whilst the Cython version takes roughly 0.04 seconds.

Does somebody have a recommendation to improve on my attempt of solving the differential equation, preferably without tinkering with the odeint routine itself, with Cython?

Edit

I incorporated DavidW's suggestion in the two files ode.pyx and solveODE.py It took only roughly 0.015 seconds to run the code with these suggestions.

Ungodly answered 16/3, 2017 at 15:18 Comment(11)
You should post this to codereview insteadCoble
I might try numba instead of cython, but any difference is likely to be small. most of the computation time is likely the context switching taking place when odeint calls your function. you may honestly see the best gains from writing your own numerical integration function (again with cython or numba) to avoid the context switchingIndefatigable
@Coble I kind of agree, but experience suggests that people get better answers to "speed up Cython" questions here than codereview so I'm never sure if this is good advice or notFeminacy
@Indefatigable Thank you for your suggestion! The documentation for odeint says odeint "solve[s] a system of ordinary differential equations using lsoda from the FORTRAN library odepack" I thought odepack already is compiled code. Is that correct and if so do you nevertheless expect significant performance improvements with a own compiled integration function?Ungodly
@Ungodly I have not read through the source code itself, but your function f and ode.f are python objects that require a context switch at least once per call (4000 calls for 0-200 in steps of 0.05) otherwise odeint wouldn't be able to take any old custom user function. I've gotten a 4x speedup with numba, but I'm working rn to get more...Indefatigable
@Indefatigable Thank you for your second remark as well! Based on incorporating two of DavidW's suggestion, the code Cython version took roughly 1/3 of the time the Python version took. (see edit of my question) I'm curious to see your suggestion with numba. Do you have an explanation for the differences between the numba and Cython implementation?Ungodly
@Wrzlprmft Thank you for your suggestion! Could you maybe elaborate a bit more how I can implement your module? I'd like to try it but cannot really figure out how to do that.Ungodly
@Ungodly For numba you literally just do import numba and @numba.jit on the line before your function definition. Use your Python version unchanged. Be aware that the first time you run it will be slow (while it compiles) so it's worth timing it twice.Feminacy
@fabian: Yes, but that would be beyond the scope of this question. Please make a Stack Overflow Chat account and invite me to a new room.Mcleroy
@Coble Don't suggest CodeReview just because they want to make the code faster. Pay attention to the tag's popularity on the respective boards. CR is great if you want to refine C++ or Java code, but not nearly as good when dealing with specialized packages like Cython.Wrapping
I would just like to point out that your code is running in python mode. It is statically compiled, but you may want to look into cpdef. I believe however (as other pointed out) that the heavy lifting will be done by the ode solver.Aragats
F
5

The easiest change to make (which will probably gain you a lot) is to use the C math library sin and cos for operations on single numbers instead of number. The call to numpy and the time spent working out that it isn't an array is fairly costly.

from libc.math cimport sin, cos

    # later
    -omega/Q + sin(theta) + d*cos(Omega*t)

I'd be tempted to assign a type to the input d (none of the other inputs are easily typed without changing the interface):

def f(y, double t, params):

I think I'd also just return a list like you do in your Python version. I don't think you gain a lot by using a C array.

Feminacy answered 16/3, 2017 at 15:42 Comment(1)
Thank you for your suggestion! Indeed, by using the C math library the code improves roughly 40% relative to my version and is in total approximately twice as fast as the Python code. Typing t as you suggested further improves the code by a few percentage points.Ungodly
I
3

tldr; use numba.jit for 3x speedup...

I don't have much experience with cython, but my machine seems to get similar computation times for your strictly python version, so we should be able to compare roughly apples to apples. I used numba to compile the function f (which I re-wrote slightly to make it play nicer with the compiler).

def f(y, t, params):
    return np.array([y[1], -y[1]/params[0] + np.sin(y[0]) + params[1]*np.cos(params[2]*t)])

numba_f = numba.jit(f)

dropping in numba_f in place of your ode.f gives me this output...

The Python Code took: 0.0468 seconds
The Numba Code took: 0.0155 seconds

I then wondered if I could duplicate odeint and also compile with numba to speed things up even further... (I could not)

Here is my Runge-Kutta numerical differential equation integrator:

#function f is provided inline (not as an arg)
def runge_kutta(y0, steps, dt, args=()): #improvement on euler's method. *note: time steps given in number of steps and dt
    Y = np.empty([steps,y0.shape[0]])
    Y[0] = y0
    t = 0
    n = 0
    for n in range(steps-1):
        #calculate coeficients
        k1 = f(Y[n], t, args) #(euler's method coeficient) beginning of interval
        k2 = f(Y[n] + (dt * k1 / 2), t + (dt/2), args) #interval midpoint A
        k3 = f(Y[n] + (dt * k2 / 2), t + (dt/2), args) #interval midpoint B
        k4 = f(Y[n] + dt * k3, t + dt, args) #interval end point

        Y[n + 1] = Y[n] + (dt/6) * (k1 + 2*k2 + 2*k3 + k4) #calculate Y(n+1)
        t += dt #calculate t(n+1)
    return Y

naive looping functions are typically the fastest once compiled, although this could probably be re-structured for a little better speed. I should note, this gives a different answer than odeint, deviating by as much as .001 after around 2000 steps, and is completely different after 3000. For the numba version of the function, I simply replaced f with numba_f, and added the compilation with @numba.jit as a decorator. In this case, as expected the pure python version is very slow, but the numba version is not any faster than the numba with odeint (again, ymmv).

using custom integrator
The Python Code took: 0.2340 seconds
The Numba Code took: 0.0156 seconds

Here's an example of compiling ahead of time. I don't have the necessary toolchain on this computer to compile, and I don't have admin to install it, so this gives me an error that I don't have the required compiler, but it should work otherwise.

import numpy as np
from numba.pycc import CC

cc = CC('diffeq')

@cc.export('func', 'f8[:](f8[:], f8, f8[:])')
def func(y, t, params):
    return np.array([y[1], -y[1]/params[0] + np.sin(y[0]) + params[1]*np.cos(params[2]*t)])

cc.compile()
Indefatigable answered 16/3, 2017 at 17:27 Comment(7)
Thank you very much for your detailed answer! I learned a lot from it. Unfortunately, I my code was severely slowed down when I used the @jit declaration. In fact, the code took roughly 2 seconds. My python version was shipped with anaconda 4.3.1 and which supplies numbda 0.30.1. Do you have an idea for such slow results?Ungodly
@Ungodly the initial compilation takes time, but every subsequent run should be fast. someone already mentioned this in the main comments thread. it's like doing the cython compilation at runtime instead of beforehand. numba does have pre-compilation support as well if you look through their documentation, but I've never used it.Indefatigable
Ah yes, you're right! Thanks! When I enter python solveODE.py the code is always compiled afresh. Is there a way to compile the code once via the terminal and subsequently run the compiled code in the terminal?Ungodly
here's the relevant documentation to do that.. You basically make a library, and compile it into an extension lib.Indefatigable
@Ungodly Or use cache=True for jit (numba.pydata.org/numba-doc/0.30.1/reference/…) possiblyFeminacy
@Ungodly I added an example of compilation. I don't actually have the compiler on this computer, so I can't test it right now. (missing vcvarsall.bat error)Indefatigable
@Indefatigable I cannot thank you enough for all your suggestions - the example of compilations works just like a charm. ThanksUngodly
M
2

If others answer this question using other modules, I might as well chime in:

I am the author of JiTCODE, which accepts an ODE written in SymPy symbols and then converts this ODE to C code for a Python module, compiles this C code, loads the result and uses this as a derivative for SciPy’s ODE. Your example translated to JiTCODE looks like this:

from jitcode import jitcode, provide_basic_symbols
import numpy as np
from sympy import sin, cos
import time

Q = 2.0
d = 1.5
Ω = 0.65

t, y = provide_basic_symbols()

f = [
    y(1),
    -y(1)/Q + sin(y(0)) + d*cos(Ω*t)
    ]

initial_state = np.array([0.0,0.0])

ODE = jitcode(f)
ODE.set_integrator("lsoda")
ODE.set_initial_value(initial_state,0.0)

start_time = time.time()
data = np.vstack(ODE.integrate(T) for T in np.arange(0.05, 200., 0.05))
end_time = time.time()
print("JiTCODE took: %.6s seconds" % (end_time - start_time))

This takes 0.11 seconds, which is horribly slow compared to the solutions based on odeint, but this is not due to the actual integration but the way the results are handled: While odeint directly creates an array efficiently internally, this is done via Python here. Depending on what you do, this may be a crucial disadvantage, but this quickly becomes irrelevant for a coarser sampling or larger differential equations.

So, let’s remove the data collection and just look at the integration, by replacing the last lines with the following:

ODE = jitcode(f)
ODE.set_integrator("lsoda", max_step=0.05, nsteps=1e10)
ODE.set_initial_value(initial_state,0.0)

start_time = time.time()
ODE.integrate(200.0)
end_time = time.time()
print("JiTCODE took: %.6s seconds" % (end_time - start_time))

Note that I set max_step=0.05 to force the integrator to make at least as many steps as in your example and ensure that the only difference is that the results of the integration are not stored to some array. This runs in 0.010 seconds.

Mcleroy answered 17/3, 2017 at 10:32 Comment(0)
L
2

NumbaLSODA takes 0.00088 seconds (17x faster than Cython).

from NumbaLSODA import lsoda_sig, lsoda
import numba as nb
import numpy as np
import time

@nb.cfunc(lsoda_sig)
def f(t, y_, dy, p_):
    p = nb.carray(p_, (3,))
    y = nb.carray(y_, (2,))
    theta, omega = y
    Q, d, Omega = p
    dy[0] = omega
    dy[1] = -omega/Q + np.sin(theta) + d*np.cos(Omega*t)

funcptr = f.address # address to ODE function
y0 = np.array([0.0, 0.0])
data = np.array([2.0, 1.5, 0.65])
t = np.arange(0., 200., 0.05)

start_time = time.time()
usol, success = lsoda(funcptr, y0, t, data = data)
print("NumbaLSODA took: %.8s seconds ---" % (time.time() - start_time))

result

NumbaLSODA took: 0.000880 seconds ---
Lepto answered 7/7, 2021 at 4:21 Comment(0)

© 2022 - 2024 — McMap. All rights reserved.