It's clear why a 16-bit floating-point format has started seeing use for machine learning; it reduces the cost of storage and computation, and neural networks turn out to be surprisingly insensitive to numeric precision.
What I find particularly surprising is that practitioners abandoned the already-defined half-precision format in favor of one that allocates only 7 bits to the significand, but 8 bits to the exponent – fully as many as 32-bit FP. (wikipedia compares brain-float bfloat16
layout against IEEE binary16 and some 24-bit formats.)
Why so many exponent bits? So far, I have only found https://cloud.google.com/blog/products/ai-machine-learning/bfloat16-the-secret-to-high-performance-on-cloud-tpus
Based on our years of experience training and deploying a wide variety of neural networks across Google’s products and services, we knew when we designed Cloud TPUs that neural networks are far more sensitive to the size of the exponent than that of the mantissa. To ensure identical behavior for underflows, overflows, and NaNs, bfloat16 has the same exponent size as FP32. However, bfloat16 handles denormals differently from FP32: it flushes them to zero. Unlike FP16, which typically requires special handling via techniques such as loss scaling [Mic 17], BF16 comes close to being a drop-in replacement for FP32 when training and running deep neural networks.
I haven't run neural network experiments on anything like Google scale, but in such as I have run, a weight or activation with absolute value much greater than 1.0 means it's gone into the weeds, is going to spiral off into infinity, and the computer would be doing you a favor if it were to promptly crash with an error message. I have never seen or heard of any case that needs a dynamic range anything like the 1e38 of single-precision floating point.
So what am I missing?
Are there cases where neural networks really need huge dynamic range? If so, how, why?
Is there some reason why it is considered very beneficial for bfloat16 to use the same exponent as single precision, even though the significand is much smaller?
Or is it the case that the real goal was to shrink the significand to the absolute minimum that would do the job, in order to minimize the chip area and energy cost of the multipliers, being the most expensive part of an FPU; it so happened this turned out to be around 7 bits; the total size should be a power of 2 for alignment reasons; it would not quite fit in 8 bits; going up to 16, left surplus bits that might as well be used for something, and the most elegant solution was to keep the 8-bit exponent?
float
, then keeping the same exponent range usually means you're ok with bfloat. Maybe? Not posting as an answer because this is conjecture based on vague memory of something I read a while ago. – Cloudland