How to load a checkpoint file in a pytorch model?
Asked Answered
S

2

8

In my pytorch model, I'm initializing my model and optimizer like this.

model = MyModelClass(config, shape, x_tr_mean, x_tr,std)
optimizer = optim.SGD(model.parameters(), lr=config.learning_rate)

And here is the path to my checkpoint file.

checkpoint_file = os.path.join(config.save_dir, "checkpoint.pth")

To load this checkpoint file, I check and see if the checkpoint file exists and then I load it as well as the model and optimizer.

if os.path.exists(checkpoint_file):
    if config.resume:
        torch.load(checkpoint_file)
        model.load_state_dict(torch.load(checkpoint_file))
        optimizer.load_state_dict(torch.load(checkpoint_file))

Also, here's how I'm saving my model and optimizer.

 torch.save({'model': model.state_dict(), 'optimizer': optimizer.state_dict(), 'iter_idx': iter_idx, 'best_va_acc': best_va_acc}, checkpoint_file)

For some reason I keep getting a strange error whenever I run this code.

model.load_state_dict(torch.load(checkpoint_file))
File "/home/Josh/.local/lib/python3.6/site-packages/torch/nn/modules/module.py", line 769, in load_state_dict
self.__class__.__name__, "\n\t".join(error_msgs)))
RuntimeError: Error(s) in loading state_dict for MyModelClass:
        Missing key(s) in state_dict: "mean", "std", "attribute.weight", "attribute.bias".
        Unexpected key(s) in state_dict: "model", "optimizer", "iter_idx", "best_va_acc"

Does anyone know why I'm getting this error?

Scapegoat answered 13/2, 2019 at 19:10 Comment(0)
C
4

You saved the model parameters in a dictionary. You're supposed to use the keys, that you used while saving earlier, to load the model checkpoint and state_dicts like this:

if os.path.exists(checkpoint_file):
    if config.resume:
        checkpoint = torch.load(checkpoint_file)
        model.load_state_dict(checkpoint['model'])
        optimizer.load_state_dict(checkpoint['optimizer'])

You can check the official tutorial on PyTorch website for more info.

Cymbre answered 13/2, 2019 at 19:26 Comment(0)
M
-1

You're storing the states of the model, optimizer and other key: value pairs separately. model.load_state_dict() takes in the dict(key: value pair data structure) and checks if its schema/data structure matches with the model's parameters. For example, if the model has 3 weights w1, w2, w3 and a bias b1 then it'd expect the object that you pass into load_state_dict() to have 3 weights and one bias values. If it has anything more or less, it throws a "Missing key" error. To fix this, use:

model.load_state_dict(torch.load(checkpoint_file)['model'])

This only loads the model's saved weights/biases and not the other things that you saved that the model doesn't require. Similarly, you can load the states for optimizer and other stuff

Meingoldas answered 28/2, 2023 at 18:16 Comment(2)
that will still throw missing key error. If you want to ignore it, use load_state_dict(state_dict, strict=False)Heavyduty
How so? If you save the state of the same model and load it back into the same model?Meingoldas

© 2022 - 2024 — McMap. All rights reserved.