BatchNorm and the curious case of training vs. inference variance
Batch Normalization (BatchNorm) was one of the great innovations in neural networks in the last decade. It allowed training of deep neural networks to be stable and well-behaved and eased the stringent requirements on initialization. Although many other normalization techniques have since been introduced, it remains popular in the computer vision domain, and is still relevant for learning the fundamentals.
However, if you read the paper, you may notice one curious discrepancy: During training, the mini-batch variance is computed using a formula based on the biased sample variance, while during inference, the variance is calculated using an unbiased estimator over the entire training set. (The mean and variance are used to normalize the outputs from the previous layer to ensure zero mean and unit variance)
This apparent discrepancy has led to numerous discussions online, with some speculating that this was an error in the paper and that the authors implied in a later paper that an unbiased variance estimate (Bessel’s correction) should have been used consistently in both training and inference. There’s even a long open issue in PyTorch (which faithfully implements BatchNorm according to the original paper) because of this.
What gives? Is there really a bug in BatchNorm? Does it even matter?
Training and Inference variance details
During training, mini-batch variance is defined as: (Page 3 of the paper)
$$ \sigma^2_\mathcal{B} = {1\over{m}}\sum_{i=1}^{m}(x_i - \mu_\mathcal{B})^2 $$
This is equivalent to a biased sample variance.
During inference, there is no mini-batch, so the variance has to be obtained another way. This is done by computing the variance over essentially the entire training set, and storing/saving that value to use at inference time. This variance is computed as: (Page 4 of the paper)
$$ Var[x] = {m\over{m-1}}\cdot\mathbb{E}_{\mathcal{B}} [\sigma^2_\mathcal{B}] $$
This is the unbiased variance estimate, as seen by the ${m\over{m-1}}$ factor at the front. In practice, this (along with the mean used during inference) is computed as training progresses using an exponential moving average, which is why the formula above uses the expectation of $\sigma^2_\mathcal{B}$.
Is this a mismatch?
At first glance, this seems like an obvious training-inference mismatch, which is why so many issues have been filed in PyTorch for simply following the BatchNorm paper to the letter.
Andrej Karpathy, in one of his excellent video lectures on neural networks, also goes into this discrepancy in detail and recommends always using the unbiased sample variance to get rid of this training-inference mismatch. (This also results in a slightly different gradient than given in the paper for the backward pass through BatchNorm, which is also derived in his lecture)
My guess for why this difference could exist:
- During training we don’t care about estimating the population variance. We instead care about ensuring that mini-batch statistics have zero mean and unit variance (after normalization but before scale and shift), to ensure training is well-behaved and more stable.
- During inference, we don’t have any concept of mini-batch. We want to normalize the values as they come in. In this case, we have to use an estimate of the population variance, and the best estimate of that population variance is the unbiased sample variance over the entire training set.
I contacted the BatchNorm authors and asked them about this discrepancy. Both of them gave similar answers, to the effect of:
- Whether you use the unbiased variance during training should not matter greatly.
- However they did try both ways and experimentally, the approach in the paper worked better.
In my very limited experiments, I didn’t find much if any difference between using unbiased variance vs. biased variance in training, which seems to support these claims.