A loopless 3D matrix multiplication in python
Asked Answered
O

3

18

I am looking to do the following operation in python (numpy).

Matrix A is M x N x R
Matrix B is N x 1 x R

Matrix multiply AB = C, where C is a M x 1 x R matrix. Essentially each M x N layer of A (R of them) is matrix multiplied independently by each N x 1 vector in B. I am sure this is a one-liner. I have been trying to use tensordot(), but I that seems to be giving me answers that I don't expect.

I have been programming in Igor Pro for nearly 10 years, and I am now trying to convert pages of it over to python.

Okoka answered 17/3, 2011 at 20:28 Comment(0)
D
21

Sorry for the necromancy, but this answer can be substantially improved upon, using the invaluable np.einsum.

import numpy as np

D,M,N,R = 1,2,3,4
A = np.random.rand(M,N,R)
B = np.random.rand(N,D,R)

print np.einsum('mnr,ndr->mdr', A, B).shape

Note that it has several advantages: first of all, its fast. np.einsum is well-optimized generally, but moreover, np.einsum is smart enough to avoid the creation of an MxNxR temporary array, but performs the contraction over N directly.

But perhaps more importantly, its very readable. There is no doubt that this code is correct; and you could make it a lot more complicated without any trouble.

Note that the dummy 'D' axis can simply be dropped from B and the einsum statement if you wish.

Desirous answered 28/1, 2014 at 9:1 Comment(2)
i saw that np.dot() can also do some multi-dimensional operation, but works under some strange rules. do you have any knowledge of it?Rebekah
my knowledge can be summarized best as 'use einsum instead'. may be a little more verbose, but 'explicit is better than implicit' never applied more, in my opinion.Desirous
V
11

numpy.tensordot() is the right way to do it:

a = numpy.arange(24).reshape(2, 3, 4)
b = numpy.arange(12).reshape(3, 1, 4)
c = numpy.tensordot(a, b, axes=[1, 0]).diagonal(axis1=1, axis2=3)

Edit: The first version of this was faulty, and this version computes more han it should and throws away most of it. Maybe a Python loop over the last axis is the better way to do it.

Another Edit: I've come to the conclusion that numpy.tensordot() is not the best solution here.

c = (a[:,:,None] * b).sum(axis=1)

will be more efficient (though even harder to grasp).

Viper answered 17/3, 2011 at 20:39 Comment(3)
Thank you for getting back to me so quickly. At the very least, it will get me started. The majority of the code that I am going to try write is matrix operation oriented... so I should really try to understand what is going on here. That being said, there are two parts to the code that confuse me. First is the invoking "axes" inside of tensordot. I am confused about what that actually does, in that, I would expect that a x b should just give you c (as the way I described above) without declaring anything special. Maybe once I understand that, I will see why it is necessary to use .diagonal.Okoka
That is clever... I don't know how long it would have taken me to do something like this (looks like you create a new axis to for multiplication-sum, then essentially recombine later) I really appreciate your time, thanks much!Okoka
Its easier to think of this if the first dimension is the list of matrices (R), and the second one is the common matrix dimension (N) i.e. shape of a is (4, 3, 2) and of b is (4, 3, 1). The multiplication operation then becomes (a * b).sum(axis=1). (a * b) is the multiplication of rows and columns of each elements in the matrix, you need to sum each row to get the final matrix. Just like you multiple matrices by hand.Tragopan
K
3

Another way to do it (easier for those not familiar with Einstein notation, like me) is np.matmul(). The important thing is just to have the matching dimensions ((M, N) x (N, 1)) in the last two indices. For this use np.transpose() Example:

M, N, R = 4, 3, 10
A = np.ones((M, N, R))
B = np.ones((N, 1, R))

# have the matching dimensions at the very end
C = np.matmul(np.transpose(A, (2, 0, 1)), np.transpose(B, (2, 0, 1))) 
C = np.transpose(C, (1, 2, 0))

print(A.shape)
# out: #(4, 3, 10)
print(B.shape)
# out: #(3, 1, 10)
print(C.shape)
# out: #(4, 1, 10)
Knapweed answered 2/5, 2020 at 12:20 Comment(0)

© 2022 - 2024 — McMap. All rights reserved.