Works fine for me with default pytorch API:
def replace_layer(module: nn.Module, old: nn.Module, new: nn.Module, full_name=""):
for name, m in module.named_children():
full_name = f"{full_name}.{name}"
if isinstance(m, old):
setattr(module, name, new)
print(f"replaced {full_name}: {old}->{new}")
elif len(list(m.children())) > 0:
replace_layer(m, old, new, full_name)
model.apply(lambda m: replace_layer(m, nn.Relu, nn.Hardswish(True)))
Will repalce layers and print "trace":
replaced ._model.norm_layer.0.1.2: <class 'torch.nn.modules.activation.ReLU'>->Hardswish()
replaced ._model.norm_layer.backbone.encoder.blocks.0.0.conv1.bn1.relu: <class 'torch.nn.modules.activation.ReLU'>->Hardswish()
replaced ._model.norm_layer.backbone.encoder.blocks.0.0.1.conv1.bn1.relu: <class 'torch.nn.modules.activation.ReLU'>->Hardswish()
replaced ._model.norm_layer.backbone.encoder.blocks.0.1.0.conv1.bn1.relu: <class 'torch.nn.modules.activation.ReLU'>->Hardswish()
replaced ._model.norm_layer.backbone.encoder.blocks.0.1.0.1.conv1.bn1.relu: <class 'torch.nn.modules.activation.ReLU'>->Hardswish()
replaced ._model.norm_layer.backbone.encoder.blocks.0.1.2.0.conv1.bn1.relu: <class 'torch.nn.modules.activation.ReLU'>->Hardswish()
replaced ._model.norm_layer.backbone.encoder.blocks.0.1.2.0.1.conv1.bn1.relu: <class 'torch.nn.modules.activation.ReLU'>->Hardswish()
replaced ._model.norm_layer.backbone.encoder.blocks.0.1.2.3.0.conv1.bn1.relu: <class 'torch.nn.modules.activation.ReLU'>->Hardswish()
replaced ._model.norm_layer.backbone.encoder.blocks.0.1.2.3.0.1.conv1.bn1.relu: <class 'torch.nn.modules.activation.ReLU'>->Hardswish()
replaced ._model.norm_layer.backbone.encoder.blocks.0.1.2.3.4.0.conv1.bn1.relu: <class 'torch.nn.modules.activation.ReLU'>->Hardswish()