How to speed up symbolic derivatives of long functions using SymPy?
Asked Answered
J

2

10

I am writing a program in Python to solve the Schrödinger equation using the Free ICI Method (well, SICI method right now... but Free ICI is what it will turn into). If this does not sound familiar, that is because there is very little information out there on the subject, and absolutely no sample code to work from.

This process involves iteratively arriving at a solution to the partial differential equation. In doing this, there are a lot of symbolic derivatives that need to be performed. The problem is, as the program runs, the functions that need to be differentiated continue to get larger and larger so that by the fifth iteration it takes a very large amount of time to compute the symbolic derivatives.

I need to speed this up because I'd like to be able to achieve at least 30 iterations, and I'd like to have it do that before I retire.

I've gone through and removed unnecessary repeats of calculations (or at least the ones I know of), which has helped quite a bit. Beyond this, I have absolutely no clue how to speed things up.

Here is the code where containing the function that is computing the derivatives (the inf_integrate function is just the composite Simpson’s method, as it is way faster than using SymPy’s integrate, and doesn’t raise errors due to oscillatory functions):

from sympy import *


def inf_integrate(fun, n, a, b):
    f = lambdify(r, fun)
    h = (b-a)/n
    XI0 = f(a) + f(b)
    XI1 = 0
    XI2 = 0

    for i in range(1, n):
        X = a + i*h

        if i % 2 == 0:
            XI2 = XI2 + f(X)
        else:
            XI1 = XI1 + f(X)

    XI = h*(XI0 + 2*XI2 + 4*XI1)/3

    return XI


r = symbols('r')

def H(fun):
    return (-1/2)*diff(fun, r, 2) - (1/r)*diff(fun, r) - (1/r)*fun

E1 = symbols('E1')
low = 10**(-5)
high = 40
n = 5000

g = Lambda(r, r)


psi0 = Lambda(r, exp(-1.5*r))

I1 = inf_integrate(4*pi*(r**2)*psi0(r)*H(psi0(r)), n, low, high)
I2 = inf_integrate(4*pi*(r**2)*psi0(r)*psi0(r), n, low, high)

E0 = I1/I2
print(E0)

for x in range(10):

    f1 = Lambda(r, psi0(r))
    f2 = Lambda(r, g(r)*(H(psi0(r)) - E0*psi0(r)))
    Hf1 = Lambda(r, H(f1(r)))
    Hf2 = Lambda(r, H(f2(r)))

    H11 = inf_integrate(4*pi*(r**2)*f1(r)*Hf1(r), n, low, high)
    H12 = inf_integrate(4*pi*(r**2)*f1(r)*Hf2(r), n, low, high)
    H21 = inf_integrate(4*pi*(r**2)*f2(r)*Hf1(r), n, low, high)
    H22 = inf_integrate(4*pi*(r**2)*f2(r)*Hf2(r), n, low, high)

    S11 = inf_integrate(4*pi*(r**2)*f1(r)*f1(r), n, low, high)
    S12 = inf_integrate(4*pi*(r**2)*f1(r)*f2(r), n, low, high)
    S21 = S12
    S22 = inf_integrate(4*pi*(r**2)*f2(r)*f2(r), n, low, high)

    eqn = Lambda(E1, (H11 - E1*S11)*(H22 - E1*S22) - (H12 - E1*S12)*(H21 - E1*S21))

    roots = solve(eqn(E1), E1)

    E0 = roots[0]

    C = -(H11 - E0*S11)/(H12 - E0*S12)

    psi0 = Lambda(r, f1(r) + C*f2(r))

    print(E0)

The program is working and converges to exactly what the expected result is, but it is way too slow. Any help on speeding this up is very much appreciated.

Jehanna answered 16/5, 2019 at 4:22 Comment(1)
Ok. Thank you for the clarity. I'm starting to see the problem now: thousands of terms growing in a combinatorial explosion. \n What is this being used for--how precise does this need to be (for optimization purposes)?Disproportionation
A
9

