How do I save a trained model in PyTorch?
Asked Answered
R

11

389

How do I save a trained model in PyTorch? I have read that:

  1. torch.save()/torch.load() is for saving/loading a serializable object.
  2. model.state_dict()/model.load_state_dict() is for saving/loading model state.
Rothberg answered 9/3, 2017 at 19:6 Comment(7)
I think it's because torch.save() save all the intermediate variables as well, like intermediate outputs for back propagation use. But you only need to save the model parameters, like weight/bias etc. Sometimes the former can be much larger than the latter.Nnw
I tested torch.save(model, f) and torch.save(model.state_dict(), f). The saved files have the same size. Now I am confused. Also, I found using pickle to save model.state_dict() extremely slow. I think the best way is to use torch.save(model.state_dict(), f) since you handle the creation of the model, and torch handles the loading of the model weights, thus eliminating possible issues. Reference: discuss.pytorch.org/t/saving-torch-models/838/4Nnw
Seems like PyTorch have addressed this a bit more explicitly in their tutorials section—there's lots of good info there that's not listed in the answers here, including saving more than one model at a time and warm starting models.Ephemeron
what is wrong with using pickle?Rodin
@CharlieParker torch.save is based on pickle. The following is from the tutorial linked above: "[torch.save] will save the entire module using Python’s pickle module. The disadvantage of this approach is that the serialized data is bound to the specific classes and the exact directory structure used when the model is saved. The reason for this is because pickle does not save the model class itself. Rather, it saves a path to the file containing the class, which is used during load time. Because of this, your code can break in various ways when used in other projects or after refactors."Bartram
@DavidMiller actually I only need to save a nn.Sequential model. Do you know how to do that? I don't have a model class definition. For sequential I wrote this, hopefully a reputable answerer will confirm: #62923552Rodin
Then I use torch.save(model, file) it seems to drop batchnorm/dropout parameters because I never get same result then reload the model even if I make model.eval() pytorch 1.5 version. So you should use recommended way of saving just to avoid bugs I thinkPokeweed
S
392

Found this page on their github repo:

Recommended approach for saving a model

There are two main approaches for serializing and restoring a model.

The first (recommended) saves and loads only the model parameters:

torch.save(the_model.state_dict(), PATH)

Then later:

the_model = TheModelClass(*args, **kwargs)
the_model.load_state_dict(torch.load(PATH))

The second saves and loads the entire model:

torch.save(the_model, PATH)

Then later:

the_model = torch.load(PATH)

However in this case, the serialized data is bound to the specific classes and the exact directory structure used, so it can break in various ways when used in other projects, or after some serious refactors.


See also: Save and Load the Model section from the official PyTorch tutorials.

Sarpedon answered 6/5, 2017 at 10:18 Comment(14)
According to @smth discuss.pytorch.org/t/saving-and-loading-a-model-in-pytorch/… model reloads to train model by default. so need to manually call the_model.eval() after loading, if you are loading it for inference, not resuming training.Pellagra
the second method gives #53798509 error on windows 10. wasn't able to solve itGeerts
Is there any option to save without a need for an access for model class?Tiphani
With that approach how do you keep track of the *args and **kwargs you need to pass in for the load case?Hooker
actually I only need to save a nn.Sequential model. Do you know how to do that? I don't have a model class definition.Rodin
@CharlieParker sorry no idea, haven’t worked on pytorch for too longSarpedon
@Sarpedon the_model = TheModelClass(*args, **kwargs). Running this command says NameError: name 'TheModelClass' is not defined. How should I go about thisDickens
@NarenBabuR that was just a dummy name. You are to replace it with the actual model class you created or whichever predefined model you used from PyTorch. The args and kwargs are whatever you used for defining the model. You can save both the weights and biases as well as the extra parameters required to load the model in. See the second answer below.Ortega
I saved the rnn model using second method. After loading, i want to do predictions and calculate metrics such as F1. However, prediction doesn't work. I get this error IndexError: index out of range in self when running this: predictions = model(batch.textt).squeeze(1)Cleocleobulus
Hi guys, could anyone tell me what is the extension for model dict file(.pth?) and the extension for the entire model file(.pkl)?? Am I correct?Brockwell
Note, that torch.save does NOT override an existing file and does so silently. Be aware when doing hyperparameter sweeps or such automated training runs ;)Metamorphosis
In the linked page, I see that the loading terminates with model.eval(): does it mean we cannot restart training ?Olander
What file extension should you use with torch.save(model.state_dict(), filepath)?Dissatisfaction
Does saving the model along with the model definition still work ?Culmiferous
F
262

