Is there any way I can download the pre-trained models available in PyTorch to a specific path?
Asked Answered
C

4

16

I am referring to the models that can be found here: https://pytorch.org/docs/stable/torchvision/models.html#torchvision-models

Cairn answered 3/10, 2018 at 13:35 Comment(0)
C
33

As, @dennlinger mentioned in his answer : torch.utils.model_zoo, is being internally called when you load a pre-trained model.

More specifically, the method: torch.utils.model_zoo.load_url() is being called every time a pre-trained model is loaded. The documentation for the same, mentions:

The default value of model_dir is $TORCH_HOME/models where $TORCH_HOME defaults to ~/.torch.

The default directory can be overridden with the $TORCH_HOME environment variable.

This can be done as follows:

import torch 
import torchvision
import os

# Suppose you are trying to load pre-trained resnet model in directory- models\resnet

os.environ['TORCH_HOME'] = 'models\\resnet' #setting the environment variable
resnet = torchvision.models.resnet18(pretrained=True)

I came across the above solution by raising an issue in the PyTorch's GitHub repository: https://github.com/pytorch/vision/issues/616

This led to an improvement in the documentation i.e. the solution mentioned above.

Cairn answered 12/10, 2018 at 17:44 Comment(3)
for me it was in ~/.cache/torch, without explicitly setting itSendai
for me it was in ~/.cache/torch/checkpoints/ For example: wget download.pytorch.org/models/vgg19-dcbb9e9d.pth -P ~/.cache/torch/checkpoints/Pastoralize
Is there a similar YamNet in tf.hub in the model zoo?Inearth
F
9

Yes, you can simply copy the urls and use wget to download it to the desired path. Here's an illustration:

For AlexNet:

$ wget -c https://download.pytorch.org/models/alexnet-owt-4df8aa71.pth

For Google Inception (v3):

$ wget -c https://download.pytorch.org/models/inception_v3_google-1a9a5a14.pth

For SqueezeNet:

$ wget -c https://download.pytorch.org/models/squeezenet1_1-f364aa15.pth

For MobileNetV2:

$ wget -c https://download.pytorch.org/models/mobilenet_v2-b0353104.pth

For DenseNet201:

$ wget -c https://download.pytorch.org/models/densenet201-c1103571.pth

For MNASNet1_0:

$ wget -c https://download.pytorch.org/models/mnasnet1.0_top1_73.512-f206786ef8.pth

For ShuffleNetv2_x1.0:

$ wget -c https://download.pytorch.org/models/shufflenetv2_x1-5666bf0f80.pth

If you want to do it in Python, then use something like:

In [11]: from six.moves import urllib

# resnet 101 host url
In [12]: url = "https://download.pytorch.org/models/resnet101-5d3b4d8f.pth"

# download and rename the file to `resnet_101.pth`
In [13]: urllib.request.urlretrieve(url, "resnet_101.pth")
Out[13]: ('resnet_101.pth', <http.client.HTTPMessage at 0x7f7fd7f53438>)

P.S: You can find the download URLs in the respective python modules of torchvision.models

Fagen answered 3/10, 2018 at 18:7 Comment(0)
R
0

There is a script available that will output a list of URLs across the entire package.

From within the pytorch/vision package execute the following:

python scripts/collect_model_urls.py .

# ...
# https://download.pytorch.org/models/swin_v2_b-781e5279.pth
# https://download.pytorch.org/models/swin_v2_s-637d8ceb.pth
# https://download.pytorch.org/models/swin_v2_t-b137f0e2.pth
# https://download.pytorch.org/models/vgg11-8a719046.pth
# https://download.pytorch.org/models/vgg11_bn-6002323d.pth
# ...
Ribald answered 24/11, 2022 at 10:13 Comment(0)
M
-1

TL;DR: No, it is not possible directly, but you can easily adapt it.

I think what you want to do is to look at torch.utils.model_zoo, which is internally called when you load a pre-trained model:

If we look at the code for the pre-trained models, for example AlexNet here, we can see that it simply calls the previously mentioned model_zoo function, but without the saved location. You can either modify the PyTorch source to specify this (that would actually be a great addition IMO, so maybe open a pull request for that), or else simply adopt the code in the second link to your own liking (and save it to a custom location under a different name), and then manually insert the relevant location there.

If you want to regularly update PyTorch, I would heavily recommend the second method, since it doesn't involve directly altering PyTorch's code base, and potentially throw errors during updates.

Methodology answered 3/10, 2018 at 16:31 Comment(0)

© 2022 - 2024 — McMap. All rights reserved.