I currently have a neural network module:
import torch.nn as nn
class NN(nn.Module):
def __init__(self,args,lambda_f,nn1, loss, opt):
super().__init__()
self.args = args
self.lambda_f = lambda_f
self.nn1 = nn1
self.loss = loss
self.opt = opt
# more nn.Params stuff etc...
def forward(self, x):
#some code using fields
return out
I am trying to checkpoint it but because pytorch saves using state_dict
s it means I can't save the lambda functions I was actually using if I checkpoint with the pytorch torch.save
etc. I literally want to save everything without issue and re-load to train on GPUs later. I currently am using this:
def save_ckpt(path_to_ckpt):
from pathlib import Path
import dill as pickle
## Make dir. Throw no exceptions if it already exists
path_to_ckpt.mkdir(parents=True, exist_ok=True)
ckpt_path_plus_path = path_to_ckpt / Path('db')
## Pickle args
db['crazy_mdl'] = crazy_mdl
with open(ckpt_path_plus_path , 'ab') as db_file:
pickle.dump(db, db_file)
currently it throws no errors when I chekpoint it and it saved it.
I am worried that when I train it there might be a subtle bug even if no exceptions/errors are trained or something unexpected might happen (e.g. weird saving on disks in the clusters etc who knows).
Is this safe to do with pytorch classes/nn models? Especially if we want to resume training with GPUs?
Cross posted:
- How does one pickle arbitrary pytorch models that use lambda functions?
- https://discuss.pytorch.org/t/how-does-one-pickle-arbitrary-pytorch-models-that-use-lambda-functions/79026
- https://www.reddit.com/r/pytorch/comments/gagpjg/how_does_one_pickle_arbitrary_pytorch_models_that/?
- https://www.quora.com/unanswered/How-does-one-pickle-arbitrary-PyTorch-models-that-use-lambda-functions