It seems you are correct. The empirical mean and variance are measured on all dimension except the feature dimension. The z-score is then calculated to standardize the mini-batch to mean=0
and std=1
. Additionally, it is then scaled-shifted with two learnable parameters gamma
and beta
.
Here is a description of a batch normalization layer:
|
Description |
Input |
|
Parameters |
|
Output |
|
And the calculation details:
Name |
Intermediate operations |
Mini-batch mean |
|
Mini-batch variance |
|
Normalize |
|
Scale & shift |
|
Here is a quick implementation to show you the normalization process without the scale-shift:
>>> a = torch.eye(2,4).reshape(2,2,2)
>>> b = torch.arange(8).reshape(2,2,2)
>>> x = torch.stack([a, b])
tensor([[[[1., 0.],
[0., 0.]],
[[0., 1.],
[0., 0.]]],
[[[0., 1.],
[2., 3.]],
[[4., 5.],
[6., 7.]]]])
We are looking to measure the mean and variance on all axes except the channel axis. So we start by permuting the batch axis with the channel axis, then flatten all axes but the first. Finally we take the average and variance.
>>> x_ = x.permute(1,0,2,3).flatten(start_dim=1)
>>> mean, var = x_.mean(dim=-1), x_.var(dim=-1)
(tensor([0.8750, 2.8750]), tensor([1.2679, 8.6964]))
>>> y = (x - mean)/(var + 1e-8).sqrt()
tensor([[[[ 0.1110, -0.9749],
[-0.7771, -0.9749]],
[[-0.7771, -0.6358],
[-0.7771, -0.9749]]],
[[[-0.7771, -0.6358],
[ 0.9991, 0.0424]],
[[ 2.7753, 0.7206],
[ 4.5515, 1.3988]]]])
Notice the shapes of mean
and variance
: vectors whose length equals the number of input channels. The same could be said about the shapes of gamma
and beta
.