When using torch.autocast, how do I force individual layers to float32
Asked Answered
S

1

5

I'm trying to train a model in mixed precision. However, I want a few of the layers to be in full precision for stability reasons. How do I force an individual layer to be float32 when using torch.autocast? In particular, I'd like for this to be onnx compileable.

Is it something like:

with torch.autocast(device_type='cuda', enabled=False, dtype=torch.float16):
    out = my_unstable_layer(inputs.float())

Edit:

Looks like this is indeed the official method. See the torch docs.

Sulphide answered 22/8, 2022 at 18:3 Comment(0)
R
6

I think the motivation of torch.autocast is to automate the reduction of precision (not the increase).

If you have functions that need a particular dtype, you should consider using, custom_fwd

import torch
@torch.cuda.amp.custom_fwd(cast_inputs=torch.complex128)
def get_custom(x):
    print('  Decorated function received', x.dtype)
def regular_func(x):
    print('  Regular function received', x.dtype)
    get_custom(x)

x = torch.tensor(0.0, dtype=torch.half, device='cuda')
with torch.cuda.amp.autocast(False):
    print('autocast disabled')
    regular_func(x)
with torch.cuda.amp.autocast(True):
    print('autocast enabled')
    regular_func(x)
autocast disabled
  Regular function received torch.float16
  Decorated function received torch.float16
autocast enabled
  Regular function received torch.float16
  Decorated function received torch.complex128

Edit: Using torchscript

I am not sure how much you can rely on this, due to a comment in the documentation. However the comment is apparently outdated.

Here is an example where I trace the model with autocast enabled, feeze it and then I use it and the value is indeed cast to the specified type

class Cast(torch.nn.Module):    
    @torch.cuda.amp.custom_fwd(cast_inputs=torch.float64)
    def forward(self, x):
        return x

with torch.cuda.amp.autocast(True):
    model = torch.jit.trace(Cast().eval(), x)
model = torch.jit.freeze(model)

x = torch.tensor(0.0, dtype=torch.half, device='cuda')
print(model(x).dtype)
torch.float64

But I suggest you to validate this approach before using it for a serious application.

Ruprecht answered 29/8, 2022 at 12:53 Comment(4)
Do you know if this will work with torchscript?Sulphide
Not sure, maybe this means that don't.Ruprecht
When I use the approach I listed in my question above, it does appear to work in torch. It's just in torchscript that it fails. So I don't think the decorator is neededSulphide
Check the example I appended to the answer. Does it help?Ruprecht

© 2022 - 2025 — McMap. All rights reserved.