How to calculate batch normalization with python?
Asked Answered
B

1

1

When I implement batch normalization in python from scrach, I am confused. Please see A paper demonstrates some figures about normalization methods, I think it may be not correct. The description and figure are both not correct.

Description from the paper:

enter image description here

Figure from the paper: original figure of BN As far as I am concerned, the representation of batch normalization is not correct in the original paper. I post the issue here for discussion. I think the batch normalization should be like the following figure.

BN

The key point is how to calculate mean and std. With feature maps' shape as (batch_size, channel_number, width, height), mean = X.mean(axis=(0, 2, 3), keepdims=True) or mean = X.mean(axis=(0, 1), keepdims=True)

Which one is correct?

Backward answered 8/1, 2020 at 17:57 Comment(0)
R
2

You should calculate mean and std across all pixels in the images of the batch. So use axis=(0, 2, 3) parameters. If the channels have roughly same distributions - you may calculate mean and std across channels as well. so just use mean() and std() without axes parameter.

The figure in the article is correct - it takes mean and std across H and W (image dimensions) for each batch. Obviously, channel is not shown in the 3d cube.

Reservoir answered 8/1, 2020 at 22:8 Comment(6)
Thanks. Could you please also see here for more details about this question? Hope to listen to your answer again.Backward
The mean and std for batch of images should be just single number, (or 3 numbers if you chose to do so for each channel). You then substract this single mean from each pixel value and then divide each pixel value by std. Run this algorithm and see that it works.Reservoir
Thanks, it works. What I am confused is that why some people think batch_size=1 will lead to division by zero?Backward
Those people, who think that batch_size=1 will lead to division by zero assume that std is calculated separately for each pixel. But even for batch_size=1 it is calculated across H*W pixels. please accept the answer.Reservoir
Thank you very much for your answer! Love you!Backward
>Obviously, channel is not shown in the 3d cube. However there is a C axis on the plots. It's not clear how to read this plots.Jackstraws

© 2022 - 2025 — McMap. All rights reserved.