I'm trying to serve a pytorch model in a flask app. This code was working when I ran this on a jupyter notebook earlier but now I'm running this within a virtual env and apparently it can't get attribute 'Net' even though the class definition is right there. All the other similar questions tell me to add the class definition of the saved model in the same script. But it still doesn't work. The torch version is 1.0.1 (where the saved model was trained as well as the virtualenv) What am I doing wrong? Here's my code.
import os
import numpy as np
from flask import Flask, request, jsonify
import requests
import torch
from torch import nn
from torch.nn import functional as F
MODEL_URL = 'https://storage.googleapis.com/judy-pytorch-model/classifier.pt'
r = requests.get(MODEL_URL)
file = open("model.pth", "wb")
file.write(r.content)
file.close()
class Net(nn.Module):
def __init__(self):
super(Net, self).__init__()
self.fc1 = nn.Linear(input_size, hidden_size)
self.fc2 = nn.Linear(hidden_size, hidden_size)
self.fc3 = nn.Linear(hidden_size, output_size)
def forward(self, x):
x = torch.sigmoid(self.fc1(x))
x = torch.sigmoid(self.fc2(x))
x = self.fc3(x)
return F.log_softmax(x, dim=-1)
model = torch.load('model.pth')
app = Flask(__name__)
@app.route("/")
def hello():
return "Binary classification example\n"
@app.route('/predict', methods=['GET'])
def predict():
x_data = request.args['x_data']
x_data = x_data.split()
x_data = list(map(float, x_data))
sample = np.array(x_data)
sample_tensor = torch.from_numpy(sample).float()
out = model(sample_tensor)
_, predicted = torch.max(out.data, -1)
if predicted.item() == 0:
pred_class = "Has no liver damage - ", predicted.item()
elif predicted.item() == 1:
pred_class = "Has liver damage - ", predicted.item()
return jsonify(pred_class)
Here's the full traceback:
Traceback (most recent call last):
File "/Users/judyraj/Judy/pytorch-deployment/flask_app/liver_disease_finder/bin/flask", line 10, in <module>
sys.exit(main())
File "/Users/judyraj/Judy/pytorch-deployment/flask_app/liver_disease_finder/lib/python3.6/site-packages/flask/cli.py", line 894, in main
cli.main(args=args, prog_name=name)
File "/Users/judyraj/Judy/pytorch-deployment/flask_app/liver_disease_finder/lib/python3.6/site-packages/flask/cli.py", line 557, in main
return super(FlaskGroup, self).main(*args, **kwargs)
File "/Users/judyraj/Judy/pytorch-deployment/flask_app/liver_disease_finder/lib/python3.6/site-packages/click/core.py", line 717, in main
rv = self.invoke(ctx)
File "/Users/judyraj/Judy/pytorch-deployment/flask_app/liver_disease_finder/lib/python3.6/site-packages/click/core.py", line 1137, in invoke
return _process_result(sub_ctx.command.invoke(sub_ctx))
File "/Users/judyraj/Judy/pytorch-deployment/flask_app/liver_disease_finder/lib/python3.6/site-packages/click/core.py", line 956, in invoke
return ctx.invoke(self.callback, **ctx.params)
File "/Users/judyraj/Judy/pytorch-deployment/flask_app/liver_disease_finder/lib/python3.6/site-packages/click/core.py", line 555, in invoke
return callback(*args, **kwargs)
File "/Users/judyraj/Judy/pytorch-deployment/flask_app/liver_disease_finder/lib/python3.6/site-packages/click/decorators.py", line 64, in new_func
return ctx.invoke(f, obj, *args, **kwargs)
File "/Users/judyraj/Judy/pytorch-deployment/flask_app/liver_disease_finder/lib/python3.6/site-packages/click/core.py", line 555, in invoke
return callback(*args, **kwargs)
File "/Users/judyraj/Judy/pytorch-deployment/flask_app/liver_disease_finder/lib/python3.6/site-packages/flask/cli.py", line 767, in run_command
app = DispatchingApp(info.load_app, use_eager_loading=eager_loading)
File "/Users/judyraj/Judy/pytorch-deployment/flask_app/liver_disease_finder/lib/python3.6/site-packages/flask/cli.py", line 293, in __init__
self._load_unlocked()
File "/Users/judyraj/Judy/pytorch-deployment/flask_app/liver_disease_finder/lib/python3.6/site-packages/flask/cli.py", line 317, in _load_unlocked
self._app = rv = self.loader()
File "/Users/judyraj/Judy/pytorch-deployment/flask_app/liver_disease_finder/lib/python3.6/site-packages/flask/cli.py", line 372, in load_app
app = locate_app(self, import_name, name)
File "/Users/judyraj/Judy/pytorch-deployment/flask_app/liver_disease_finder/lib/python3.6/site-packages/flask/cli.py", line 235, in locate_app
__import__(module_name)
File "/Users/judyraj/Judy/pytorch-deployment/flask_app/app.py", line 34, in <module>
model = torch.load('model.pth')
File "/Users/judyraj/Judy/pytorch-deployment/flask_app/liver_disease_finder/lib/python3.6/site-packages/torch/serialization.py", line 368, in load
return _load(f, map_location, pickle_module)
File "/Users/judyraj/Judy/pytorch-deployment/flask_app/liver_disease_finder/lib/python3.6/site-packages/torch/serialization.py", line 542, in _load
result = unpickler.load()
AttributeError: Can't get attribute 'Net' on <module '__main__' from '/Users/judyraj/Judy/pytorch-deployment/flask_app/liver_disease_finder/bin/flask'>
This doesn't solve my issue. I do not want to change the way I persist the model. torch.save() worked fine for me outside the virtual env. I don't mind adding the class definition to the script. I'm trying to see what's causing the error despite that.
save
the model? did you save entire model or just itsstate_dict
? – Gildstate_dict
rather than the model to be robust to changes in your environment. – Gildprint(__name__)
line in your code? I'm guessing that the__name__
of your script was equal to__main__
when saving the pickle but is something different now, when you're running it with flask, causing an attribute lookup error. – Snowinsummer