There are several things you can do here:

  1. If you profile your code, you will notice that you spend most time in the integration function inf_integrate, mostly because you are using manual Python loops. This can be amended by turning the argument into a vectorised function and using SciPy’s integration routines (which are compiled and thus fast).

  2. When you are using nested symbolic expressions, it may be worthwhile checking whether an occasional explicit simplification can help to keep the exploding complexity in check. This appears to be the case here.

  3. All the Lamda functions you defined are not needed. You can simplify work with expressions. I haven’t checked whether this actually affects the runtime, but it certainly helps with the next step (since SymEngine does not have Lambda yet).

  4. Use SymEngine instead of SymPy. SymPy (as of now) is purely Python-based and hence slow. SymEngine is its compiled core in the making and can be considerably faster. It has almost all the functionalities you need.

  5. With every step, you solve an equation that does not change its nature: It’s always the same quadratic equation, only the coefficients change. By solving this once in general, you save a lot of time, in particular by SymPy not having to deal with complicated coefficients.

Taking all together, I arrive at the following:

from symengine import *
import sympy
from scipy.integrate import trapz
import numpy as np

r, E1 = symbols('r, E1')
H11, H22, H12, H21 = symbols("H11, H22, H12, H21")
S11, S22, S12, S21 = symbols("S11, S22, S12, S21")
low = 1e-5
high = 40
n = 5000

quadratic_expression = (H11-E1*S11)*(H22-E1*S22)-(H12-E1*S12)*(H21-E1*S21)
general_solution = sympify( sympy.solve(quadratic_expression,E1)[0] )
def solve_quadratic(**kwargs):
    return general_solution.subs(kwargs)

sampling_points = np.linspace(low,high,n)
def inf_integrate(fun):
    f = lambdify([r],[fun])
    values = f(sampling_points)
    return trapz(values,sampling_points)

def H(fun):
    return -fun.diff(r,2)/2 - fun.diff(r)/r - fun/r

psi0 = exp(-3*r/2)
I1 = inf_integrate(4*pi*(r**2)*psi0*H(psi0))
I2 = inf_integrate(4*pi*(r**2)*psi0**2)
E0 = I1/I2
print(E0)

for x in range(30):
    f1 = psi0
    f2 = r * (H(psi0)-E0*psi0)
    Hf1 = H(f1)
    Hf2 = H(f2)

    H11 = inf_integrate( 4*pi*(r**2)*f1*Hf1 )
    H12 = inf_integrate( 4*pi*(r**2)*f1*Hf2 )
    H21 = inf_integrate( 4*pi*(r**2)*f2*Hf1 )
    H22 = inf_integrate( 4*pi*(r**2)*f2*Hf2 )

    S11 = inf_integrate( 4*pi*(r**2)*f1**2 )
    S12 = inf_integrate( 4*pi*(r**2)*f1*f2 )
    S21 = S12
    S22 = inf_integrate( 4*pi*(r**2)*f2**2 )

    E0 = solve_quadratic(
            H11=H11, H22=H22, H12=H12, H21=H21,
            S11=S11, S22=S22, S12=S12, S21=S21,
        )
    print(E0)

    C = -( H11 - E0*S11 )/( H12 - E0*S12 )
    psi0 = (f1 + C*f2).simplify()

This converges to −½ in a few seconds on my machine.

Afeard answered 16/5, 2019 at 8:17 Comment(4)
Thank you! That was a very timely response, and it works incredibly well now. Do you know how I can get it do display a larger amount of digits on the E0 values? Also, the integrals are supposed to be from 0 t0 oo, but with those bounds it has been causing some errors. Setting the upper bound as 40 is for the most part "oo", and setting the lower bound as 10e-5 is for the most part (0). This works well for now, but I have a feeling when I try to improve the precision it will cause too much of an error.Jehanna
You can always use decimal to improve precision to show a greater number of digits. docs.python.org/3.7/library/decimal.html It may slow things down a little bit, but probably not severely.Disproportionation
@Shrodinger149: There are certainly better-suited methods for numerically estimating integrals to infinity, in particular adaptive ones since you can freely evaluate your function. If you can provide some specific details of your problem (how quick do your functions decay; do they oscillate; etc.?), this might be a question for Computational Science. But it may also be worth digging a bit into the topic yourself.Afeard
@Shrodinger149: As for the precision, are you really sure that you need more than double-precision floating-point accuracy? Anyway, it is the accuracy of the integral that is limiting your precision right now, so there is no point in going to a higher numerical precision.Afeard
J
2

