Numpy matmul and einsum 6 to 7 times slower than MATLAB
Asked Answered
Q

3

6

I am trying to port some code from MATLAB to Python and I am getting much slower performance from Python. I am not very good at Python coding, so any advise to speed these up will be much appreciated.

I tried an einsum one-liner (takes 7.5 seconds on my machine):

import numpy as np

n = 4
N = 200
M = 100
X = 0.1*np.random.rand(M, n, N)
w = 0.1*np.random.rand(M, N, 1)

G = np.einsum('ijk,iljm,lmn->il', w, np.exp(np.einsum('ijk,ljn->ilkn',X,X)), w)

I also tried a matmult implementation (takes 6 seconds on my machine)

G = np.zeros((M, M))
for i in range(M):
    G[:, i] = np.squeeze(w[i,...].T @ (np.exp(X[i, :, :].T @ X) @ w))

But my original MATLAB code is way faster (takes 1 second on my machine)

n = 4;
N = 200;
M = 100;
X = 0.1*rand(n, N, M);
w = 0.1*rand(N, 1, M);

G=zeros(M);
for i=1:M
    G(:,i) = squeeze(pagemtimes(pagemtimes(w(:,1,i).', exp(pagemtimes(X(:,:,i),'transpose',X,'none'))) ,w));
end

I was expecting both Python implementations to be comparable in speed, but they are not. Any ideas why the Python implementations are this slow, or any suggestions to speed those up?

Quarrel answered 27/6, 2023 at 18:41 Comment(12)
Note that np.matmul(A, B) can be rewritten A @ B, which is maybe easier to read. Also A.transpose() is equivalent to A.T, which is also easier to read.Maragaret
Edited the question, yes that is much easier to read :-)Quarrel
Your loop for i in range(M): G[:, i] = ... is a bit suspicious, there should be a way to write that directly in numpy as G = ... without the for-loop.Maragaret
np.exp(np.einsum('ijk,ljn->ilkn',X,X)) produces a (100,100,200,200) shape, and is I think the slowest step. My timeit run killed the ipython session, so I don't think I'll explore more.Alethaalethea
@Stef, with 'ijk,ljn->ilkn' and 'ijk,iljm,lmn->il'. there isn't an obvious 'batch' dimension for the matmul operations.Alethaalethea
MATLAB does a lot of JIT compiling, so your for i=1:M is not much of a time penalty. I don't know what pagemtimes does - I suspect it's an extension of it's matrix multiplication to handle more than 2 dimensional matrices. In any case, my testing with einsum indicates that this is a big task. The 4d exp term is 3 Gb.Alethaalethea
If MATLAB is using JIT, could Numba be used to speed this up and have an accurate comparison?Rinderpest
Also, on my machine, the second Python implementation is much faster than the first. The first takes about 4.9s while the second takes 2.0s, i.e. more than 2x speed up. That said, MATLAB takes 0.55 on my machine, so it is still almost 4x faster than the fastest Python implementation that you show.Rinderpest
@Rinderpest The JIT does not matter much here. Indeed Matlab and Numpy should spent more of their time in BLAS routine. That being said, the einsum implementation of Numpy is not very efficient currently in such a case when combined with OpenBLAS. I expect Matlab to uses an efficient implementation for such a basic implementation (it is designed for matrix manipulations like this).Swish
Just a suggestion: you should try opt_einsum. It implements some optimisations and can be significantly faster than default Numpy. Also, I think an older version of this package was used as the basis for Numpy's einsum optimiser.Rustler
@Rustler I was excited to test opt_einsum, but it turns out results are rather disappointing on this use-case: I got slightly slower results (1.5 s) on my machine with this module.Swish
@JérômeRichard interesting! It is significantly slower. But I'd keep it in mind for future use: I've seen huge improvements with opt_einsum, though possibly only when working with much larger arrays.Rustler
I
9

First of all np.einsum has a parameter optimize which is set to False by default (mainly because the optimization can be more expensive than the computation in some cases and it is better in general to pre-compute the optimal path in a separate call first). You can use optimal=True to significantly speed-up np.einsum (it provides the optimal path in this case though the internal implementation is not be optimal). Note that pagemtimes in Matlab is more specific than np.einsum so there is not need for such a parameter (i.e. it is fast by default in this case).

Moreover, Numpy function like np.exp create a new array by default. The thing is computing arrays in-place is generally faster (and it also consumes less memory). This can be done thanks to the out parameter.

The np.exp is pretty expensive on most machines because it runs serially (like most Numpy functions) and it is often not very optimized internally either. Using a fast math library like the one of Intel helps. I suspect Matlab uses such kind of fast math library internally. Alternatively, one can use multiple threads to compute this faster. This is easy to do with the numexpr package.

Here is the resulting more optimized Numpy code:

import numpy as np
import numexpr as ne

# [...] Same initialization as in the question

tmp = np.einsum('ijk,ljn->ilkn',X,X, optimize=True)
ne.evaluate('exp(tmp)', out=tmp)
G = np.einsum('ijk,iljm,lmn->il', w, tmp, w, optimize=True)

Performance results

Here are results on my machine (with a i5-9600KF CPU, 32 GiB of RAM, on Windows):

Naive einsums:        6.62 s
CPython loops:        3.37 s
This answer:          1.27 s   <----

max9111 solution:     0.47 s   (using an unmodified Numba v0.57)
max9111 solution:     0.54 s   (using a modified Numba v0.57)

The optimized code is about 5.2 times faster than the initial code and 2.7 times faster than the initial fastest one!


Note about performances and possible optimizations

The first einsum takes a significant fraction of the runtime in the faster implementation on my machine. This is mainly because einsum perform many small matrix multiplications internally in a way that is not very efficient. Indeed, each matrix multiplication is done in parallel by a BLAS library (like OpenBLAS library which is the default one on most machines like mine). The thing is OpenBLAS is not efficient to compute small matrices in parallel. In fact, computing each small matrix in parallel is not efficient. A more efficient solution is to compute all the matrix multiplication in parallel (each thread should perform several serial matrix multiplication). This is certainly what Matlab does and why it can be a bit faster. This can be done using a parallel Numba code (or with Cython) and by disabling the parallel execution of BLAS routines (note this can have performance side effects on a larger script if it is done globally).

Another possible optimization is to do all the operation at once in Numba using multiple threads. This solution can certainly reduce even more the memory footprint and further improve performance. However, this is far from being easy to write an optimized implementation and the resulting code will be significantly harder to maintain. This is what the max9111's code does.

Inductile answered 27/6, 2023 at 20:8 Comment(4)
Adding my own times: Using OMP_NUM_THREADS=1, the timings are 2.16, 2.01, and 1.99 for the einsum with optimize=True, OP's second method, and your numexpr method, respectively. Using OMP_NUM_THREADS=4, the timings become 2.07, 2.42, and 0.93, respectively. That's compared to 0.55 for MATLAB.Rinderpest
I have added a Numba implementation (bellow 150ms) , as an example. If someone want's to go this way is a different story.Curch
I added benchmark results accordingly. Thank you!Swish
Can you confirm that Numba 0.56 shows a different behaviour? Beginning with 0.57 my timings are still a bit slower with the default compared to 0.56 with opt=2, but 0.57 with opt=2 is even slower, like your results. A second thing: Maybe you don't have SVML running wich can explain the generally slower timings you get on a quite similar CPU (6C/12T) with AVX2.Curch
C
3

A Numba Implementation

As @Jérôme Richard already mentioned you can also write a pure Numba implementation. I partially used this code generation function on both einsums with some manual of code editing.

Please be aware that from Numba version 0.53 to 0.56, there is a bug/feature, which usually has a high performance impact. I would recommend to change that in version 0.53 until 0.57, if the little benefit on compilation times doesn't matter. Beginning with 0.57 this option seems to be slower than the default.

Pros/Cons

  • Much faster than the accepeted solution (and likely the Matlab solution)
  • Very small temporary arrays,if memory usgae is a problem
  • Scales well with the number of cores you use (there may be problems with newer big/little Intel CPUS, but still around 600ms on a new notebook)
  • The code is hard to quickly understand, comments are necessary to understand what's happening

Implementation

#set chache false to test the behaviour of
#https://github.com/numba/numba/issues/8172#issuecomment-1160474583
#and of course restart the interpreter
@nb.njit(fastmath=True,parallel=True,cache=False)
def einsum(X,w):
    #For loop unrolling
    assert X.shape[1] ==4
    assert w.shape[2] ==1

    #For safety
    assert X.shape[0] == w.shape[0]
    assert X.shape[2] == w.shape[1]

    i_s = X.shape[0]
    x_s = X.shape[1]
    j_s = X.shape[2]
    l_s = X.shape[0]
    m_s = X.shape[2]
    k_s = w.shape[2]
    n_s = w.shape[2]

    res = np.empty((i_s,l_s))

    for i in nb.prange(i_s):
        for l in range(l_s):
            #TMP_0 is thread local, it will be omptimized out of the loop by Numba in parallel mode
            #np.einsum('xm,xj->jm', X,X) -> TMP_0
            TMP_0 = np.zeros((j_s,m_s))
            for x in range(x_s):
                for j in range(j_s):
                    for m in range(m_s):
                        TMP_0[j,m]+=X[l,x,m] *X[i,x,j]

            #EXP in-place
            for j in range(j_s):
                for m in range(m_s):
                    TMP_0[j,m] = np.exp(TMP_0[j,m])

            #TMP_1 is thread local, it will be omptimized out of the loop by Numba in parallel mode
            #np.einsum('jm,jk->m', TMP_0,w[i]) -> TMP_1
            TMP_1 = np.zeros((m_s))
            for j in range(j_s):
                for m in range(m_s):
                    for k in range(k_s):
                        TMP_1[m]+=TMP_0[j,m] *w[i,j,k]

            #np.einsum('m,mn->', TMP_1,w[l]) -> res
            acc=0
            for m in range(m_s):
                for n in range(n_s):
                    acc+=TMP_1[m] *w[l,m,n]
            res[i,l]=acc

    return res

Timings on Ryzen 5 5600G (6C/12T)

Orignal implementation (unique characters):

%timeit G3 = np.einsum('ijk,iljm,lmn->il', w, np.exp(np.einsum('ixj,lxm->iljm',X,X)), w)
4.45 s ± 14.1 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)

Jérôme Richard's implementation:

1.43 s ± 102 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)

