What is the idea behind using nn.Identity for residual learning?
Asked Answered
F

2

11

So, I've read about half the original ResNet paper, and am trying to figure out how to make my version for tabular data.

I've read a few blog posts on how it works in PyTorch, and I see heavy use of nn.Identity(). Now, the paper also frequently uses the term identity mapping. However, it just refers to adding the input for a stack of layers the output of that same stack in an element-wise fashion. If the in and out dimensions are different, then the paper talks about padding the input with zeros or using a matrix W_s to project the input to a different dimension.

Here is an abstraction of a residual block I found in a blog post:


class ResidualBlock(nn.Module):
    def __init__(self, in_channels, out_channels, activation='relu'):
        super().__init__()
        self.in_channels, self.out_channels, self.activation = in_channels, out_channels, activation
        self.blocks = nn.Identity()
        self.shortcut = nn.Identity()   
    
    def forward(self, x):
        residual = x
        if self.should_apply_shortcut: residual = self.shortcut(x)
        x = self.blocks(x)
        x += residual
        return x
    
    @property
    def should_apply_shortcut(self):
        return self.in_channels != self.out_channels
    
block1 = ResidualBlock(4, 4)

And my own application to a dummy tensor:

x = tensor([1, 1, 2, 2])
block1 = ResidualBlock(4, 4)
block2 = ResidualBlock(4, 6)
x = block1(x)
print(x)
x = block2(x)
print(x)

>>> tensor([2, 2, 4, 4])
>>> tensor([4, 4, 8, 8])

So at the end of it, x = nn.Identity(x) and I'm not sure the point of its use except to mimic math lingo found in the original paper. I'm sure that's not the case though, and that it has some hidden use that I'm just not seeing yet. What could it be?

EDIT Here is another example of implementing residual learning, this time in Keras. It does just what I suggested above and just keeps a copy of the input for adding to the output:

def residual_block(x: Tensor, downsample: bool, filters: int,                                        kernel_size: int = 3) -> Tensor:
    y = Conv2D(kernel_size=kernel_size,
               strides= (1 if not downsample else 2),
               filters=filters,
               padding="same")(x)
    y = relu_bn(y)
    y = Conv2D(kernel_size=kernel_size,
               strides=1,
               filters=filters,
               padding="same")(y)

    if downsample:
        x = Conv2D(kernel_size=1,
                   strides=2,
                   filters=filters,
                   padding="same")(x)
    out = Add()([x, y])
    out = relu_bn(out)
    return out
Fanfaron answered 6/10, 2020 at 16:8 Comment(1)
I was shown the answer here: github.com/pytorch/pytorch/issues/9160. I'll expand it into an answer in a couple days if nobody does it first.Fanfaron
F
22

What is the idea behind using nn.Identity for residual learning?

There is none (almost, see the end of the post), all nn.Identity does is forwarding the input given to it (basically no-op).

As shown in PyTorch repo issue you linked in comment this idea was first rejected, later merged into PyTorch, due to other use (see the rationale in this PR). This rationale is not connected to ResNet block itself, see end of the answer.

ResNet implementation

Easiest generic version I can think of with projection would be something along those lines:

class Residual(torch.nn.Module):
    def __init__(self, module: torch.nn.Module, projection: torch.nn.Module = None):
        super().__init__()
        self.module = module
        self.projection = projection

    def forward(self, inputs):
        output = self.module(inputs)
        if self.projection is not None:
            inputs = self.projection(inputs)
        return output + inputs

You can pass as module things like two stacked convolutions and add 1x1 convolution (with padding or with strides or something) as projection module.

For tabular data you could use this as module (assuming your input has 50 features):

torch.nn.Sequential(
    torch.nn.Linear(50, 50),
    torch.nn.ReLU(),
    torch.nn.Linear(50, 50),
    torch.nn.ReLU(),
    torch.nn.Linear(50, 50),
)

Basically, all you have to do is is add input to some module to it's output and that is it.

Rationale behing nn.Identity

It might be easier to construct neural networks (and read them afterwards), example for batch norm (taken from aforementioned PR):

batch_norm = nn.BatchNorm2d
if dont_use_batch_norm:
    batch_norm = Identity

Now you can use it with nn.Sequential easily:

nn.Sequential(
    ...
    batch_norm(N, momentum=0.05),
    ...
)

And when printing the network it always has the same number of submodules (with either BatchNorm or Identity) which also makes the whole thing a little smoother IMO.

Another use case, mentioned here might be removing parts of existing neural networks:

net = tv.models.alexnet(pretrained=True)
# Assume net has two parts
# features and classifier
net.classifier = Identity()

Now, instead of running net.features(input) you can run net(input) which might be easier for others to read as well.

Frizz answered 6/10, 2020 at 17:17 Comment(1)
Perfect, this is the exact conclusion I came to after finding that GitHub issue. Thanks for explaining it so well.Fanfaron
F
0

One very good use of nn.Identity() is during jit scripting. In very modular models scripting will search each if statement and check all path in forward even tho during initialization an if statement is set to false

class MyModule(nn.Module):
    def __init__(self, extra=false):
        self.conv = nn.conv2d(3,3)
        self.extra = extra
        if extra:
            self.extra_layer = nn.Conv2d(3, 3)
    def forward(self, x):
        x = self.conv(x)
        if self.extra:
            x = self.extra_layer(x)
        return x

This module can not be scripted but you can do something like

class MyModule(nn.Module):
    def __init__(self, extra=false):
        self.conv = nn.conv2d(3,3)
        self.extra = extra
        
        self.extra_layer = nn.Conv2d(3, 3) if extra else nn.Identity()
    def forward(self, x):
        x = self.conv(x)
        if self.extra:
            x = self.extra_layer(x)
        return x
Forename answered 22/11, 2023 at 22:0 Comment(0)

© 2022 - 2024 — McMap. All rights reserved.