How to convert float64 to make it work in apple silicon?
Asked Answered
A

0

7

I am trying to load a pre-trained weight to mps GPU device of Apple M1. To reproduce the issue minimally, I can run this:

torch.load('yolov7_training.pt', map_location='mps')

which produces the following exception:

  File "train.py", line 619, in <module>
    train(hyp, opt, device, tb_writer)
  File "train.py", line 72, in train
    torch.load('yolov7_training.pt', map_location='mps')
  File "/Users/smahasanulhaque/miniconda3/envs/torch-gpu/lib/python3.8/site-packages/torch/serialization.py", line 789, in load
    return _load(opened_zipfile, map_location, pickle_module, **pickle_load_args)
  File "/Users/smahasanulhaque/miniconda3/envs/torch-gpu/lib/python3.8/site-packages/torch/serialization.py", line 1131, in _load
    result = unpickler.load()
  File "/Users/smahasanulhaque/miniconda3/envs/torch-gpu/lib/python3.8/site-packages/torch/_utils.py", line 153, in _rebuild_tensor_v2
    tensor = _rebuild_tensor(storage, storage_offset, size, stride)
  File "/Users/smahasanulhaque/miniconda3/envs/torch-gpu/lib/python3.8/site-packages/torch/_utils.py", line 146, in _rebuild_tensor
    t = torch.tensor([], dtype=storage.dtype, device=storage.untyped().device)
TypeError: Cannot convert a MPS Tensor to float64 dtype as the MPS framework doesn't support float64. Please use float32 instead.
(torch-gpu)

I am a beginner at pytorch, and see no option to cast it as float32 while loading, as suggested by the exception. How to make this work?

My dumb workaround will be to load it into CPU, make it float32, and then load to mps device. But not sure how to do it, or will it even work.

Abrahamabrahams answered 8/11, 2022 at 20:11 Comment(4)
None of the Apple GPUs support double precision floating point arithmetics.Warrenne
You could try something like model = torch.load('yolov7_training.pt', map_location="cpu"); model = model.float(); model.to("mps")Decrial
@MatthewR. that's exactly how I wanted to do it. the problem is, the model in your example is a dict, not the model itself, and model['model'] contains the model. So, None of .float() or .to('mps') doesn't work on it, as it's just a dict.Abrahamabrahams
Oh right, I forgot the load_state_dict part. If you aren't loading the state dict into a model, could you do something like for k in loaded_state_dict: loaded_state_dict[k].float().to("mps")?Decrial

© 2022 - 2024 — McMap. All rights reserved.