How to remove the last FC layer from a ResNet model in PyTorch?
Asked Answered
E

5

49

I am using a ResNet152 model from PyTorch. I'd like to strip off the last FC layer from the model. Here's my code:

from torchvision import datasets, transforms, models
model = models.resnet152(pretrained=True)
print(model)

When I print the model, the last few lines look like this:

    (2):  Bottleneck(
      (conv1):  Conv2d(2048,  512,  kernel_size=(1,  1),  stride=(1,  1),  bias=False)
      (bn1):  BatchNorm2d(512,  eps=1e-05,  momentum=0.1,  affine=True,  track_running_stats=True)
      (conv2):  Conv2d(512,  512,  kernel_size=(3,  3),  stride=(1,  1),  padding=(1,  1),  bias=False)
      (bn2):  BatchNorm2d(512,  eps=1e-05,  momentum=0.1,  affine=True,  track_running_stats=True)
      (conv3):  Conv2d(512,  2048,  kernel_size=(1,  1),  stride=(1,  1),  bias=False)
      (bn3):  BatchNorm2d(2048,  eps=1e-05,  momentum=0.1,  affine=True,  track_running_stats=True)
      (relu):  ReLU(inplace)
    )
  )
  (avgpool):  AvgPool2d(kernel_size=7,  stride=1,  padding=0)
  (fc):  Linear(in_features=2048,  out_features=1000,  bias=True)
)

I want to remove that last fc layer from the model.

I found an answer here on SO (How to convert pretrained FC layers to CONV layers in Pytorch), where mexmex seems to provide the answer I'm looking for:

list(model.modules()) # to inspect the modules of your model
my_model = nn.Sequential(*list(model.modules())[:-1]) # strips off last linear layer

So I added those lines to my code like this:

model = models.resnet152(pretrained=True)
list(model.modules()) # to inspect the modules of your model
my_model = nn.Sequential(*list(model.modules())[:-1]) # strips off last linear layer
print(my_model)

But this code doesn't work as advertised -- as least not for me. The rest of this post is a detailed explanation of why that answer doesn't work so this question doesn't get closed as a duplicate.

First, the printed model is nearly 5x larger than before. I see the same model as before, but followed by what appears to be a repeat of the model, but perhaps flattened.

    (2): Bottleneck(
      (conv1): Conv2d(2048, 512, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv2): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv3): Conv2d(512, 2048, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn3): BatchNorm2d(2048, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace)
    )
  )
  (avgpool): AvgPool2d(kernel_size=7, stride=1, padding=0)
  (fc): Linear(in_features=2048, out_features=1000, bias=True)
)
(1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
(2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(3): ReLU(inplace)
(4): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
(5): Sequential(
  . . . this goes on for ~1600 more lines . . .
  (415): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (416): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
  (417): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (418): Conv2d(512, 2048, kernel_size=(1, 1), stride=(1, 1), bias=False)
  (419): BatchNorm2d(2048, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (420): ReLU(inplace)
  (421): AvgPool2d(kernel_size=7, stride=1, padding=0)
)

Second, the fc layer is still there -- and the Conv2D layer after it looks just like the first layer of ResNet152.

Third, if I try to invoke my_model.forward(), pytorch complains about a size mismatch. It expects size [1, 3, 224, 224], but the input was [1, 1000]. So it looks like a copy of the entire model (minus the fc layer) is getting appended to the original model.

Bottom line, the only answer I found on SO doesn't actually work.

Emanative answered 28/9, 2018 at 4:13 Comment(1)
Not really sure, but a basic deletion from should work here, del(model['fc']). Can you try?Katzman
A
62

For ResNet model, you can use children attribute to access layers since ResNet model in pytorch consist of nn modules. (Tested on pytorch 0.4.1)

model = models.resnet152(pretrained=True)
newmodel = torch.nn.Sequential(*(list(model.children())[:-1]))
print(newmodel)

Update: Although there is not an universal answer for the question that can work on all pytorch models, it should work on all well structured ones. Existing layers you add to your model (such as torch.nn.Linear, torch.nn.Conv2d, torch.nn.BatchNorm2d...) all based on torch.nn.Module class. And if you implement a custom layer and add that to your network you should inherit it from pytorch's torch.nn.Module class. As written in documentation, children attribute lets you access the modules of your class/model/network.

def children(self):
        r"""Returns an iterator over immediate children modules.  

Update: It is important to note that children() returns "immediate" modules, which means if last module of your network is a sequential, it will return whole sequential.

Animatism answered 28/9, 2018 at 4:41 Comment(4)
Updated my answer with documentation references, please remind if something is missing.Animatism
To keep the layer names, you can do newmodel = torch.nn.Sequential(OrderedDict([*(list(model.named_children())[:-1])])).Maguire
You're effectively replacing the forward function by a simple daisy chain when using nn.Sequential. Wouldn't that destroy the skip layer connections of resnets?Harberd
@TimKuipers Actually if you print the model itself, you will see it consist of blocks of Bottleneck class, instead of sequential connection of conv relu batchnorms. Each Bottleneck block is a residual connection block, you can look forward function of Bottleneck here: github.com/pytorch/vision/blob/main/torchvision/models/…. Result of final relu is calculated from sum of original input(x) and intermediate block output.Animatism
C
30

You can do it simply by :

Model.fc = nn.Sequential()

or alternatively you can create Identity layer:

class Identity(nn.Module):
    def __init__(self):
        super().__init__()
        
    def forward(self, x):
        return x

and replace fc layer with it :

Model.fc = Identity()
Chef answered 12/7, 2019 at 15:11 Comment(3)
PyTorch now has a built-in identity module: pytorch.org/docs/stable/nn.html#identityPyrrho
correct way to call is torch.nn.Identity().. save time to open the link for correct call methodDespinadespise
I love you! I love you! This same technique can be used to change a built in PyTorch net like ResNet50 from one output to multiple outputs!Barton
R
18

If you are looking not just to strip the model of the last FC layer, but to replace it with your own, hence taking advantage of transfer learning technique, you can do so in this way:

import torch.nn as nn
from collections import OrderedDict

n_inputs = model.fc.in_features

# add more layers as required
classifier = nn.Sequential(OrderedDict([
    ('fc1', nn.Linear(n_inputs, 512))
]))

model.fc = classifier
Rozalin answered 9/1, 2019 at 16:18 Comment(0)
H
3

From the PyTorch tutorial "Finetuning TorchVision Models":

Here we use Resnet18, as our dataset is small and only has two classes. When we print the model, we see that the last layer is a fully connected layer as shown below:

(fc): Linear(in_features=512, out_features=1000, bias=True)

Thus, we must reinitialize model.fc to be a Linear layer with 512 input features and 2 output features with:

model.fc = nn.Linear(512, num_classes)
Hyacinthhyacintha answered 13/4, 2021 at 16:7 Comment(0)
R
0

Alternatively, change the fc layer to an Identity. nn.Identity() simply forwards its input to output:

import torch.nn as nn
from torchvision.models import resnet50, ResNet50_Weights

# load a pretrained resnet50 model
model = resnet50(weights = ResNet50_Weights.DEFAULT)
model.fc = nn.Identity()
Represent answered 14/5 at 19:41 Comment(0)

© 2022 - 2024 — McMap. All rights reserved.