Input type (MPSFloatType) and weight type (torch.FloatTensor) should be the same
Asked Answered
M

2

5

I am trying to run this notebook on Apple M1 (1st gen) running MacOS 12.4,

libs freeze:


>pip3 freeze
anyio @ file:///private/tmp/jupyterlab--anyio-20211211-70040-1yv1wmx/anyio-3.4.0
appnope==0.1.2
argon2-cffi @ file:///private/tmp/jupyterlab--argon2-cffi-20211211-70040-1er07d0/argon2-cffi-21.2.0
argon2-cffi-bindings @ file:///private/tmp/jupyterlab--argon2-cffi-bindings-20211211-70040-o64kwi/argon2-cffi-bindings-21.2.0
asttokens==2.0.5
attrs @ file:///private/tmp/jupyterlab--attrs-20211211-70040-6u3qxt/attrs-21.2.0
Babel==2.9.1
backcall @ file:///private/tmp/jupyterlab--backcall-20211211-70040-acdr42/backcall-0.2.0
beniget==0.4.1
black==21.12b0
bleach==4.1.0
certifi==2022.5.18.1
cffi==1.15.0
charset-normalizer==2.0.12
click==8.0.3
cycler==0.10.0
Cython==0.29.24
debugpy @ file:///private/tmp/jupyterlab--debugpy-20211211-70040-2j9lay/debugpy-1.5.1
decorator==5.1.0
defusedxml @ file:///private/tmp/jupyterlab--defusedxml-20211211-70040-uowur4/defusedxml-0.7.1
entrypoints @ file:///private/tmp/jupyterlab--entrypoints-20211211-70040-1r2y5g4/entrypoints-0.3
et-xmlfile==1.1.0
executing==0.8.2
finnhub-python==2.4.5
gast==0.5.2
GDAL==3.4.0
gensim==4.1.2
graphviz==0.19.1
idna==3.3
imageio==2.13.5
ipykernel==6.6.0
ipython==7.30.1
ipython-genutils==0.2.0
ipywidgets==7.6.5
jedi==0.18.1
Jinja2==3.0.3
joblib==1.1.0
json5==0.9.6
jsonschema @ file:///private/tmp/jupyterlab--jsonschema-20211211-70040-1np642r/jsonschema-4.2.1
jupyter==1.0.0
jupyter-client==7.1.0
jupyter-console==6.4.0
jupyter-core==4.9.1
jupyter-server @ file:///private/tmp/jupyterlab--jupyter-server-20211211-70040-1u7h7vl/jupyter_server-1.13.1
jupyterlab @ file:///private/tmp/jupyterlab-20211211-70040-1ltrjpx/jupyterlab-3.2.5
jupyterlab-pygments==0.1.2
jupyterlab-server @ file:///private/tmp/jupyterlab--jupyterlab-server-20211211-70040-iufjhi/jupyterlab_server-2.8.2
jupyterlab-widgets==1.0.2
kiwisolver==1.3.2
lxml==4.6.3
MarkupSafe==2.0.1
matplotlib==3.4.3
matplotlib-inline==0.1.3
midi @ git+https://github.com/vishnubob/python-midi.git@abb85028c97b433f74621be899a0b399cd100aaa
midi-to-dataframe @ git+https://github.com/TaylorPeer/midi-to-dataframe@35347f787f01a2326234ad278d8c40bee3817f1d
mido==1.2.10
mistune==0.8.4
multitasking==0.0.9
mypy-extensions==0.4.3
nbclassic @ file:///private/tmp/jupyterlab--nbclassic-20211211-70040-1fah2fe/nbclassic-0.3.4
nbclient @ file:///private/tmp/jupyterlab--nbclient-20211211-70040-ptwp5d/nbclient-0.5.9
nbconvert==6.3.0
nbformat==5.1.3
nest-asyncio @ file:///private/tmp/jupyterlab--nest-asyncio-20211211-70040-72pz5e/nest_asyncio-1.5.4
networkx==2.6.3
notebook==6.4.6
numpy==1.23.0rc1
openpyxl==3.0.9
packaging @ file:///private/tmp/jupyterlab--packaging-20211211-70040-1f14ddt/packaging-21.3
pandas==1.4.2
pandocfilters==1.5.0
parso==0.8.3
pathspec==0.9.0
pexpect==4.8.0
pickleshare==0.7.5
Pillow==9.1.1
platformdirs==2.4.1
ply==3.11
prometheus-client==0.12.0
prompt-toolkit @ file:///private/tmp/jupyterlab--prompt-toolkit-20211211-70040-hcpjwc/prompt_toolkit-3.0.24
ptyprocess @ file:///private/tmp/jupyterlab--ptyprocess-20211211-70040-wjbvpa/ptyprocess-0.7.0
pure-eval==0.2.1
pybind11==2.8.0
pycparser==2.21
Pygments==2.10.0
pyparsing==3.0.6
pyrsistent @ file:///private/tmp/jupyterlab--pyrsistent-20211211-70040-1fnadg/pyrsistent-0.18.0
python-dateutil==2.8.2
pythran==0.10.0
pytz==2022.1
PyWavelets==1.2.0
PyYAML==6.0
pyzmq @ file:///private/tmp/jupyterlab--pyzmq-20211211-70040-2xtuon/pyzmq-22.3.0
qtconsole==5.2.2
QtPy==2.0.0
requests==2.27.1
scikit-image==0.19.1
scikit-learn==1.1.dev0
scipy==1.8.1
seaborn==0.11.2
Send2Trash==1.8.0
six==1.16.0
smart-open==5.2.1
sniffio @ file:///private/tmp/jupyterlab--sniffio-20211211-70040-wu3dri/sniffio-1.2.0
squarify==0.4.3
stack-data==0.1.4
terminado @ file:///private/tmp/jupyterlab--terminado-20211211-70040-dw1vl6/terminado-0.12.1
testpath @ file:///private/tmp/jupyterlab--testpath-20211211-70040-895z1/testpath-0.5.0
threadpoolctl==3.0.0
tifffile==2021.11.2
tomli==1.2.3
torch==1.13.0.dev20220528
torchaudio==0.11.0
torchsummary==1.5.1
torchtext==0.10.0
torchvision==0.14.0a0+f0f8a3c
torchviz==0.0.2
tornado==6.1
tqdm==4.62.3
traitlets @ file:///private/tmp/jupyterlab--traitlets-20211211-70040-ru76xv/traitlets-5.1.1
typing_extensions==4.2.0
urllib3==1.26.9
wcwidth==0.2.5
webencodings==0.5.1
websocket-client==1.2.3
wget==3.2
widgetsnbextension==3.5.2
yfinance==0.1.64

