How to choose which internal precision of float32 matrix multiplications to use in PyTorch?
Asked Answered
C

1

9

PyTorch 1.12 changed the default fp32 math to be "highest precision", and introduced the torch.set_float32_matmul_precision API, allowing users to specify which precision out of medium, high and highest to use for the internal precision of float32 matrix multiplications.

From the documentation, I read that choosing a lower precision "may significantly increase performance, and in some programs the loss of precision has a negligible impact".

1. How do I determine whether my program would benefit from setting a lower precision? Is this purely empirical?

Similarly, when training with PyTorch Lightning, I get the following warning:

You are using a CUDA device ('NVIDIA A100-SXM4-40GB') that has Tensor Cores. To properly utilize them, you should set `torch.set_float32_matmul_precision('medium' | 'high')`

This seems to answer 1. (i.e., when your GPU has tensor cores, use a lower precision), but doesn't suggest which of the two lower precisions to use.

2. How do you determine which of the lower precisions ("high" or "medium") to use? Is this purely empirical? What's the suggested approach?

Contention answered 30/7, 2023 at 9:39 Comment(1)
Predicting how precision affects quality is hard. AFAICT most people determine it empirically. If the model works well enough in lower precision, great. If it does't, use higher precision or fix the model.Bacteriostat
P
0

How do you determine which of the lower precisions ("high" or "medium") to use?

Well, first, you need to figure out what that setting actually means. I'm not a PyTorch expert, but I know my GPUs, so I can tell you that NVIDIA GPUs, at the moment, don't use single-precision internally in matrix multiplications. Not quite sure why, but you get either double-precision (FP64), half-precision (FP16), brain-half-precision (BF16; one bit take from mantissa in favor of exponent); TF32 (mantissa of FP16, exponent of FP32); or even lower precisions which you likely don't care about.

I am guessing that medium precision means operations using TF32 and high precision means FP64. But you need to make sure that's actually the case.

Once you've determined what this actually means, and assuming it is as I suspect - you need to evaluate the numerical problem you're working on:

  • Either you are able to obtain an upper-bound on its sensitivity to internal numerical errors, i.e. how far is the output from the perfect-precision output, given an error of up to $\delta$ in each FP operation;

  • Or you can establish ground truths/perfect-precision/near-perfect-precision results, and also faithfully sample your input space, in which case you could use estimates on error behavior - and simply check how bad things get when setting the precision to 'medium'.

The second alternative is riskier and less sound.

Printery answered 4/9, 2024 at 12:27 Comment(1)
This is what the documentation says, there is no mention of fp64: pytorch.org/docs/stable/generated/…Bonded

© 2022 - 2025 — McMap. All rights reserved.