This page contains notes from Karpathy's YouTube lecture Building makemore Part 3: Activations & Gradients, BatchNorm . I downloaded the video and generated captions using generate_captions.py. wherever "I" is seen, it refers Karpathy.
We are using batch normalization to control the statistics of activations in the neural net. It is common to sprinkle batch normalization layer across the neural net and usually we will place it after layers that have multiplications like for example a linear layer or a convolutional layer.
From Batch Normalization: Accelerating Deep Network Training b y Reducing Internal Covariate Shift:
import torch
%matplotlib inline
import matplotlib.pyplot as plt
# Linear layer
embcat = torch.randn(32, 30)
W1 = torch.randn(30, 200)
hpreact = embcat @ W1
# Batch norm layer
gamma = torch.ones(1, 200)
beta = torch.zeros(1, 200)
eps=1e-5
batch_mean = hpreact.mean(dim=0, keepdim=True)
batch_var = hpreact.var(dim=0, keepdim=True)
batch_norm = (hpreact - batch_mean) / torch.sqrt(batch_var + eps)
batch_norm = gamma * batch_norm + beta
# hpreact.shape, batch_norm.shape
plt.figure(figsize=(20, 5))
plt.subplot(121)
plt.hist(hpreact.view(-1).tolist(), 50);
plt.title('PRE batch-norm')
plt.subplot(122)
plt.hist(batch_norm.view(-1).tolist(), 50);
plt.title('POST batch-norm');
From makemore_part3_bn:
class BatchNorm1d:
def __init__(self, dim, eps=1e-5, momentum=0.1):
self.eps = eps
self.momentum = momentum
self.training = True
# parameters (trained with backprop)
self.gamma = torch.ones(dim)
self.beta = torch.zeros(dim)
# buffers (trained with a running 'momentum update')
self.running_mean = torch.zeros(dim)
self.running_var = torch.ones(dim)
def __call__(self, x):
# calculate the forward pass
if self.training:
xmean = x.mean(0, keepdim=True) # batch mean
xvar = x.var(0, keepdim=True) # batch variance
else:
xmean = self.running_mean
xvar = self.running_var
xhat = (x - xmean) / torch.sqrt(xvar + self.eps) # normalize to unit variance
self.out = self.gamma * xhat + self.beta
# update the buffers
if self.training:
with torch.no_grad():
self.running_mean = (
1 - self.momentum
) * self.running_mean + self.momentum * xmean
self.running_var = (
1 - self.momentum
) * self.running_var + self.momentum * xvar
return self.out
def parameters(self):
return [self.gamma, self.beta]
self.training
flag determines this.gamma
, beta
are trained with Backprop. These are Scale & Shift parameters.running_mean
, running_var
are used. They're called buffers in PyTorch nomenclature. these buffers are trained using exponential moving average here explicitly and they are not part of the back propagation and the stochastic gradient descent. So they are not sort of parameters of this layer. That's why when we have parameters here, we only return gamma and beta; We do not return the mean and a variance. This is trained sort of like internally here every forward pass using exponential moving average.Use of scale & shift:
- We want the pre-activations to be roughly Gaussian but only at initialization But we don't want these to be forced to be Gaussian always.
- We'd like to allow the neural net to move this around to potentially make it more diffuse, to make it more sharp, to make some
tanh
neurons maybe be more trigger-happy or less trigger-happy. So we'd like this distribution to move around, and we'd like the back-propagation to tell us how that distribution should move around- And so in addition to this idea of standardizing the activations at any point in the network we have to also introduce this additional component in the paper here described as scale and shift
so basically what we're doing is we're taking these normalized inputs and we are additionally scaling them by some gain and offsetting them by some bias to get our final output from this layer.
Now, the stability offered by batch normalization actually comes at a terrible cost, and that cost is that if you think about what's happening here, something terribly strange and unnatural is happening. It used to be that we have a single example feeding into a neural net, and then we calculate it its activations and its logits, and this is a deterministic sort of process, so you arrive at some logits, for this example And then because of efficiency of training, we suddenly started to use batches of examples, but those batches of examples were processed independently, and it was just an efficiency thing.
But now suddenly in batch normalization, because of the normalization through the batch, we are coupling these examples mathematically and in the forward pass and the backward pass of the neural net. So now the hidden state activations, hpreact
, and logits
for any one input example, are not just a function of that example and its input, but they're also a function of all the other examples that happen to come for a ride in that batch And these examples are sampled randomly.
And so what's happening is, for example, when you look at hpreact
that's going to feed into h
, the hidden state activations, for example, for any one of these input examples, is going to actually change slightly depending on what other examples there are in a batch And depending on what other examples happen to come for a ride, h
is going to change suddenly and is going to like jitter, if you imagine sampling different examples, because the statistics of the mean understanding deviation are going to be impacted And so you'll get a jitter for h
, and you'll get a jitter for logits And you think that this would be a bug or something undesirable, but in a very strange way, this actually turns out to be good in neural network training as a side effect.
And the reason for that is that you can think of this as kind of like a regularizer Because what's happening is you have your input and And then depending on the other examples, this is jittering a bit. And so what that does is that it's effectively padding out any one of these input examples And it's introducing a little bit of entropy And because of the padding out, it's actually kind of like a form of data augmentation. and it's kind of like augmenting the input a little bit and it's jittering it, and that makes it harder for the neural nuts to overfit these concrete specific examples. So by introducing all this noise, it actually pads out the examples, and it regularizes the neural nut And that's one of the reasons why deceivingly as a second-order effect, this is actually a regularizer, and that has made it harder for us to remove the use of batch normalization. because basically no one likes this property that the examples in the batch are coupled mathematically and in the forward pass.
so people have tried to deprecate the use of batch normalization and move to other normalization techniques that do not couple the examples of a batch (Examples are layer normalization, instance normalization, group normalization) and so on. but basically long story short batch normalization was the first kind of normalization layer to be introduced it worked extremely well it happened to have this regularizing effect it stabilized training and people have been trying to remove it and move to some of the other normalization techniques but it's been hard because it just works quite well and some of the reason that it works quite well is again because of this regularizing effect and because it is quite effective at controlling the activations and their distributions.