in the code , am setting device = torch.device('mps')

at this line: history = [evaluate(model, valid_dl)] am getting runtime error

Input type (MPSFloatType) and weight type (torch.FloatTensor) should be the same

Trace:


---------------------------------------------------------------------------
RuntimeError                              Traceback (most recent call last)
<timed exec> in <module>

/opt/homebrew/Cellar/jupyterlab/3.2.5/libexec/lib/python3.9/site-packages/torch/autograd/grad_mode.py in decorate_context(*args, **kwargs)
     25         def decorate_context(*args, **kwargs):
     26             with self.clone():
---> 27                 return func(*args, **kwargs)
     28         return cast(F, decorate_context)
     29 

/var/folders/mz/qfpvpvf550s039lrnxg70whh0000gn/T/ipykernel_11483/1143432410.py in evaluate(model, val_loader)
      3 def evaluate(model, val_loader):
      4     model.eval()
----> 5     outputs = [model.validation_step(batch) for batch in val_loader]
      6     return model.validation_epoch_end(outputs)
      7 

/var/folders/mz/qfpvpvf550s039lrnxg70whh0000gn/T/ipykernel_11483/1143432410.py in <listcomp>(.0)
      3 def evaluate(model, val_loader):
      4     model.eval()
----> 5     outputs = [model.validation_step(batch) for batch in val_loader]
      6     return model.validation_epoch_end(outputs)
      7 

