Train multi-output regression model in pytorch
Asked Answered
X

1

8

I'd like to have a model with 3 regression outputs, such as the dummy example below:

import torch

class MultiOutputRegression(torch.nn.Module):

    def __init__(self):
        super(MultiOutputRegression, self).__init__()
        self.linear1 = torch.nn.Linear(1, 10)
        self.linear2 = torch.nn.Linear(10, 10)
        self.linear3 = torch.nn.Linear(3, 3)

    def forward(self, x):
        x = self.linear1(x)
        x = self.linear2(x)
        x = self.linear3(x)
        return x

Suppose I want to train it to perform a dummy task, such as, given the input x returning [x, 2x, 3x].

After defining the criterion and the loss we can train it with the following data:

for i in range(1, 100, 2):
    x_train = torch.tensor([i, i + 1]).reshape(2, 1).float()
    y_train = torch.tensor([[j, 2 * j] for j in x_train]).float()
    y_pred = model(x_train)
    # todo: perform training iteration 

Sample data at the first iteration would be:

x_train
tensor([[1.],
        [2.]])
y_train
tensor([[1., 2., 3.],
        [2., 4., 6.]])

How can I define a proper loss and criterion to train the neural network?

Xyster answered 17/6, 2021 at 1:1 Comment(0)
R
7
class MultiOutputRegression(torch.nn.Module):

    def __init__(self):
        super(MultiOutputRegression, self).__init__()
        self.linear1 = torch.nn.Linear(1, 10)
        self.linear2 = torch.nn.Linear(10, 10)
        self.linear3 = torch.nn.Linear(10, 3)

    def forward(self, x):
        x = self.linear1(x)
        x = self.linear2(x)
        x = self.linear3(x)
        return x

model = MultiOutputRegression()

criterion = nn.MSELoss()
optimizer = optim.Adam(model.parameters())

for epoch in range(5):
    for i in range(1, 100, 2):
        x_train = torch.tensor([i, i + 1]).reshape(2, 1).float()
        y_train = torch.tensor([[j, 2 * j, 3 * j] for j in x_train]).float()
        
        optimizer.zero_grad()
        y_pred = model(x_train)
        loss = criterion(y_pred, y_train)
        loss.backward()
        optimizer.step()
        
        print(loss.detach().numpy())
Requiem answered 23/3, 2022 at 16:38 Comment(0)

© 2022 - 2024 — McMap. All rights reserved.