Unpickling saved pytorch model throws AttributeError: Can't get attribute 'Net' on <module '__main__' despite adding class definition inline
Asked Answered
T

8

10

I'm trying to serve a pytorch model in a flask app. This code was working when I ran this on a jupyter notebook earlier but now I'm running this within a virtual env and apparently it can't get attribute 'Net' even though the class definition is right there. All the other similar questions tell me to add the class definition of the saved model in the same script. But it still doesn't work. The torch version is 1.0.1 (where the saved model was trained as well as the virtualenv) What am I doing wrong? Here's my code.

import os
import numpy as np
from flask import Flask, request, jsonify 
import requests

import torch
from torch import nn
from torch.nn import functional as F


MODEL_URL = 'https://storage.googleapis.com/judy-pytorch-model/classifier.pt'


r = requests.get(MODEL_URL)
file = open("model.pth", "wb")
file.write(r.content)
file.close()

class Net(nn.Module):

    def __init__(self):
        super(Net, self).__init__()
        self.fc1 = nn.Linear(input_size, hidden_size)
        self.fc2 = nn.Linear(hidden_size, hidden_size)
        self.fc3 = nn.Linear(hidden_size, output_size)

    def forward(self, x):
        x = torch.sigmoid(self.fc1(x))
        x = torch.sigmoid(self.fc2(x))
        x = self.fc3(x)

        return F.log_softmax(x, dim=-1)

model = torch.load('model.pth')

app = Flask(__name__)

@app.route("/")
def hello():
    return "Binary classification example\n"

@app.route('/predict', methods=['GET'])
def predict():


    x_data = request.args['x_data']

    x_data =  x_data.split()
    x_data = list(map(float, x_data))

    sample = np.array(x_data) 

    sample_tensor = torch.from_numpy(sample).float()

    out = model(sample_tensor)

    _, predicted = torch.max(out.data, -1)

    if predicted.item() == 0: 
         pred_class = "Has no liver damage - ", predicted.item()
    elif predicted.item() == 1:
        pred_class = "Has liver damage - ", predicted.item()

    return jsonify(pred_class)

Here's the full traceback:

Traceback (most recent call last):
  File "/Users/judyraj/Judy/pytorch-deployment/flask_app/liver_disease_finder/bin/flask", line 10, in <module>
    sys.exit(main())
  File "/Users/judyraj/Judy/pytorch-deployment/flask_app/liver_disease_finder/lib/python3.6/site-packages/flask/cli.py", line 894, in main
    cli.main(args=args, prog_name=name)
  File "/Users/judyraj/Judy/pytorch-deployment/flask_app/liver_disease_finder/lib/python3.6/site-packages/flask/cli.py", line 557, in main
    return super(FlaskGroup, self).main(*args, **kwargs)
  File "/Users/judyraj/Judy/pytorch-deployment/flask_app/liver_disease_finder/lib/python3.6/site-packages/click/core.py", line 717, in main
    rv = self.invoke(ctx)
  File "/Users/judyraj/Judy/pytorch-deployment/flask_app/liver_disease_finder/lib/python3.6/site-packages/click/core.py", line 1137, in invoke
    return _process_result(sub_ctx.command.invoke(sub_ctx))
  File "/Users/judyraj/Judy/pytorch-deployment/flask_app/liver_disease_finder/lib/python3.6/site-packages/click/core.py", line 956, in invoke
    return ctx.invoke(self.callback, **ctx.params)
  File "/Users/judyraj/Judy/pytorch-deployment/flask_app/liver_disease_finder/lib/python3.6/site-packages/click/core.py", line 555, in invoke
    return callback(*args, **kwargs)
  File "/Users/judyraj/Judy/pytorch-deployment/flask_app/liver_disease_finder/lib/python3.6/site-packages/click/decorators.py", line 64, in new_func
    return ctx.invoke(f, obj, *args, **kwargs)
  File "/Users/judyraj/Judy/pytorch-deployment/flask_app/liver_disease_finder/lib/python3.6/site-packages/click/core.py", line 555, in invoke
    return callback(*args, **kwargs)
  File "/Users/judyraj/Judy/pytorch-deployment/flask_app/liver_disease_finder/lib/python3.6/site-packages/flask/cli.py", line 767, in run_command
    app = DispatchingApp(info.load_app, use_eager_loading=eager_loading)
  File "/Users/judyraj/Judy/pytorch-deployment/flask_app/liver_disease_finder/lib/python3.6/site-packages/flask/cli.py", line 293, in __init__
    self._load_unlocked()
  File "/Users/judyraj/Judy/pytorch-deployment/flask_app/liver_disease_finder/lib/python3.6/site-packages/flask/cli.py", line 317, in _load_unlocked
    self._app = rv = self.loader()
  File "/Users/judyraj/Judy/pytorch-deployment/flask_app/liver_disease_finder/lib/python3.6/site-packages/flask/cli.py", line 372, in load_app
    app = locate_app(self, import_name, name)
  File "/Users/judyraj/Judy/pytorch-deployment/flask_app/liver_disease_finder/lib/python3.6/site-packages/flask/cli.py", line 235, in locate_app
    __import__(module_name)
  File "/Users/judyraj/Judy/pytorch-deployment/flask_app/app.py", line 34, in <module>
    model = torch.load('model.pth')
  File "/Users/judyraj/Judy/pytorch-deployment/flask_app/liver_disease_finder/lib/python3.6/site-packages/torch/serialization.py", line 368, in load
    return _load(f, map_location, pickle_module)
  File "/Users/judyraj/Judy/pytorch-deployment/flask_app/liver_disease_finder/lib/python3.6/site-packages/torch/serialization.py", line 542, in _load
    result = unpickler.load()