Wrzlprmft's answer was great. I've gone ahead and cleaned things up, and swapped the clunky integration function with SymPy's integrate. This did not work on my original code, but works perfectly after Wrzlprmft's corrections/additions. The program is a little bit slower (still orders of magnitude faster than my original), but it no longer has the error that was limiting the precision. Here is the final code:

from symengine import *
from sympy import *
import sympy

r, E1 = symbols('r, E1')
H11, H22, H12, H21 = symbols("H11, H22, H12, H21")
S11, S22, S12, S21 = symbols("S11, S22, S12, S21")
low = 0
high = oo
n = 100000

quadratic_expression = (H11-E1*S11)*(H22-E1*S22)-(H12-E1*S12)*(H21-E1*S21)
general_solution = sympify(sympy.solve(quadratic_expression, E1)[0])


def solve_quadratic(**kwargs):
    return general_solution.subs(kwargs)


def H(fun):
    return -fun.diff(r, 2)/2 - fun.diff(r)/r - fun/r


psi0 = exp(-3*r/2)
I1 = N(integrate(4*pi*(r**2)*psi0*H(psi0), (r, low, high)))
I2 = N(integrate(4*pi*(r**2)*psi0**2, (r, low, high)))
E0 = I1/I2
print(E0)

for x in range(100):
    f1 = psi0
    f2 = r * (H(psi0)-E0*psi0)
    Hf1 = H(f1)
    Hf2 = H(f2)

    H11 = integrate(4*pi*(r**2)*f1*Hf1, (r, low, high))
    H12 = integrate(4*pi*(r**2)*f1*Hf2, (r, low, high))
    H21 = integrate(4*pi*(r**2)*f2*Hf1, (r, low, high))
    H22 = integrate(4*pi*(r**2)*f2*Hf2, (r, low, high))

    S11 = integrate(4*pi*(r**2)*f1**2, (r, low, high))
    S12 = integrate(4*pi*(r**2)*f1*f2, (r, low, high))
    S21 = S12
    S22 = integrate(4*pi*(r**2)*f2**2, (r, low, high))

    E0 = solve_quadratic(
            H11=H11, H22=H22, H12=H12, H21=H21,
            S11=S11, S22=S22, S12=S12, S21=S21,
        )
    print(E0)

    C = -(H11 - E0*S11)/(H12 - E0*S12)
    psi0 = (f1 + C*f2).simplify()

Jehanna answered 17/5, 2019 at 3:8 Comment(3)
You can again speed this up with changing the first lines to from symengine import *; import sympy; integrate = lambda *args: sympy.N(sympy.integrate(*args)) This addresses the following: 1) You are wasting a considerable amount of time by explicitly carrying around terms like 1+sqrt(pi), which probably doesn’t add much to your accuracy and also accumulates. 2) from sympy import * completely overrides from symengine import * and deprives you of the speed advantages of SymEngine.Afeard
Accepted and upvoted (my upvote doesn't change anything because I do not have a good enough reputation). Thank you very much for all of the help. For the integrate = lambda *args: sympy.N(sympy.integrate(*args)) part, where is that supposed to go? Also, what is args?Jehanna
where is that supposed to go? – you can put it directly after the imports. — Also, what is args? – It’s a tuple containing all the arguments. The name args is the standard convention when you write a simple wrapper and just want to pass all positional arguments to another function as they are.Afeard

© 2022 - 2024 — McMap. All rights reserved.