Problem with missing and unexpected keys while loading my model in Pytorch
Asked Answered
C

2

15

I'm trying to load the model using this tutorial: https://pytorch.org/tutorials/beginner/saving_loading_models.html#saving-loading-model-for-inference . Unfortunately I'm very beginner and I face some problems.

I have created checkpoint:

checkpoint = {'epoch': epochs, 'model_state_dict': model.state_dict(), 'optimizer_state_dict': optimizer.state_dict(),'loss': loss}
torch.save(checkpoint, 'checkpoint.pth')

Then I wrote class for my network and I wanted to load the file:

class Network(nn.Module):
    def __init__(self):
        super().__init__()
        self.fc1 = nn.Linear(9216, 4096)
        self.fc2 = nn.Linear(4096, 1000)
        self.fc3 = nn.Linear(1000, 102)

    def forward(self, x):
        x = self.fc1(x)
        x = F.relu(x)
        x = self.fc2(x)
        x = F.relu(x)
        x = self.fc3(x)
        x = log(F.softmax(x, dim=1))
        return x

Like that:

def load_checkpoint(filepath):
    checkpoint = torch.load(filepath)
    model = Network()
    model.load_state_dict(checkpoint['model_state_dict'])
    optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
    epoch = checkpoint['epoch']
    loss = checkpoint['loss']

model = load_checkpoint('checkpoint.pth')

I got this error (edited to show whole communicate):

RuntimeError: Error(s) in loading state_dict for Network:
    Missing key(s) in state_dict: "fc1.weight", "fc1.bias", "fc2.weight", "fc2.bias", "fc3.weight", "fc3.bias". 
    Unexpected key(s) in state_dict: "features.0.weight", "features.0.bias", "features.3.weight", "features.3.bias", "features.6.weight", "features.6.bias", "features.8.weight", "features.8.bias", "features.10.weight", "features.10.bias", "classifier.fc1.weight", "classifier.fc1.bias", "classifier.fc2.weight", "classifier.fc2.bias", "classifier.fc3.weight", "classifier.fc3.bias". 

This is my model.state_dict().keys():

odict_keys(['features.0.weight', 'features.0.bias', 'features.3.weight', 
'features.3.bias', 'features.6.weight', 'features.6.bias', 
'features.8.weight', 'features.8.bias', 'features.10.weight', 
'features.10.bias', 'classifier.fc1.weight', 'classifier.fc1.bias', 
'classifier.fc2.weight', 'classifier.fc2.bias', 'classifier.fc3.weight', 
'classifier.fc3.bias'])

This is my model:

AlexNet(
  (features): Sequential(
    (0): Conv2d(3, 64, kernel_size=(11, 11), stride=(4, 4), padding=(2, 2))
    (1): ReLU(inplace)
    (2): MaxPool2d(kernel_size=3, stride=2, padding=0, dilation=1, ceil_mode=False)
    (3): Conv2d(64, 192, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2))
    (4): ReLU(inplace)
    (5): MaxPool2d(kernel_size=3, stride=2, padding=0, dilation=1, ceil_mode=False)
    (6): Conv2d(192, 384, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (7): ReLU(inplace)
    (8): Conv2d(384, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (9): ReLU(inplace)
    (10): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (11): ReLU(inplace)
    (12): MaxPool2d(kernel_size=3, stride=2, padding=0, dilation=1, ceil_mode=False)

((classifier): Sequential(
(fc1): Linear(in_features=9216, out_features=4096, bias=True)
(relu1): ReLU()
(fc2): Linear(in_features=4096, out_features=1000, bias=True)
(relu2): ReLU()
(fc3): Linear(in_features=1000, out_features=102, bias=True)
(output): LogSoftmax()
)
)

It's my first network ever and I'm blundering along. Thanks for steering me into right direction!

Calefacient answered 23/12, 2018 at 20:50 Comment(6)
What if you just rename the corresponding keys in your model.state_dict().keys(), so that features.3.weight becomes fc3.weight, and so on?Nicolnicola
I'll try and let you know in a momentCalefacient
It's weird but when I do it, after loading the model is NoneCalefacient
Ah OK, so because you are not using a return value on the function, when you call load_checkpoint it returns nothing; hence NoneType. If you want to return the model from your function, you need to add return model to the bottom of your function. If you do not need to return it, remove the model = from the model = load_checkpoint('checkpoint.pth') which will just call the function.Nicolnicola
If you want to return multiple variables, you would need to return them individually. E.g. return checkpoint, model, epoc, loss etc.. and where you call the function, you will need to catch each return value in to another variable. E.g. checkpoint, model, epoc, loss = load_checkpoint('checkpoint.pth')Nicolnicola
Thanks for help Adam. I'm closer and closer. The problem right now is I get only fc1, fc2, fc3 model without AlexNet features so I can't really come back to train it.Calefacient
K
9

So your Network is essentially the classifier part of AlexNet and you're looking to load pretrained AlexNet weights into it. The problem is that the keys in state_dict are "fully qualified", which means that if you look at your network as a tree of nested modules, a key is just a list of modules in each branch, joined with dots like grandparent.parent.child. You want to

  1. Keep only the tensors with name starting with "classifier."
  2. Remove the "classifier." part of keys

so try

model = Network()
loaded_dict = checkpoint['model_state_dict']
prefix = 'classifier.'
n_clip = len(prefix)
adapted_dict = {k[n_clip:]: v for k, v in loaded_dict.items()
                if k.startswith(prefix)}
model.load_state_dict(adapted_dict)
Kapp answered 23/12, 2018 at 21:13 Comment(5)
It doesn't return any error but when I print(model) it shows None. I mean after this model = load_checkpoint('checkpoint.pth')Calefacient
I tried to achieve 1) and 2) in a more understandable way for me but the outcome is still the same. The model after loading is empty.Calefacient
Right, as mentioned by Adam above, you need to return the model if you want to get it as the return value of your function.Kapp
What do you mean? AlexNet is comprised of two parts called features and classifier. Your Network implements only classifier so yes, you're losing the features part. I was assuming that's what you meant to doKapp
I'd like to save the whole model to be able to train it in the futureCalefacient
C
4

in my case, i had to remove "module." prefix from the state dict to load.

    model= Model()
    state_dict = torch.load(model_path)
    remove_prefix = 'module.'
    state_dict = {k[len(remove_prefix):] if k.startswith(remove_prefix) else k: v for k, v in state_dict.items()}

After that,


    model.load_state_dict(state_dict)

Worked!

Camshaft answered 18/11, 2022 at 17:48 Comment(0)

© 2022 - 2024 — McMap. All rights reserved.