How does one pickle arbitrary pytorch models that use lambda functions?
Asked Answered
B

3

1

I currently have a neural network module:

import torch.nn as nn

class NN(nn.Module):
    def __init__(self,args,lambda_f,nn1, loss, opt):
        super().__init__()
        self.args = args
        self.lambda_f = lambda_f
        self.nn1 = nn1
        self.loss = loss
        self.opt = opt
        # more nn.Params stuff etc...

    def forward(self, x):
        #some code using fields
        return out

I am trying to checkpoint it but because pytorch saves using state_dicts it means I can't save the lambda functions I was actually using if I checkpoint with the pytorch torch.save etc. I literally want to save everything without issue and re-load to train on GPUs later. I currently am using this:

def save_ckpt(path_to_ckpt):
    from pathlib import Path
    import dill as pickle
    ## Make dir. Throw no exceptions if it already exists
    path_to_ckpt.mkdir(parents=True, exist_ok=True)
    ckpt_path_plus_path = path_to_ckpt / Path('db')

    ## Pickle args
    db['crazy_mdl'] = crazy_mdl
    with open(ckpt_path_plus_path , 'ab') as db_file:
        pickle.dump(db, db_file)

currently it throws no errors when I chekpoint it and it saved it.

I am worried that when I train it there might be a subtle bug even if no exceptions/errors are trained or something unexpected might happen (e.g. weird saving on disks in the clusters etc who knows).

Is this safe to do with pytorch classes/nn models? Especially if we want to resume training with GPUs?

Cross posted:

Belligerency answered 29/4, 2020 at 20:11 Comment(1)
this is not a good idea. If you do this then if your code changes to a different github repo then it will be hard restore your models that took a lot of time to train. The cycles spent recovering those or retraining is not worth it. I recommend to instead do it the pytorch way and only save the weights as they recommend in pytorch.Belligerency
B
0

this is not a good idea. If you do this then if your code changes to a different github repo then it will be hard restore your models that took a lot of time to train. The cycles spent recovering those or retraining is not worth it. I recommend to instead do it the pytorch way and only save the weights as they recommend in pytorch.

Belligerency answered 29/4, 2020 at 20:11 Comment(0)
S
4

I'm the dill author. I use dill (and klepto) to save classes that contain trained ANNs inside of lambda functions. I tend to use combinations of mystic and sklearn, so I can't speak directly to pytorch, but I can assume it works the same. The place where you have to be careful is if you have a lambda that contains a pointer to an object external to the lambda... so for example y = 4; f = lambda x: x+y. This might seem obvious, but dill will pickle the lambda, and depending on the rest of the code and the serialization variant, may not serialize the value of y. So, I've seen many cases where people serialize a trained estimator inside some function (or lambda, or class) and then the results aren't "correct" when they restore the function from serialization. The overarching cause is because the function wasn't encapsulated so all objects required for the function to yield the correct results are stored in the pickle. However, even in that case you can get the "correct" results back, but you'd just need to create the same environment you had when you pickled the estimator (i.e. all the same values it depends on in the surrounding namespace). The takeaway should be, try to make sure that all variables used in the function are defined within the function. Here's a portion of a class I've recently started to use myself (should be in the next release of mystic):

class Estimator(object):
    "a container for a trained estimator and transform (not a pipeline)"
    def __init__(self, estimator, transform):
        """a container for a trained estimator and transform

    Input:
        estimator: a fitted sklearn estimator
        transform: a fitted sklearn transform
        """
        self.estimator = estimator
        self.transform = transform
        self.function = lambda *x: float(self.estimator.predict(self.transform.transform(np.array(x).reshape(1,-1))).reshape(-1))
    def __call__(self, *x):
        "f(*x) for x of xtest and predict on fitted estimator(transform(xtest))"
        import numpy as np
        return self.function(*x)

Note when the function is called, everything that it uses (including np) is defined in the surrounding namespace. As long as pytorch estimators serialize as expected (without external references), then you should be fine if you follow the above guidelines.

Starfish answered 30/4, 2020 at 12:47 Comment(4)
Ok! So I guess dill works a little different from what I expected. I assumed that dill for lambda functions would save the "closure" (i.e. the function name, body AND the program environment/namespace) as it was during the pickling execution. It seems you are saying that is NOT how it works and instead what it does is save the names and then it uses the current local environment to resolve the lambda functions. Is that correct? That is sad to me because it means I can't use dill to pickle things without worrying (though I don't want to sound ungrateful, Im sure this is a hard problem).Belligerency
is the snippet of code you are showing trying to teach me/demo how to do what I need? i.e. having the right program env by saving the lambda function after the field definition for the class that we pickled? That way when it is restored, it uses the data value it pickled when restoring the lambda function. Is that basically what your trying to show me?Belligerency
dill has several serialization variants, all which treat the global namespace differently. One doesn't save any of the global namespace (like pickle), one tries to save all members of the namespace that are referenced directly (like cloudpickle), and then two variants that are unique to dill -- save the global namespace as a dict, and save an object by extracting the generating code. So, dill does save the namespace. What I'm saying is you have to help it out to make sure the references are hooked up as expected... and you can do that by encapsulating the variables as above.Starfish
What I'm showing with the above code is a strategy that allows you to write a lambda that uses pointer references to the enclosing namespace -- and will still serialize as expected. A naked lambda will serialize, but you are more likely to have an issue where there's a reference from within the lambda that is not pointing to the expected value.Starfish
B
0

this is not a good idea. If you do this then if your code changes to a different github repo then it will be hard restore your models that took a lot of time to train. The cycles spent recovering those or retraining is not worth it. I recommend to instead do it the pytorch way and only save the weights as they recommend in pytorch.

Belligerency answered 29/4, 2020 at 20:11 Comment(0)
S
0

Yes, I think it is safe to use dill to pickle lambda functions etc. I have been using torch.save with dill to save state dict and have had no problems resuming training over GPU as well as CPU unless the model class was changed. Even if the model class was changed (adding/deleting some parameters), I could load state dict, modify it, and load to the model.

Also, usually, people don't save the model objects but only state dicts i.e parameter values to resume the training along with hyperparameters/model arguments to get the same model object later.

Saving model object can be sometimes problematic as changes to model class (code) can make the saved object useless. If you don't plan on changing your model class/code at all and hence the model object won't be changed then maybe saving objects can work well but generally, it is not recommended to pickle module object.

Scrivner answered 30/4, 2020 at 20:18 Comment(0)

© 2022 - 2024 — McMap. All rights reserved.