It depends on what you want to do.

Case # 1: Save the model to use it yourself for inference: You save the model, you restore it, and then you change the model to evaluation mode. This is done because you usually have BatchNorm and Dropout layers that by default are in train mode on construction:

torch.save(model.state_dict(), filepath)

#Later to restore:
model.load_state_dict(torch.load(filepath))
model.eval()

Case # 2: Save model to resume training later: If you need to keep training the model that you are about to save, you need to save more than just the model. You also need to save the state of the optimizer, epochs, score, etc. You would do it like this:

state = {
    'epoch': epoch,
    'state_dict': model.state_dict(),
    'optimizer': optimizer.state_dict(),
    ...
}
torch.save(state, filepath)

To resume training you would do things like: state = torch.load(filepath), and then, to restore the state of each individual object, something like this:

model.load_state_dict(state['state_dict'])
optimizer.load_state_dict(state['optimizer'])

Since you are resuming training, DO NOT call model.eval() once you restore the states when loading.

Case # 3: Model to be used by someone else with no access to your code: In Tensorflow you can create a .pb file that defines both the architecture and the weights of the model. This is very handy, specially when using Tensorflow serve. The equivalent way to do this in Pytorch would be:

torch.save(model, filepath)

# Then later:
model = torch.load(filepath)

This way is still not bullet proof and since pytorch is still undergoing a lot of changes, I wouldn't recommend it.

Fissi answered 2/3, 2018 at 23:34 Comment(10)
Is there a recommended file ending for the 3 cases? Or is it always .pth?Unblock
In the Case #3 torch.load returns just an OrderedDict. How do you get the model in order to make predictions?Gangling
Hi, May I know how to do the mentioned "Case # 2: Save model to resume training later"? I managed to load the checkpoint to model, then I unable to run or resume to train model like "model.to(device) model = train_model_epoch(model, criterion, optimizer, sched, epochs)"Juni
Hi, for case one which is for inference, in the official pytorch doc say that must save optimizer state_dict for either inference or completing training. "When saving a general checkpoint, to be used for either inference or resuming training, you must save more than just the model’s state_dict. It is important to also save the optimizer’s state_dict, as this contains buffers and parameters that are updated as the model trains. "Conceivable
In case #3, model class should be defined somewhere.Tiphani
Another question dives into some of the mess of Case #3. I've written a similar model saver for Keras, and this process of saving the model source code is inherently messy.Otherdirected
For use case #3, you probably want a model exchange format like ONNX rather than pickling the whole model.Hirohito
Same question here, what are the extensions for these 3 cases? I know there are .pth, .pkl, but could someone do the mapping between extension names and the cases??Brockwell
@VerenaHaunschmid yet it would be .pt fileIrish
In case #1, does it function by itself or do we need have certain files in the directory? I'm not fully understanding the difference between case #1 and #3Unessential
S
36

The pickle Python library implements binary protocols for serializing and de-serializing a Python object.

When you import torch (or when you use PyTorch) it will import pickle for you and you don't need to call pickle.dump() and pickle.load() directly, which are the methods to save and to load the object.

In fact, torch.save() and torch.load() will wrap pickle.dump() and pickle.load() for you.

A state_dict the other answer mentioned deserves just a few more notes.

What state_dict do we have inside PyTorch? There are actually two state_dicts.

The PyTorch model is torch.nn.Module which has model.parameters() call to get learnable parameters (w and b). These learnable parameters, once randomly set, will update over time as we learn. Learnable parameters are the first state_dict.

The second state_dict is the optimizer state dict. You recall that the optimizer is used to improve our learnable parameters. But the optimizer state_dict is fixed. Nothing to learn there.

Because state_dict objects are Python dictionaries, they can be easily saved, updated, altered, and restored, adding a great deal of modularity to PyTorch models and optimizers.

Let's create a super simple model to explain this:

import torch
import torch.optim as optim

model = torch.nn.Linear(5, 2)

# Initialize optimizer
optimizer = optim.SGD(model.parameters(), lr=0.001, momentum=0.9)

print("Model's state_dict:")
for param_tensor in model.state_dict():
    print(param_tensor, "\t", model.state_dict()[param_tensor].size())