AttributeError: Can't get attribute 'Net' on <module '__main__' from '/Users/judyraj/Judy/pytorch-deployment/flask_app/liver_disease_finder/bin/flask'>

This doesn't solve my issue. I do not want to change the way I persist the model. torch.save() worked fine for me outside the virtual env. I don't mind adding the class definition to the script. I'm trying to see what's causing the error despite that.

Toreutics answered 3/4, 2019 at 6:50 Comment(8)
This has nothing to do with that. torch.save() was working fine for me outside the virtualenv. I'm just trying to figure out how to fix the error. I don't want to change the way to model persistance.Toreutics
how did you save the model? did you save entire model or just its state_dict?Gild
The entire model. Not the state_dict. And I can load it and use it successfully locally. I can't do it within the virtualenv. I'm trying to deploy it to AWS LambdaToreutics
This is exactly what the "duplicate" thread is telling you: save the state_dict rather than the model to be robust to changes in your environment.Gild
It will work if I just use the state_dict. I'm trying to understand why pickle throws the Attribute Error despite adding the class definition.Toreutics
How are you running your app? Can you add a print(__name__) line in your code? I'm guessing that the __name__ of your script was equal to __main__ when saving the pickle but is something different now, when you're running it with flask, causing an attribute lookup error.Snowinsummer
You may look at #1 and #2 to understand why pickle throws error.Plausible
@Gild although the semi-duplicate link is useful, I don't think this is a duplicate. This question is more about how to wrestle through saving both the model and the weights -- which is a non-trivial question in its own right.Impenetrability
I
8

(This is a partial answer)

I don't think torch.save(model,'model.pt') works from the command prompt, or when a model is saved from one script running as '__main__' and loaded from another.

The reason is that torch must be automatically loading the module that was used to save the file, and it gets the module name from __name__.

Now for the partial part: It's unclear how to fix this issue, especially when you have virtualenvs in the mix.

Thanks to Jatentaki for starting the conversation in this direction.

Impenetrability answered 28/8, 2020 at 17:36 Comment(0)
G
5

I know I am late to answer this. But figured out a way to load the model from another package instead of the "__main__"

before loading the module if the attribute is set dynamically as below it would work.

import __main__
setattr(__main__, "Net", Net)
model = torch.load(os.path.join(parent_dir,"<path to pickle>"), map_location=torch.device("cpu"))

Note: if the "__main__" is a binary, then this hack will not work.

Gnome answered 5/12, 2022 at 6:25 Comment(1)
This techniques also solves the issue faced by loading model saved by torch.saveVancevancleave
L
3

Here, the saving and loading of the model is done under the hood using pickle. The disadvantage of this approach is that the serialized data is bound to the specific classes and the exact directory structure used when the model is saved. The reason for this is because Pickle does not save the model class itself. Rather, it saves a path to the file containing the class, which is used during load time. Because of this, your code can break in various ways when used in other projects or after refactors, for more info.

Solution: Instead of using torch.save(model, PATH), use following:

# Saving model
model_scripted = torch.jit.script(model) # Export to TorchScript
model_scripted.save('model_scripted.pt') # Save

