Reset parameters of a neural network in pytorch
Asked Answered
V

3

7

I have a neural network with the following structure:

class myNetwork(nn.Module):
    def __init__(self):
        super(myNetwork, self).__init__()
        self.bigru = nn.GRU(input_size=2, hidden_size=100, batch_first=True, bidirectional=True)
        self.fc1 = nn.Linear(200, 32)
        torch.nn.init.xavier_uniform_(self.fc1.weight)
        self.fc2 = nn.Linear(32, 2)
        torch.nn.init.xavier_uniform_(self.fc2.weight)

I need to reinstate the model to an unlearned state by resetting the parameters of the neural network. I can do so for nn.Linear layers by using the method below:

def reset_weights(self):
    torch.nn.init.xavier_uniform_(self.fc1.weight)
    torch.nn.init.xavier_uniform_(self.fc2.weight)

But, to reset the weight of the nn.GRU layer, I could not find any such snippet.

My question is how does one reset the nn.GRU layer? Any other way of resetting the network is also fine. Any help is appreciated.

Voguish answered 28/8, 2020 at 5:32 Comment(1)
see the official pytorch forum: discuss.pytorch.org/t/…Elias
S
12

You can use reset_parameters method on the layer. As given here

for layer in model.children():
   if hasattr(layer, 'reset_parameters'):
       layer.reset_parameters()

Or Another way would be saving the model first and then reload the module state. Using torch.save and torch.load see docs for more Or Saving and Loading Models

Silverts answered 28/8, 2020 at 6:11 Comment(2)
it does not walk into nested nn.ModuleList(). The best way is reorganize the code to recreate whole network instance.Underpart
With recursion, this works for more complex models as well: pastebin.com/QBuNyYucIncomprehension
E
4

Here is the code with an example that runs:

def lp_norm(mdl: nn.Module, p: int = 2) -> Tensor:
    lp_norms = [w.norm(p) for name, w in mdl.named_parameters()]
    return sum(lp_norms)

def reset_all_weights(model: nn.Module) -> None:
    """
    refs:
        - https://discuss.pytorch.org/t/how-to-re-set-alll-parameters-in-a-network/20819/6
        - https://mcmap.net/q/1419521/-reset-parameters-of-a-neural-network-in-pytorch
        - https://pytorch.org/docs/stable/generated/torch.nn.Module.html
    """

    @torch.no_grad()
    def weight_reset(m: nn.Module):
        # - check if the current module has reset_parameters & if it's callabed called it on m
        reset_parameters = getattr(m, "reset_parameters", None)
        if callable(reset_parameters):
            m.reset_parameters()

    # Applies fn recursively to every submodule see: https://pytorch.org/docs/stable/generated/torch.nn.Module.html
    model.apply(fn=weight_reset)


def reset_all_linear_layer_weights(model: nn.Module) -> nn.Module:
    """
    Resets all weights recursively for linear layers.

    ref:
        - https://pytorch.org/docs/stable/generated/torch.nn.Module.html
    """

    @torch.no_grad()
    def init_weights(m):
        if type(m) == nn.Linear:
            m.weight.fill_(1.0)

    # Applies fn recursively to every submodule see: https://pytorch.org/docs/stable/generated/torch.nn.Module.html
    model.apply(init_weights)


def reset_all_weights_with_specific_layer_type(model: nn.Module, modules_type2reset) -> nn.Module:
    """
    Resets all weights recursively for linear layers.

    ref:
        - https://pytorch.org/docs/stable/generated/torch.nn.Module.html
    """

    @torch.no_grad()
    def init_weights(m):
        if type(m) == modules_type2reset:
            # if type(m) == torch.nn.BatchNorm2d:
            #     m.weight.fill_(1.0)
            m.reset_parameters()

    # Applies fn recursively to every submodule see: https://pytorch.org/docs/stable/generated/torch.nn.Module.html
    model.apply(init_weights)


# -- tests

def reset_params_test():
    import torchvision.models as models
    from uutils.torch_uu import lp_norm

    resnet18 = models.resnet18(pretrained=True)
    resnet18_random = models.resnet18(pretrained=False)

    print(f'{lp_norm(resnet18)=}')
    print(f'{lp_norm(resnet18_random)=}')
    print(f'{lp_norm(resnet18)=}')
    reset_all_weights(resnet18)
    print(f'{lp_norm(resnet18)=}')


if __name__ == '__main__':
    reset_params_test()
    print('Done! \a\n')

output:

lp_norm(resnet18)=tensor(517.5472, grad_fn=<AddBackward0>)
lp_norm(resnet18_random)=tensor(668.3687, grad_fn=<AddBackward0>)
lp_norm(resnet18)=tensor(517.5472, grad_fn=<AddBackward0>)
lp_norm(resnet18)=tensor(476.0836, grad_fn=<AddBackward0>)
Done!

I am assuming this works because I calculated the norm twice for the pre-trained net and it was the same both times before calling reset.

Though I was unhappy it wasn't closer to the norm of the random net I must admit but I think this is good enough.

same: https://discuss.pytorch.org/t/how-to-re-set-alll-parameters-in-a-network/20819/11

Elias answered 9/11, 2021 at 22:5 Comment(0)
S
0

New to pytorch, I wonder if this could be a solution :)

Suppose Model inherents from torch.nn.module,

to reset it to zeros:

dic = Model.state_dict()
for k in dic:
    dic[k] *= 0
Model.load_state_dict(dic)
del(dic)

to reset it randomly

dic = Model.state_dict()
for k in dic:
    dic[k] = torch.randn(dic[k].size())  
Model.load_state_dict(dic)
del(dic)
Stogner answered 19/7, 2021 at 15:26 Comment(0)

© 2022 - 2024 — McMap. All rights reserved.