print("Model weight:")    
print(model.weight)

print("Model bias:")    
print(model.bias)

print("---")
print("Optimizer's state_dict:")
for var_name in optimizer.state_dict():
    print(var_name, "\t", optimizer.state_dict()[var_name])

This code will output the following:

Model's state_dict:
weight      torch.Size([2, 5])
bias      torch.Size([2])
Model weight:
Parameter containing:
tensor([[ 0.1328,  0.1360,  0.1553, -0.1838, -0.0316],
        [ 0.0479,  0.1760,  0.1712,  0.2244,  0.1408]], requires_grad=True)
Model bias:
Parameter containing:
tensor([ 0.4112, -0.0733], requires_grad=True)
---
Optimizer's state_dict:
state      {}
param_groups      [{'lr': 0.001, 'momentum': 0.9, 'dampening': 0, 'weight_decay': 0, 'nesterov': False, 'params': [140695321443856, 140695321443928]}]

Note this is a minimal model. You may try to add stack of sequential

model = torch.nn.Sequential(
          torch.nn.Linear(D_in, H),
          torch.nn.Conv2d(A, B, C)
          torch.nn.Linear(H, D_out),
        )

Note that only layers with learnable parameters (convolutional layers, linear layers, etc.) and registered buffers (batchnorm layers) have entries in the model's state_dict.

Non-learnable things belong to the optimizer object state_dict, which contains information about the optimizer's state, as well as the hyperparameters used.

The rest of the story is the same; in the inference phase (this is a phase when we use the model after training) for predicting; we do predict based on the parameters we learned. So for the inference, we just need to save the parameters model.state_dict().

torch.save(model.state_dict(), filepath)

And to use later model.load_state_dict(torch.load(filepath)) model.eval()

Note: Don't forget the last line model.eval() this is crucial after loading the model.

Also don't try to save torch.save(model.parameters(), filepath). The model.parameters() is just the generator object.

On the other hand, torch.save(model, filepath) saves the model object itself, but keep in mind the model doesn't have the optimizer's state_dict. Check the other excellent answer by @Jadiel de Armas to save the optimizer's state dict.

Samos answered 17/4, 2019 at 19:1 Comment(2)
Although it is not a straightforward solution, the essence of the problem is deeply analyzed! Upvote.Integumentary
in model.load_state_dict, what is the model? If I'm exporting it to another environment, what would be the code sample to export and then import my model? I'm training a Bert. not sure what class to use to then load these parameters intoUnessential
R
23

A common PyTorch convention is to save models using either a .pt or .pth file extension.

Save/Load Entire Model

Save:

path = "username/directory/lstmmodelgpu.pth"
torch.save(trainer, path)

Load:

(Model class must be defined somewhere)

model.load_state_dict(torch.load(PATH))
model.eval()
Ronni answered 13/5, 2019 at 20:46 Comment(1)
it raised: AttributeError: 'dict' object has no attribute 'eval'Collection
B
17

If you want to save the model and wants to resume the training later:

Single GPU: Save:

state = {
        'epoch': epoch,
        'state_dict': model.state_dict(),
        'optimizer': optimizer.state_dict(),
}
savepath='checkpoint.t7'
torch.save(state,savepath)

Load:

checkpoint = torch.load('checkpoint.t7')
model.load_state_dict(checkpoint['state_dict'])
optimizer.load_state_dict(checkpoint['optimizer'])
epoch = checkpoint['epoch']

Multiple GPU: Save

state = {
        'epoch': epoch,
        'state_dict': model.module.state_dict(),
        'optimizer': optimizer.state_dict(),
}
savepath='checkpoint.t7'
torch.save(state,savepath)

Load:

checkpoint = torch.load('checkpoint.t7')
model.load_state_dict(checkpoint['state_dict'])
optimizer.load_state_dict(checkpoint['optimizer'])
epoch = checkpoint['epoch']

#Don't call DataParallel before loading the model otherwise you will get an error

model = nn.DataParallel(model) #ignore the line if you want to load on Single GPU
Boice answered 21/5, 2020 at 18:23 Comment(0)
M
6

Saving locally

How you save your model depends on how you want to access it in the future. If you can call a new instance of the model class, then all you need to do is save/load the weights of the model with model.state_dict():

# Save:
torch.save(old_model.state_dict(), PATH)

