How to change activation layer in Pytorch pretrained module?
Asked Answered
K

5

5

How to change the activation layer of a Pytorch pretrained network? Here is my code :

print("All modules")
for child in net.children():
    if isinstance(child,nn.ReLU) or isinstance(child,nn.SELU):
        print(child)

print('Before changing activation')
for child in net.children():
    if isinstance(child,nn.ReLU) or isinstance(child,nn.SELU):
        print(child)
        child=nn.SELU()
        print(child)
print('after changing activation')
for child in net.children():
    if isinstance(child,nn.ReLU) or isinstance(child,nn.SELU):
        print(child)

Here is my output:

All modules
ReLU(inplace=True)
Before changing activation
ReLU(inplace=True)
SELU()
after changing activation
ReLU(inplace=True)
Kaunas answered 9/10, 2019 at 4:44 Comment(2)
you should check this for a more general solution that works for any layer: discuss.pytorch.org/t/how-to-modify-a-pretrained-model/60509/…Grandchild
Does this answer your question? How to remove the last FC layer from a ResNet model in PyTorch?Carbolated
K
3

._modules solves the problem for me.

for name,child in net.named_children():
    if isinstance(child,nn.ReLU) or isinstance(child,nn.SELU):
        net._modules['relu'] = nn.SELU()
Kaunas answered 9/10, 2019 at 5:2 Comment(0)
O
4

here is a general function for replacing any layer

def replace_layers(model, old, new):
    for n, module in model.named_children():
        if len(list(module.children())) > 0:
            ## compound module, go inside it
            replace_layers(module, old, new)
            
        if isinstance(module, old):
            ## simple module
            setattr(model, n, new)

replace_layer(model, nn.ReLU, nn.ReLU6())

I struggled with it for a few days. So, I did some digging & wrote a kaggle notebook explaining how different types of layers / modules are accessed in pytorch.

Obadiah answered 27/5, 2021 at 6:54 Comment(0)
K
3

._modules solves the problem for me.

for name,child in net.named_children():
    if isinstance(child,nn.ReLU) or isinstance(child,nn.SELU):
        net._modules['relu'] = nn.SELU()
Kaunas answered 9/10, 2019 at 5:2 Comment(0)
G
3

Works fine for me with default pytorch API:

def replace_layer(module: nn.Module, old: nn.Module, new: nn.Module, full_name=""):
    for name, m in module.named_children():
        full_name = f"{full_name}.{name}"

        if isinstance(m, old):
            setattr(module, name, new)
            print(f"replaced {full_name}: {old}->{new}")
        elif len(list(m.children())) > 0:
            replace_layer(m, old, new, full_name)

model.apply(lambda m: replace_layer(m, nn.Relu, nn.Hardswish(True)))
Will repalce layers and print "trace":

replaced ._model.norm_layer.0.1.2: <class 'torch.nn.modules.activation.ReLU'>->Hardswish()
replaced ._model.norm_layer.backbone.encoder.blocks.0.0.conv1.bn1.relu: <class 'torch.nn.modules.activation.ReLU'>->Hardswish()
replaced ._model.norm_layer.backbone.encoder.blocks.0.0.1.conv1.bn1.relu: <class 'torch.nn.modules.activation.ReLU'>->Hardswish()
replaced ._model.norm_layer.backbone.encoder.blocks.0.1.0.conv1.bn1.relu: <class 'torch.nn.modules.activation.ReLU'>->Hardswish()
replaced ._model.norm_layer.backbone.encoder.blocks.0.1.0.1.conv1.bn1.relu: <class 'torch.nn.modules.activation.ReLU'>->Hardswish()
replaced ._model.norm_layer.backbone.encoder.blocks.0.1.2.0.conv1.bn1.relu: <class 'torch.nn.modules.activation.ReLU'>->Hardswish()
replaced ._model.norm_layer.backbone.encoder.blocks.0.1.2.0.1.conv1.bn1.relu: <class 'torch.nn.modules.activation.ReLU'>->Hardswish()
replaced ._model.norm_layer.backbone.encoder.blocks.0.1.2.3.0.conv1.bn1.relu: <class 'torch.nn.modules.activation.ReLU'>->Hardswish()
replaced ._model.norm_layer.backbone.encoder.blocks.0.1.2.3.0.1.conv1.bn1.relu: <class 'torch.nn.modules.activation.ReLU'>->Hardswish()
replaced ._model.norm_layer.backbone.encoder.blocks.0.1.2.3.4.0.conv1.bn1.relu: <class 'torch.nn.modules.activation.ReLU'>->Hardswish()
Gemma answered 26/4, 2022 at 20:15 Comment(1)
Thanks, a short and tidy answer!Susannasusannah
C
1

I'm assuming you use module interface nn.ReLU to create the acitvation layer instead of using functional interface F.relu. If so, setattr works for me.

import torch
import torch.nn as nn

# This function will recursively replace all relu module to selu module. 
def replace_relu_to_selu(model):
    for child_name, child in model.named_children():
        if isinstance(child, nn.ReLU):
            setattr(model, child_name, nn.SELU())
        else:
            replace_relu_to_selu(child)

########## A toy example ##########
net = nn.Sequential(
            nn.Conv2d(3, 32, kernel_size=3, stride=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(3, 32, kernel_size=3, stride=1),
            nn.ReLU(inplace=True)
          )

########## Test ##########
print('Before changing activation')
for child in net.children():
    if isinstance(child,nn.ReLU) or isinstance(child,nn.SELU):
        print(child)
# Before changing activation
# ReLU(inplace=True)
# ReLU(inplace=True)


print('after changing activation')
for child in net.children():
    if isinstance(child,nn.ReLU) or isinstance(child,nn.SELU):
        print(child)
# after changing activation
# SELU()
# SELU(
Cyder answered 9/10, 2019 at 5:9 Comment(1)
Print the network architecture, in my case, it didn't work.Kaunas
G
0

I will provide a more general solution that works for any layer (and avoids other issues like modifying a dictionary as you loop through it or when there are recursive nn.modules inside each other).

def replace_bn(module, name):
    '''
    Recursively put desired batch norm in nn.module module.

    set module = net to start code.
    '''
    # go through all attributes of module nn.module (e.g. network or layer) and put batch norms if present
    for attr_str in dir(module):
        target_attr = getattr(m, attr_str)
        if type(target_attr) == torch.nn.BatchNorm2d:
            print('replaced: ', name, attr_str)
            new_bn = torch.nn.BatchNorm2d(target_attr.num_features, target_attr.eps, target_attr.momentum, target_attr.affine,
                                          track_running_stats=False)
            setattr(module, attr_str, new_bn)

    # iterate through immediate child modules. Note, the recursion is done by our code no need to use named_modules()
    for name, immediate_child_module in module.named_children():
        replace_bn(immediate_child_module, name)

replace_bn(model, 'model')

the crux is that you need to recursively keep changing the layers (mainly because sometimes you will encounter attributes that have modules itself). I think better code than the above would be to add another if statement (after the batch norm) detecting if you have to recurse and recursing if so. The above works to but first changes the batch norm over the outer layer (i.e. the first loop) and then with another loop making sure no other object that should be recursed is missed (and then recursing).

Original post: https://discuss.pytorch.org/t/how-to-modify-a-pretrained-model/60509/10

credits: https://discuss.pytorch.org/t/replacing-convs-modules-with-custom-convs-then-notimplementederror/17736/3?u=brando_miranda

Grandchild answered 1/10, 2020 at 19:0 Comment(0)

© 2022 - 2024 — McMap. All rights reserved.