Batch-Matrix multiplication in Pytorch - Confused with the handling of the output's dimension
Asked Answered
P

1

10

I got two arrays :

A
B

Array A contains a batch of RGB images, with shape:

[batch, Width, Height, 3]

whereas Array B contains coefficients needed for a "transformation-like" operation on images, with shape:

[batch, 4, 4, 3]

To put it simply, the operation for a single image is a multiplication that outputs an environment map (normalMap * Coefficients).

The output I want should hold shape:

[batch, Width, Height, 3]

I tried using torch.bmm but failed. Is this possible somehow?

Prediction answered 11/6, 2019 at 12:37 Comment(3)
I do not understand the dimensions of the matrix multiplication? Does the multiplication need to work on the channels axis? Maybe check out torch.nn.functional.conv2d?Bleier
@Danos I want for each image in the batch from tensor A to be multiplicated with the 4*4 matrix from tensor B, respectively on the channel axis yes.Prediction
According to the documentation of torch.bmm, the matrix dimensions must agree (i.e. Height is equal to 4 if it's A*B). If this is not the case, it makes sense the operation failed. If you want element-wise multiplication, check out torch.mul which in this case I think you need to make sure the B is broadcastable.Bleier
B
5

I think you need to calculate that PyTorch works with

BxCxHxW : number of mini-batches, channels, height, width

format, and also use matmul, since bmm works with tensors or ndim/dim/rank =3.

I know you may find this online, but for any case:

batch1 = torch.randn(10, 3, 20, 10)
batch2 = torch.randn(10, 3, 10, 30)
res = torch.matmul(batch1, batch2)
res.size() # torch.Size([10, 3, 20, 30])
Boyhood answered 12/6, 2019 at 20:46 Comment(0)

© 2022 - 2024 — McMap. All rights reserved.