Multi dimensional inputs in pytorch Linear method?
Asked Answered
L

3

16

When building a simple perceptron neural network we usuall passes a 2D matrix of input of format (batch_size,features) to a 2D weight matrix, similar to this simple neural network in numpy. I always assumed a Perceptron/Dense/Linear layer of a neural network only accepts an input of 2D format and outputs another 2D output. But recently I came across this pytorch model in which a Linear layer accepts a 3D input tensor and output another 3D tensor (o1 = self.a1(x)).

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim

class Net(nn.Module):
    def __init__(self):
        super().__init__()
        self.a1 = nn.Linear(4,4)
        self.a2 = nn.Linear(4,4)
        self.a3 = nn.Linear(9,1)
    def forward(self,x):
        o1 = self.a1(x)
        o2 = self.a2(x).transpose(1,2)
        output = torch.bmm(o1,o2)
        output = output.view(len(x),9)
        output = self.a3(output)
        return output

x = torch.randn(10,3,4)
y = torch.ones(10,1)

net = Net()

criterion = nn.MSELoss()
optimizer = optim.Adam(net.parameters())

for i in range(10):
    net.zero_grad()
    output = net(x)
    loss = criterion(output,y)
    loss.backward()
    optimizer.step()
    print(loss.item())

These are the question I have,

  1. Is the above neural network a valid one? that is whether the model will train correctly?
  2. Even after passing a 3D input x = torch.randn(10,3,4), why is the pytorch nn.Linear doesn't shows any error and gives a 3D output?
Lulita answered 28/10, 2019 at 7:28 Comment(0)
C
27

Newer versions of PyTorch allows nn.Linear to accept N-D input tensor, the only constraint is that the last dimension of the input tensor will equal in_features of the linear layer. The linear transformation is then applied on the last dimension of the tensor.
For instance, if in_features=5 and out_features=10 and the input tensor x has dimensions 2-3-5, then the output tensor will have dimensions 2-3-10.

Capel answered 28/10, 2019 at 13:14 Comment(0)
M
14

If you have a look at the documentation, you will find that indeed the Linear layer accepts tensors of arbitrary shape, where only the last dimension must match with the in_features argument you specified in the constructor.

The output will have exactly the same shape as the input, only the last dimension will change to whatever you specified as out_features in the constructor.

It works in a way that the same layer (with the same weights) is applied on each of the (possibly) multiple inputs. In your example you have an input shape of (10, 3, 4) which is basically a set of 10 * 3 == 30 4-dimensional vectors. So, your layers a1 and a2 are applied on all of these 30 vectors to generate another 10 * 3 == 30 4D vectors as the output (because you specified out_features=4 in the constructor).

So, to answer your questions:

Is the above neural network a valid one? that is whether the model will train correctly?

Yes, it is valid and it will be trained "correctly" from a technical pov. But, as with any other network, if this will actually correctly tackle your problem is another question.

Even after passing a 3D input x = torch.randn(10,3,4), why is the pytorch nn.Linear doesn't shows any error and gives a 3D output?

Well, because it is defined to work this way.

Mistletoe answered 28/10, 2019 at 13:16 Comment(3)
Interesting! Could you please explain the part where you mention " input shape of (10, 3, 4) which is basically a set of 10 * 3 == 30 4-dimensional vectors". Having difficulty visualizing this. @MistletoePaymaster
@SohamBhaumik There are many ways to think about this: e.g. a shape of (10, 3) can be seen as 10 3d-vectors. It can also be seen as a 10x3 matrix of numbers. Combine these two ways of thinking and now a shape of (10, 3, 4) can be seen as a 10x3 matrix, where each entry is not a number but a 4d-vector. So, you have 30 4d-vectors.Mistletoe
Ah! the reshape function has never made so much sense! Im starting out with Machine Learning and am having a super tough time visualizing dimensionalities, and shapes and how interpretations of pytorch tensors are all use-case based. thank you for this!Paymaster
F
0

I have the same question too. Thank you, everyone. I tested it.

First, do this...

torch.manual_seed(0)

a = np.arange(12).reshape(3, 2, 2).astype(np.float32)
a = torch.from_numpy(a)

b = torch.nn.Linear(2, 3, bias=False)
torch.nn.init.normal_(b.weight, mean=0, std=1.0)
for param in b.parameters():
    param.requires_grad = False

c = b(a)

print('a =\n', a)
print()
print('b.weight =\n', b.weight)
print()
print('c =\n', c)

a =
 tensor([[[ 0.,  1.],
         [ 2.,  3.]],

        [[ 4.,  5.],
         [ 6.,  7.]],

        [[ 8.,  9.],
         [10., 11.]]])

b.weight =
 Parameter containing:
tensor([[ 1.2645, -0.6874],
        [ 0.1604, -0.6065],
        [-0.7831,  1.0622]])

c =
 tensor([[[-0.6874, -0.6065,  1.0622],
         [ 0.4666, -1.4986,  1.6204]],

        [[ 1.6207, -2.3907,  2.1786],
         [ 2.7748, -3.2829,  2.7368]],

        [[ 3.9288, -4.1750,  3.2950],
         [ 5.0829, -5.0672,  3.8532]]])

Next, do this...

b_ = b.weight.clone().detach()

c_ = []

for a_ in a:
    c_.append(a_ @ b_.T)

c_ = torch.stack(c_)

print('a =\n', a)
print()
print('b_ =\n', b_)
print()
print('c_ =\n', c_)

a =
 tensor([[[ 0.,  1.],
         [ 2.,  3.]],

        [[ 4.,  5.],
         [ 6.,  7.]],

        [[ 8.,  9.],
         [10., 11.]]])

b_ = 
 tensor([[ 1.2645, -0.6874],
        [ 0.1604, -0.6065],
        [-0.7831,  1.0622]])

c_ = 
 tensor([[[-0.6874, -0.6065,  1.0622],
         [ 0.4666, -1.4986,  1.6204]],

        [[ 1.6207, -2.3907,  2.1786],
         [ 2.7748, -3.2829,  2.7368]],

        [[ 3.9288, -4.1750,  3.2950],
         [ 5.0829, -5.0672,  3.8532]]])

Finally, check...

print('diff =', torch.mean(torch.abs(c - c_)))

diff = tensor(0.)

That's all.

Even if you change 3D:a = np.arange(12).reshape(3, 2, 2).astype(np.float32) to 4D:a = np.arange(60).reshape(5, 3, 2, 2).astype(np.float32), diff will be zero.

Interesring! :-D

Foreglimpse answered 13/2 at 23:56 Comment(0)

© 2022 - 2024 — McMap. All rights reserved.