How to do fully connected batch norm in PyTorch?
Asked Answered
G

2

12

torch.nn has classes BatchNorm1d, BatchNorm2d, BatchNorm3d, but it doesn't have a fully connected BatchNorm class? What is the standard way of doing normal Batch Norm in PyTorch?

Gnawing answered 9/11, 2017 at 9:13 Comment(1)
what makes you think these layer are not fully connected?Needham
G
32

Ok. I figured it out. BatchNorm1d can also handle Rank-2 tensors, thus it is possible to use BatchNorm1d for the normal fully-connected case.

So for example:

import torch.nn as nn


class Policy(nn.Module):
def __init__(self, num_inputs, action_space, hidden_size1=256, hidden_size2=128):
    super(Policy, self).__init__()
    self.action_space = action_space
    num_outputs = action_space

    self.linear1 = nn.Linear(num_inputs, hidden_size1)
    self.linear2 = nn.Linear(hidden_size1, hidden_size2)
    self.linear3 = nn.Linear(hidden_size2, num_outputs)
    self.bn1 = nn.BatchNorm1d(hidden_size1)
    self.bn2 = nn.BatchNorm1d(hidden_size2)

def forward(self, inputs):
    x = inputs
    x = self.bn1(F.relu(self.linear1(x)))
    x = self.bn2(F.relu(self.linear2(x)))
    out = self.linear3(x)


    return out
Gnawing answered 9/11, 2017 at 12:44 Comment(2)
This may not be related to machine learning but shouldn't the super call be like super(Policy, self).__init__() instead of super(Policy2, self).__init__()? In Python3 it can even be simplified to just super().__init__().Majordomo
Shouldn't it be F.relu(self.bn1(self.linear1(x)))Caddis
N
11

The BatchNorm1d normally comes before the ReLU, and the bias is redundant, so

import torch.nn as nn

class Policy(nn.Module):
def __init__(self, num_inputs, action_space, hidden_size1=256, hidden_size2=128):
    super(Policy2, self).__init__()
    self.action_space = action_space
    num_outputs = action_space

    self.linear1 = nn.Linear(num_inputs, hidden_size1, bias=False)
    self.linear2 = nn.Linear(hidden_size1, hidden_size2, bias=False)
    self.linear3 = nn.Linear(hidden_size2, num_outputs)
    self.bn1 = nn.BatchNorm1d(hidden_size1)
    self.bn2 = nn.BatchNorm1d(hidden_size2)

def forward(self, inputs):
    x = inputs
    x = F.relu(self.bn1(self.linear1(x)))
    x = F.relu(self.bn2(self.linear2(x)))
    out = self.linear3(x)

    return out
Nibelungenlied answered 14/1, 2020 at 15:8 Comment(1)
batch norm should be after the relu as per this study: github.com/ducha-aiki/caffenet-benchmark/blob/master/…Supine

© 2022 - 2024 — McMap. All rights reserved.