Batch Normalization

  • 2024/08/10

About

This page contains notes from Karpathy's YouTube lecture Building makemore Part 3: Activations & Gradients, BatchNorm . I downloaded the video and generated captions using generate_captions.py. wherever "I" is seen, it refers Karpathy.


What does Batchnorm do?

We are using batch normalization to control the statistics of activations in the neural net. It is common to sprinkle batch normalization layer across the neural net and usually we will place it after layers that have multiplications like for example a linear layer or a convolutional layer.


Paper

From Batch Normalization: Accelerating Deep Network Training b y Reducing Internal Covariate Shift:

./images/batch-normalization-formula.png


Sample

import torch  
%matplotlib inline  
import matplotlib.pyplot as plt  
  
# Linear layer  
embcat = torch.randn(32, 30)   
W1 = torch.randn(30, 200)  
hpreact = embcat @ W1  
  
# Batch norm layer  
gamma = torch.ones(1, 200)  
beta = torch.zeros(1, 200)  
eps=1e-5  
  
batch_mean = hpreact.mean(dim=0, keepdim=True)  
batch_var = hpreact.var(dim=0, keepdim=True)  
batch_norm = (hpreact - batch_mean) / torch.sqrt(batch_var + eps)  
batch_norm = gamma * batch_norm + beta  
  
# hpreact.shape, batch_norm.shape
plt.figure(figsize=(20, 5))  
plt.subplot(121)  
plt.hist(hpreact.view(-1).tolist(), 50);  
plt.title('PRE batch-norm')  
  
plt.subplot(122)  
plt.hist(batch_norm.view(-1).tolist(), 50);  
plt.title('POST batch-norm');

Implementing Batchnorm layer

From makemore_part3_bn:

class BatchNorm1d:  
    def __init__(self, dim, eps=1e-5, momentum=0.1):  
        self.eps = eps  
        self.momentum = momentum  
        self.training = True  
  
        # parameters (trained with backprop)  
        self.gamma = torch.ones(dim)  
        self.beta = torch.zeros(dim)  
  
        # buffers (trained with a running 'momentum update')  
        self.running_mean = torch.zeros(dim)  
        self.running_var = torch.ones(dim)  
  
    def __call__(self, x):  
        # calculate the forward pass  
        if self.training:  
            xmean = x.mean(0, keepdim=True)  # batch mean  
            xvar = x.var(0, keepdim=True)  # batch variance  
        else:  
            xmean = self.running_mean  
            xvar = self.running_var  
  
        xhat = (x - xmean) / torch.sqrt(xvar + self.eps)  # normalize to unit variance  
        self.out = self.gamma * xhat + self.beta  
  
        # update the buffers  
        if self.training:  
            with torch.no_grad():  
                self.running_mean = (  
                    1 - self.momentum  
                ) * self.running_mean + self.momentum * xmean  
                self.running_var = (  
                    1 - self.momentum  
                ) * self.running_var + self.momentum * xvar  
  
        return self.out  
  
    def parameters(self):  
        return [self.gamma, self.beta]
  • Like other modules, BatchNorm has a different behaviour based on whether it training or evaluation. self.training flag determines this.
  • gamma, beta are trained with Backprop. These are Scale & Shift parameters.
  • During training, current batch's mean and variance are calculated.
  • During inference, running_mean, running_var are used. They're called buffers in PyTorch nomenclature. these buffers are trained using exponential moving average here explicitly and they are not part of the back propagation and the stochastic gradient descent. So they are not sort of parameters of this layer. That's why when we have parameters here, we only return gamma and beta; We do not return the mean and a variance. This is trained sort of like internally here every forward pass using exponential moving average.

Use of scale & shift:

  • We want the pre-activations to be roughly Gaussian but only at initialization But we don't want these to be forced to be Gaussian always.
  • We'd like to allow the neural net to move this around to potentially make it more diffuse, to make it more sharp, to make some tanh neurons maybe be more trigger-happy or less trigger-happy. So we'd like this distribution to move around, and we'd like the back-propagation to tell us how that distribution should move around
  • And so in addition to this idea of standardizing the activations at any point in the network we have to also introduce this additional component in the paper here described as scale and shift

so basically what we're doing is we're taking these normalized inputs and we are additionally scaling them by some gain and offsetting them by some bias to get our final output from this layer.


How Batchnorm is used

./images/batchnorm-placement.svg


Regularizing effect of Batch Normalization

Now, the stability offered by batch normalization actually comes at a terrible cost, and that cost is that if you think about what's happening here, something terribly strange and unnatural is happening. It used to be that we have a single example feeding into a neural net, and then we calculate it its activations and its logits, and this is a deterministic sort of process, so you arrive at some logits, for this example And then because of efficiency of training, we suddenly started to use batches of examples, but those batches of examples were processed independently, and it was just an efficiency thing.

But now suddenly in batch normalization, because of the normalization through the batch, we are coupling these examples mathematically and in the forward pass and the backward pass of the neural net. So now the hidden state activations, hpreact, and logits for any one input example, are not just a function of that example and its input, but they're also a function of all the other examples that happen to come for a ride in that batch And these examples are sampled randomly.

And so what's happening is, for example, when you look at hpreact that's going to feed into h, the hidden state activations, for example, for any one of these input examples, is going to actually change slightly depending on what other examples there are in a batch And depending on what other examples happen to come for a ride, h is going to change suddenly and is going to like jitter, if you imagine sampling different examples, because the statistics of the mean understanding deviation are going to be impacted And so you'll get a jitter for h, and you'll get a jitter for logits And you think that this would be a bug or something undesirable, but in a very strange way, this actually turns out to be good in neural network training as a side effect.

And the reason for that is that you can think of this as kind of like a regularizer Because what's happening is you have your input and And then depending on the other examples, this is jittering a bit. And so what that does is that it's effectively padding out any one of these input examples And it's introducing a little bit of entropy And because of the padding out, it's actually kind of like a form of data augmentation. and it's kind of like augmenting the input a little bit and it's jittering it, and that makes it harder for the neural nuts to overfit these concrete specific examples. So by introducing all this noise, it actually pads out the examples, and it regularizes the neural nut And that's one of the reasons why deceivingly as a second-order effect, this is actually a regularizer, and that has made it harder for us to remove the use of batch normalization. because basically no one likes this property that the examples in the batch are coupled mathematically and in the forward pass.

so people have tried to deprecate the use of batch normalization and move to other normalization techniques that do not couple the examples of a batch (Examples are layer normalization, instance normalization, group normalization) and so on. but basically long story short batch normalization was the first kind of normalization layer to be introduced it worked extremely well it happened to have this regularizing effect it stabilized training and people have been trying to remove it and move to some of the other normalization techniques but it's been hard because it just works quite well and some of the reason that it works quite well is again because of this regularizing effect and because it is quite effective at controlling the activations and their distributions.


Papers