So, I've read about half the original ResNet paper, and am trying to figure out how to make my version for tabular data.
I've read a few blog posts on how it works in PyTorch, and I see heavy use of nn.Identity()
. Now, the paper also frequently uses the term identity mapping. However, it just refers to adding the input for a stack of layers the output of that same stack in an element-wise fashion. If the in and out dimensions are different, then the paper talks about padding the input with zeros or using a matrix W_s
to project the input to a different dimension.
Here is an abstraction of a residual block I found in a blog post:
class ResidualBlock(nn.Module):
def __init__(self, in_channels, out_channels, activation='relu'):
super().__init__()
self.in_channels, self.out_channels, self.activation = in_channels, out_channels, activation
self.blocks = nn.Identity()
self.shortcut = nn.Identity()
def forward(self, x):
residual = x
if self.should_apply_shortcut: residual = self.shortcut(x)
x = self.blocks(x)
x += residual
return x
@property
def should_apply_shortcut(self):
return self.in_channels != self.out_channels
block1 = ResidualBlock(4, 4)
And my own application to a dummy tensor:
x = tensor([1, 1, 2, 2])
block1 = ResidualBlock(4, 4)
block2 = ResidualBlock(4, 6)
x = block1(x)
print(x)
x = block2(x)
print(x)
>>> tensor([2, 2, 4, 4])
>>> tensor([4, 4, 8, 8])
So at the end of it, x = nn.Identity(x)
and I'm not sure the point of its use except to mimic math lingo found in the original paper. I'm sure that's not the case though, and that it has some hidden use that I'm just not seeing yet. What could it be?
EDIT Here is another example of implementing residual learning, this time in Keras. It does just what I suggested above and just keeps a copy of the input for adding to the output:
def residual_block(x: Tensor, downsample: bool, filters: int, kernel_size: int = 3) -> Tensor:
y = Conv2D(kernel_size=kernel_size,
strides= (1 if not downsample else 2),
filters=filters,
padding="same")(x)
y = relu_bn(y)
y = Conv2D(kernel_size=kernel_size,
strides=1,
filters=filters,
padding="same")(y)
if downsample:
x = Conv2D(kernel_size=1,
strides=2,
filters=filters,
padding="same")(x)
out = Add()([x, y])
out = relu_bn(out)
return out