# Load:
new_model = TheModelClass(*args, **kwargs)
new_model.load_state_dict(torch.load(PATH))

If you cannot for whatever reason (or prefer the simpler syntax), then you can save the entire model (actually a reference to the file(s) defining the model, along with its state_dict) with torch.save():

# Save:
torch.save(old_model, PATH)

# Load:
new_model = torch.load(PATH)

But since this is a reference to the location of the files defining the model class, this code is not portable unless those files are also ported in the same directory structure.

Saving to cloud - TorchHub

If you wish your model to be portable, you can easily allow it to be imported with torch.hub. If you add an appropriately defined hubconf.py file to a github repo, this can be easily called from within PyTorch to enable users to load your model with/without weights:

hubconf.py (github.com/repo_owner/repo_name)

dependencies = ['torch']
from my_module import mymodel as _mymodel

def mymodel(pretrained=False, **kwargs):
    return _mymodel(pretrained=pretrained, **kwargs)

Loading model:

new_model = torch.hub.load('repo_owner/repo_name', 'mymodel')
new_model_pretrained = torch.hub.load('repo_owner/repo_name', 'mymodel', pretrained=True)
Merovingian answered 2/4, 2021 at 15:54 Comment(0)
H
2

pip install pytorch-lightning

make sure your parent model uses pl.LightningModule instead of nn.Module

Saving and loading checkpoints using pytorch lightning

import pytorch_lightning as pl

model = MyLightningModule(hparams)
trainer.fit(model)
trainer.save_checkpoint("example.ckpt")
new_model = MyModel.load_from_checkpoint(checkpoint_path="example.ckpt")
Hydrops answered 8/11, 2021 at 21:5 Comment(0)
I
1

I use this approach, hope it will be useful for you.

num_labels = len(test_label_cols)
robertaclassificationtrain = '/dbfs/FileStore/tables/PM/TC/roberta_model'
robertaclassificationpath = "/dbfs/FileStore/tables/PM/TC/ROBERTACLASSIFICATION"

model = RobertaForSequenceClassification.from_pretrained(robertaclassificationpath, 
num_labels=num_labels)
model.cuda()

model.load_state_dict(torch.load(robertaclassificationtrain))
model.eval()

Where I save my train model already in 'roberta_model' path. Save a train model.

torch.save(model.state_dict(), '/dbfs/FileStore/tables/PM/TC/roberta_model')
Irish answered 25/10, 2022 at 4:22 Comment(0)
T
1

Export/Load Model in TorchScript Format is another way of saving model

Another common way to do inference with a trained model is to use TorchScript, an intermediate representation of a PyTorch model that can be run in Python as well as in C++.

NOTE: Using the TorchScript format, you will be able to load the exported model and run inference without defining the model class.

class TheModelClass(nn.Module):
    def __init__(self):
        super(TheModelClass, self).__init__()
        self.conv1 = nn.Conv2d(3, 6, 5)
        self.pool = nn.MaxPool2d(2, 2)
        self.conv2 = nn.Conv2d(6, 16, 5)
        self.fc1 = nn.Linear(16 * 5 * 5, 120)
        self.fc2 = nn.Linear(120, 84)
        self.fc3 = nn.Linear(84, 10)

    def forward(self, x):
        x = self.pool(F.relu(self.conv1(x)))
        x = self.pool(F.relu(self.conv2(x)))
        x = x.view(-1, 16 * 5 * 5)
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        return x

# Initialize model
model = TheModelClass()

Export:

model_scripted = torch.jit.script(model) # Export to TorchScript
model_scripted.save('model_scripted.pt') # Save

Load [ Works w/o defining model class ]:

model = torch.jit.load('model_scripted.pt')
model.eval()

       **Model arch in Netron looks like this**
Toothwort answered 7/4, 2023 at 18:20 Comment(0)
L
0

These days everything is written in the official tutorial: https://pytorch.org/tutorials/beginner/saving_loading_models.html

You have several options on how to save and what to save and all is explained in that tutorial.

Lopeared answered 15/7, 2021 at 9:13 Comment(0)
V
0

save: torch.save(model, file_path)

load: model_save = torch.load(file_path)

Note:

  • comment for(training cycle) after train the dataset
  • and keep the model class (like the RNN class)
  • As long as you want to load the model without training, there will be no errors
Valdez answered 28/1 at 15:39 Comment(0)

© 2022 - 2024 — McMap. All rights reserved.