A Numba Implementation
As @Jérôme Richard already mentioned you can also write a pure Numba implementation. I partially used this code generation function on both einsums with some manual of code editing.
Please be aware that from Numba version 0.53 to 0.56, there is a bug/feature, which usually has a high performance impact. I would recommend to change that in version 0.53 until 0.57, if the little benefit on compilation times doesn't matter. Beginning with 0.57 this option seems to be slower than the default.
Pros/Cons
- Much faster than the accepeted solution (and likely the Matlab solution)
- Very small temporary arrays,if memory usgae is a problem
- Scales well with the number of cores you use (there may be problems with newer big/little Intel CPUS, but still around 600ms on a new notebook)
- The code is hard to quickly understand, comments are necessary to understand what's happening
Implementation
#set chache false to test the behaviour of
#https://github.com/numba/numba/issues/8172#issuecomment-1160474583
#and of course restart the interpreter
@nb.njit(fastmath=True,parallel=True,cache=False)
def einsum(X,w):
#For loop unrolling
assert X.shape[1] ==4
assert w.shape[2] ==1
#For safety
assert X.shape[0] == w.shape[0]
assert X.shape[2] == w.shape[1]
i_s = X.shape[0]
x_s = X.shape[1]
j_s = X.shape[2]
l_s = X.shape[0]
m_s = X.shape[2]
k_s = w.shape[2]
n_s = w.shape[2]
res = np.empty((i_s,l_s))
for i in nb.prange(i_s):
for l in range(l_s):
#TMP_0 is thread local, it will be omptimized out of the loop by Numba in parallel mode
#np.einsum('xm,xj->jm', X,X) -> TMP_0
TMP_0 = np.zeros((j_s,m_s))
for x in range(x_s):
for j in range(j_s):
for m in range(m_s):
TMP_0[j,m]+=X[l,x,m] *X[i,x,j]
#EXP in-place
for j in range(j_s):
for m in range(m_s):
TMP_0[j,m] = np.exp(TMP_0[j,m])
#TMP_1 is thread local, it will be omptimized out of the loop by Numba in parallel mode
#np.einsum('jm,jk->m', TMP_0,w[i]) -> TMP_1
TMP_1 = np.zeros((m_s))
for j in range(j_s):
for m in range(m_s):
for k in range(k_s):
TMP_1[m]+=TMP_0[j,m] *w[i,j,k]
#np.einsum('m,mn->', TMP_1,w[l]) -> res
acc=0
for m in range(m_s):
for n in range(n_s):
acc+=TMP_1[m] *w[l,m,n]
res[i,l]=acc
return res
Timings on Ryzen 5 5600G (6C/12T)
Orignal implementation (unique characters):
%timeit G3 = np.einsum('ijk,iljm,lmn->il', w, np.exp(np.einsum('ixj,lxm->iljm',X,X)), w)
4.45 s ± 14.1 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
Jérôme Richard's implementation:
1.43 s ± 102 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
My implementation on unmodified Numba abobe v0.53, have to be modified if performance is the main goal, which is usually is the case if you use Numba :-(
665 ms ± 13.2 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
My implementation below v0.53, or modified newer Numba:
142 ms ± 3.03 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
Updated timings
The previous timings where with Numba 0.55, starting with 0.57 Numba seems to show another behaviour. The runtime is now faster with the default, but still a bit slower than vers. 0.56 with opt=2
:
%timeit G2 = einsum(X,w)
#0.56, windows installed via pip (opt=2))
#706 ms ± 13.9 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
#0.56, windows installed via pip (opt=2))
#153 ms ± 2.68 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
#0.57, windows installed via pip (default)
#173 ms ± 1.79 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)
#0.57, windows installed via pip (opt=2)
#247 ms ± 1.64 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
For comparable timings check if SVML has been used
This should be the default on Anaconda Python, but may not be the case on Standard Python.
def check_SVML(func):
if 'intel_svmlcc' in func.inspect_llvm(func.signatures[0]):
print("found")
else:
print("not found")
check_SVML(einsum)
#found
np.matmul(A, B)
can be rewrittenA @ B
, which is maybe easier to read. AlsoA.transpose()
is equivalent toA.T
, which is also easier to read. – Maragaretfor i in range(M): G[:, i] = ...
is a bit suspicious, there should be a way to write that directly in numpy asG = ...
without the for-loop. – Maragaretnp.exp(np.einsum('ijk,ljn->ilkn',X,X))
produces a (100,100,200,200) shape, and is I think the slowest step. My timeit run killed theipython
session, so I don't think I'll explore more. – Alethaalethea'ijk,ljn->ilkn'
and'ijk,iljm,lmn->il'
. there isn't an obvious 'batch' dimension for thematmul
operations. – Alethaaletheafor i=1:M
is not much of a time penalty. I don't know whatpagemtimes
does - I suspect it's an extension of it's matrix multiplication to handle more than 2 dimensional matrices. In any case, my testing witheinsum
indicates that this is a big task. The 4dexp
term is 3 Gb. – Alethaaletheaopt_einsum
, but it turns out results are rather disappointing on this use-case: I got slightly slower results (1.5 s) on my machine with this module. – Swishopt_einsum
, though possibly only when working with much larger arrays. – Rustler