calculation of mean and variance in batch normalization in convolutional neural network
Asked Answered
R

2

5

May I ask if the following understanding of batch normalization in convolutional neural network is correct?

As shown in the diagram below, the mean and variance are calculated using all the cells on the same feature maps generated from respective examples in the current mini-batch, i.e. they are calculated across h, w and m axis.

enter image description here

Romalda answered 7/1, 2021 at 13:56 Comment(0)
M
4

It seems you are correct. The empirical mean and variance are measured on all dimension except the feature dimension. The z-score is then calculated to standardize the mini-batch to mean=0 and std=1. Additionally, it is then scaled-shifted with two learnable parameters gamma and beta.

Here is a description of a batch normalization layer:

Description
Input
Parameters
Output

And the calculation details:

Name Intermediate operations
Mini-batch mean
Mini-batch variance
Normalize
Scale & shift

Here is a quick implementation to show you the normalization process without the scale-shift:

>>> a = torch.eye(2,4).reshape(2,2,2)
>>> b = torch.arange(8).reshape(2,2,2)
>>> x = torch.stack([a, b])
tensor([[[[1., 0.],
          [0., 0.]],

         [[0., 1.],
          [0., 0.]]],


        [[[0., 1.],
          [2., 3.]],

         [[4., 5.],
          [6., 7.]]]])

We are looking to measure the mean and variance on all axes except the channel axis. So we start by permuting the batch axis with the channel axis, then flatten all axes but the first. Finally we take the average and variance.

>>> x_ = x.permute(1,0,2,3).flatten(start_dim=1)
>>> mean, var = x_.mean(dim=-1), x_.var(dim=-1)
(tensor([0.8750, 2.8750]), tensor([1.2679, 8.6964]))

>>> y = (x - mean)/(var + 1e-8).sqrt()
tensor([[[[ 0.1110, -0.9749],
          [-0.7771, -0.9749]],

         [[-0.7771, -0.6358],
          [-0.7771, -0.9749]]],


        [[[-0.7771, -0.6358],
          [ 0.9991,  0.0424]],

         [[ 2.7753,  0.7206],
          [ 4.5515,  1.3988]]]])

Notice the shapes of mean and variance: vectors whose length equals the number of input channels. The same could be said about the shapes of gamma and beta.

Mcgehee answered 7/1, 2021 at 14:33 Comment(4)
The formula to calculate the mini-batch mean is a bit confusing. It is dividing the sum by m. Should not it be divided by mhw? Also it looks like x_i is representing a tensor of shape CxHxW. If that is the case, the formula seems to indicate calculating a mean and var for each pixel of CxHxW across batch size N. Could you please clarify?Shawn
What you said is correct, you need to divide the sum by m*h*w (m being the batch size). Instead of summing then dividing, you can apply torch.mean on the appropriate dimensions. In my post, the permution + reshaping operation might be a little confusing. I didn't know at the time of writing, you can pass a tuple of dimensions to torch.mean and torch.var. So mean = x.mean(dim=(0,2,3)) and var = x.var(dim=(0,2,3)). Let me know if that seems clear to you!Mcgehee
Edit Sorry, I just realized you were referring to the m in the table, x_i, where i goes over [1, m] covers all pixels and on all images of the batch for one of the channels. Which means here m is the product m*h*w.Mcgehee
Thanks Ivan. I though m goes over the batch dimension only and hence my confusion. Now it is clear.Shawn
C
4

The picture depicts BatchNorm correctly.

In BatchNorm we compute the mean and variance using the spatial feature maps of the same channel in the whole batch. If you look at the picture that you've attached It may sound confusing because, in that picture, the data is single-channel, which means each grid/matrix represents 1 data sample, however, if you think of colored images, those will require 3 such grid/matrix to represent 1 data sample as they have 3 channels (RGB) per sample. So in your picture, you could think of taking the same element (index) from every m grid/matrices and then calculate their mean and variance.

So your picture does show the computation of mean and variance for BatchNorm correctly, however when you'll think of multi-channel data, you might get confused as the picture only good for understanding single-channel data. To make that case (multi-channel) a bit clear, you may think of a colored image dataset. So in every batch, there are a number of images, and each image has 3 channels, RED, GREEN, and BLUE (to visualize, think of RED as a matrix, GREEN as a matrix, and BLUE as a matrix, so 3 matrices per image). So in BatchNorm, what you would do now is (assume batch size is 32) take all the 32 matrices of RED channel and calculate their mean and variance, similarly, you'll repeat the process for GREEN and BLUE channels, so that's how you'd do for multi-channeled data.

Cantle answered 7/1, 2021 at 15:6 Comment(2)
Thanks, Khalid. Yes, you're right that the diagram refers to "one feature map" only. It's actually part of a larger diagram I'm creating to help myself to understand convolutional neural network in which BN is applied to each feature map in a convolutional layer. Apology that I should ask the question in its original context. Anyway, I'll post the image of the "context" in my own reply to this post later.Romalda
great! you've done a good job on the diagram, very expressive.Cantle

© 2022 - 2025 — McMap. All rights reserved.