/var/folders/mz/qfpvpvf550s039lrnxg70whh0000gn/T/ipykernel_11483/446280773.py in validation_step(self, batch)
     16     def validation_step(self, batch):
     17         images, labels = batch
---> 18         out = self(images)                   # Generate prediction
     19         loss = F.cross_entropy(out, labels)  # Calculate loss
     20         acc = accuracy(out, labels)          # Calculate accuracy

/opt/homebrew/Cellar/jupyterlab/3.2.5/libexec/lib/python3.9/site-packages/torch/nn/modules/module.py in _call_impl(self, *input, **kwargs)
   1128         if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks or _global_backward_hooks
   1129                 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1130             return forward_call(*input, **kwargs)
   1131         # Do not call functions when jit is used
   1132         full_backward_hooks, non_full_backward_hooks = [], []

/var/folders/mz/qfpvpvf550s039lrnxg70whh0000gn/T/ipykernel_11483/3789274317.py in forward(self, xb)
     29 
     30     def forward(self, xb): # xb is the loaded batch
---> 31         out = self.conv1(xb)
     32         out = self.conv2(out)
     33         out = self.res1(out) + out

/opt/homebrew/Cellar/jupyterlab/3.2.5/libexec/lib/python3.9/site-packages/torch/nn/modules/module.py in _call_impl(self, *input, **kwargs)
   1128         if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks or _global_backward_hooks
   1129                 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1130             return forward_call(*input, **kwargs)
   1131         # Do not call functions when jit is used
   1132         full_backward_hooks, non_full_backward_hooks = [], []

/opt/homebrew/Cellar/jupyterlab/3.2.5/libexec/lib/python3.9/site-packages/torch/nn/modules/container.py in forward(self, input)
    137     def forward(self, input):
    138         for module in self:
--> 139             input = module(input)
    140         return input
    141 

/opt/homebrew/Cellar/jupyterlab/3.2.5/libexec/lib/python3.9/site-packages/torch/nn/modules/module.py in _call_impl(self, *input, **kwargs)
   1128         if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks or _global_backward_hooks
   1129                 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1130             return forward_call(*input, **kwargs)
   1131         # Do not call functions when jit is used
   1132         full_backward_hooks, non_full_backward_hooks = [], []

/opt/homebrew/Cellar/jupyterlab/3.2.5/libexec/lib/python3.9/site-packages/torch/nn/modules/conv.py in forward(self, input)
    457 
    458     def forward(self, input: Tensor) -> Tensor:
--> 459         return self._conv_forward(input, self.weight, self.bias)
    460 
    461 class Conv3d(_ConvNd):

/opt/homebrew/Cellar/jupyterlab/3.2.5/libexec/lib/python3.9/site-packages/torch/nn/modules/conv.py in _conv_forward(self, input, weight, bias)
    453                             weight, bias, self.stride,
    454                             _pair(0), self.dilation, self.groups)
--> 455         return F.conv2d(input, weight, bias, self.stride,
    456                         self.padding, self.dilation, self.groups)
    457 

RuntimeError: Input type (MPSFloatType) and weight type (torch.FloatTensor) should be the same

MPS is still new and am trying to figure out the cause here, any suggestions are welcome, the code runs fine if torch device is set to CPU - just takes so much time.

Thanks, Deep Kamal Singh

Modification answered 29/5, 2022 at 5:58 Comment(0)
C
6

My guess is that the model has not been placed onto the MPS device.

If you place your model onto the MPS device (by calling model.to(device)), does your code work?

Countercheck answered 25/10, 2022 at 7:53 Comment(0)
K
0

I have meet the same situation. solve this by add to('mps') to model, for instance :

model = torchvision.models.resnet50(pretrained=True).to('mps').to(torch.float16)
Kodak answered 22/7 at 2:46 Comment(0)

© 2022 - 2024 — McMap. All rights reserved.