# Loading model
model = torch.jit.load('model_scripted.pt')
model.eval()
Loredo answered 7/4, 2023 at 12:26 Comment(2)
Confirming that this solved the issue of saving the model in one file and loading it in another file on my end!Discovery
Any idea why pickle needs the path to the file containing the class? Suppose we pickle an object of class A. Then we modify class A and finally we load the unpickle the pickled object. There would be no error in this case.Peddling
W
1

First thing I've initialized an empty model and then loaded the saved model, this solved the issue for some reason.

Weakly answered 13/5, 2021 at 16:6 Comment(1)
I run into the same issue. Importing the class definition solved the problem for me.Ichthyology
D
1

One easy solution to your problem is that you need to define "class Net(nn.Module):" before loading your model . And that will solve this issue

Desai answered 16/3, 2022 at 2:35 Comment(0)
P
1

I stumbled the same problem recently and solved it by different way of saving my model.

When I was saving it like this:

torch.save(model, 'model_name.pth')

and then was loading it like this:

loaded_model = torch.load('model_name.pth')

in Flask app I was getting error about Flask not being able to find custom class I declared my model with during training. Even when the code for this model Class was copied in the Flask app code just before model load line.

However when I changed the code for saving model to:

torch.save(loaded_model.state_dict(), 'model_name.pth')

and loading code to:

loaded_model = TheModelClass(*args, **kwargs)
loaded_model.load_state_dict(torch.load('model_name.pth'))

everything worked fine. (Of course, like the doc says, you need to declare custom model class before loading the model in flask app code.)

Hope this helps!

Portfolio answered 28/12, 2023 at 9:28 Comment(0)
S
0

Simple solution:

  1. You just need to create an instance of class Net(nn.Module) as follows, and then it will run fine.
  2. I've faced the same problem, and solved with these simple steps.
import torch
from torch import nn
from torch.nn import functional as F


MODEL_URL = 'https://storage.googleapis.com/judy-pytorch-model/classifier.pt'


r = requests.get(MODEL_URL)
file = open("model.pth", "wb")
file.write(r.content)
file.close()

class Net(nn.Module):

    def __init__(self):
        super(Net, self).__init__()
        self.fc1 = nn.Linear(input_size, hidden_size)
        self.fc2 = nn.Linear(hidden_size, hidden_size)
        self.fc3 = nn.Linear(hidden_size, output_size)

    def forward(self, x):
        x = torch.sigmoid(self.fc1(x))
        x = torch.sigmoid(self.fc2(x))
        x = self.fc3(x)

        return F.log_softmax(x, dim=-1)

model = Net()#<---------------------------- Extra thing added
model = torch.load('model.pth', , map_location=torch.device('cpu'))#<---- if running on a CPU, else 'cuda'

app = Flask(__name__)

@app.route("/")
def hello():
    return "Binary classification example\n"

@app.route('/predict', methods=['GET'])
def predict():


    x_data = request.args['x_data']

    x_data =  x_data.split()
    x_data = list(map(float, x_data))

    sample = np.array(x_data) 

    sample_tensor = torch.from_numpy(sample).float()

    out = model(sample_tensor)

    _, predicted = torch.max(out.data, -1)

    if predicted.item() == 0: 
         pred_class = "Has no liver damage - ", predicted.item()
    elif predicted.item() == 1:
        pred_class = "Has liver damage - ", predicted.item()

    return jsonify(pred_class)
Stumble answered 6/4, 2022 at 12:23 Comment(0)
G
-1

This might not be a very popular answer, however, I find that the dill package is very consistent at making my code work. For me I am not even trying to load a model, I am trying to unpack a custom object that helps my stuff but it can't find it for some reason. I don't know why but dill seems to be a better option for pickling in my experience:

    # - path to files
    path = Path(path2dataset).expanduser()
    path2file_data_prep = Path(path2file_data_prep).expanduser()
    # - create dag dataprep obj
    print(f'path to data set {path=}')
    dag_prep = SplitDagDataPreparation(path)
    # - save data prep splits object
    print(f'saving to {path2file_data_prep=}')
    torch.save({'data_prep': dag_prep}, path2file_data_prep, pickle_module=dill)
    # - load the data prep splits object to test it loads correctly
    db = torch.load(path2file_data_prep, pickle_module=dill)
    db['data_prep']
    print(db)
    return path2file_data_prep
Grati answered 17/5, 2021 at 17:54 Comment(0)

© 2022 - 2025 — McMap. All rights reserved.