numpy einsum to get axes permutation
Asked Answered
D

1

7

What I understood in the documentation of ‘np.einsum‘ is that a permutation string, would give a permutation of the axis in a vector. This is confirmed by the following experiment:

>>> M = np.arange(24).reshape(2,3,4)
>>> M.shape
(2, 3, 4)
>>> np.einsum('ijk', M).shape
(2, 3, 4)
>>> np.einsum('ikj', M).shape
(2, 4, 3)
>>> np.einsum('jik', M).shape
(3, 2, 4)

But this I cannot understand:

>>> np.einsum('kij', M).shape
(3, 4, 2)

I would expect (4, 2, 3) instead... What's wrong with my understanding?

Divorcee answered 30/1, 2015 at 9:24 Comment(0)
J
11

When the output signature is not specified (i.e. there's no '->' in the subscripts string), einsum will create it by taking the letters it's been given and arranging them in alphabetical order.

This means that

np.einsum('kij', M)

is actually equivalent to

np.einsum('kij->ijk', M)

So writing 'kij' labels the axes of the input matrix, not the output matrix, and this leads to the permutation of the axes that you observed.

This point isn't made explicit in the documentation, but can be seen commented in the C source code for einsum:

/*
 * If there is no output signature, create one using each label
 * that appeared once, in alphabetical order
 */

To ensure the axes of M are permuted in the intended order, it may be necessary to give einsum the labeling for both the input and output matrices:

>>> np.einsum('ijk->kij', M).shape
(4, 2, 3)
Jarlen answered 30/1, 2015 at 10:7 Comment(0)

© 2022 - 2024 — McMap. All rights reserved.