My implementation on unmodified Numba abobe v0.53, have to be modified if performance is the main goal, which is usually is the case if you use Numba :-(

665 ms ± 13.2 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)

My implementation below v0.53, or modified newer Numba:

142 ms ± 3.03 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)

Updated timings

The previous timings where with Numba 0.55, starting with 0.57 Numba seems to show another behaviour. The runtime is now faster with the default, but still a bit slower than vers. 0.56 with opt=2:

%timeit G2 = einsum(X,w)

#0.56, windows installed via pip (opt=2))
#706 ms ± 13.9 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
#0.56, windows installed via pip (opt=2))
#153 ms ± 2.68 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)

#0.57, windows installed via pip (default)
#173 ms ± 1.79 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)
#0.57, windows installed via pip (opt=2)
#247 ms ± 1.64 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)

For comparable timings check if SVML has been used

This should be the default on Anaconda Python, but may not be the case on Standard Python.

def check_SVML(func):
    if 'intel_svmlcc' in func.inspect_llvm(func.signatures[0]):
        print("found")
    else:
        print("not found")

check_SVML(einsum)
#found
Curch answered 29/6, 2023 at 18:24 Comment(6)
Good solution. It looks like Numba 0.57 does not have the issue anymore, or at least not on my Windows machine for this use-case. The thing is the last message of the issue states that it should be solved later in v0.58 so I am confused... opt=2 and opt=3 give slightly slower results than opt=0 on my machine (repeated twice). I can't wait for v0.58 :D !Swish
Thank you for the detailed response, this looks very promising. I am much more familiar with MATLAB, where explicit for loops, especially nested ones, are very slow, so I try to avoid them as much as possible. Clearly not the case here! Also, what is the significance of "unique characters" in your answer? I thought the two calls to einsum were completely separate, so it didn't matter if I reused characters. Is that not the case?Quarrel
@JérômeRichard If it works you should see approx. 150ms results on your machine. Don't forget to retart the kernel, and delete/invalidate the cache. I had problems on newer Intel CPUS to reproduce a consitent fast result.Curch
@Quarrel I was just thinking about expanding/rewriting the code generation to something like einsum_gen('ijk,exp(ixj,lxm),lmn->iljm',w,X,X,w) or einsum_gen('ijk,exp(ixj,lxm),lmn->iljm',w,(,X,X),w)to directly get a working implemtation without manual edits. The spare time is too short.... ;)Curch
@Curch Your implementation makes my CPU fan sound like it is about to take off. None of the other implementations do. I am totally using that as a sniff test for code optimization from now on :-)Quarrel
@Curch I restarted the CPython process from scratch between tests and used exactly your code with cache=False. The different timings shows that the code is not cached anyway and that the optimization level impact performances. Maybe the use of a newer llvm version has an impact on this issue...Swish
M
1

Alternatively, you can just use Pytorch, with minimal changes to your code. Running your exact code on my computer, I get 0.6 s for Matlab (including making G) and 4.4 s for numpy. Running the following with torch I get 0.3 s. And this is all just in my CPU.

import torch as torch
import time

n = 4
N = 200
M = 100
X = 0.1*torch.rand(M, n, N)
w = 0.1*torch.rand(M, N, 1)

start = time.time()
G = torch.einsum('ijk,iljm,lmn->il', w, torch.exp(torch.einsum('ijk,ljn->ilkn',X,X)), w)
end = time.time()
print('Time: ', end-start)
Melosa answered 28/8 at 14:20 Comment(0)

© 2022 - 2024 — McMap. All rights reserved.