PyTorch 1.12 changed the default fp32 math to be "highest precision", and introduced the torch.set_float32_matmul_precision API, allowing users to specify which precision out of medium
, high
and highest
to use for the internal precision of float32 matrix multiplications.
From the documentation, I read that choosing a lower precision "may significantly increase performance, and in some programs the loss of precision has a negligible impact".
1. How do I determine whether my program would benefit from setting a lower precision? Is this purely empirical?
Similarly, when training with PyTorch Lightning, I get the following warning:
You are using a CUDA device ('NVIDIA A100-SXM4-40GB') that has Tensor Cores. To properly utilize them, you should set `torch.set_float32_matmul_precision('medium' | 'high')`
This seems to answer 1. (i.e., when your GPU has tensor cores, use a lower precision), but doesn't suggest which of the two lower precisions to use.
2. How do you determine which of the lower precisions ("high" or "medium") to use? Is this purely empirical? What's the suggested approach?