How do I re-use trained fastai models?
Asked Answered
S

3

10

How do I load pretrained model using fastai implementation over PyTorch? Like in SkLearn I can use pickle to dump a model in file then load and use later. I've use .load() method after declaring learn instance like bellow to load previously saved weights:

arch=resnet34
data = ImageClassifierData.from_paths(PATH, tfms=tfms_from_model(arch, sz))
learn = ConvLearner.pretrained(arch, data, precompute=False)
learn.load('resnet34_test')

Then to predict the class of an image:

trn_tfms, val_tfms = tfms_from_model(arch,100)
img = open_image('circle/14.png')
im = val_tfms(img)
preds = learn.predict_array(im[None])
print(np.argmax(preds))

But It gets me the error:

ValueError: Expected more than 1 value per channel when training, got input size [1, 1024]

This code works if I use learn.fit(0.01, 3) instead of learn.load(). What I really want is to avoid the training step In my application.

Smalto answered 21/3, 2018 at 4:34 Comment(1)
the best way is to print the details of the model and the name and shapes of tensors in the pretrained model to see what is going wrong. from your current description, it is vague that what is the actual problem.Vaduz
S
2

This error occurs whenever a batch of your data contains a single element.

Solution 1: Call learn.predict() after learn.load('resnet34_test')

Solution 2: Remove 1 data point from your training set.

Pytorch issue

Fastai forum issue description

Supralapsarian answered 2/5, 2018 at 17:6 Comment(0)
A
2

This could be an edge case where batch size equals 1 for some batch. Make sure none of you batches = 1 (mostly the last batch)

Apotheosis answered 29/3, 2018 at 1:54 Comment(0)
S
2

This error occurs whenever a batch of your data contains a single element.

Solution 1: Call learn.predict() after learn.load('resnet34_test')

Solution 2: Remove 1 data point from your training set.

Pytorch issue

Fastai forum issue description

Supralapsarian answered 2/5, 2018 at 17:6 Comment(0)
O
-2

In training you will get this error if you have 1 data in training set batch.

If you are using model to predict output please make sure to set

learner.eval()
Occasionally answered 12/10, 2018 at 12:57 Comment(0)

© 2022 - 2024 — McMap. All rights reserved.