This is my notes after watching Karpathy's Let's build GPT: from scratch, in code, spelled out.lecture. A while ago, when I looked at images like the following that try to explain how a token has more attention to other tokens, I could not quite follow. after watching this video, and executing the code, I think, I am better now.
When I was reading docs about attention, and transformers, I had many doubts:
Did this lecture answer all my questions? I don't know, but, I am at least better than before.
Each token in the sequence has its own query, key, and value vectors.
import torch
from torch.nn import functional as F
import torch.nn as nn
torch.manual_seed(1337)
B, T, C = 4, 3, 5 # batch, time, channels
x = torch.randn(B, T, C)
head_size = 6
# What am i looking for?
query = nn.Linear(in_features=C, out_features=head_size, bias=False)
# What do I have ?
key = nn.Linear(in_features=C, out_features=head_size, bias=False)
# What I can share?
value = nn.Linear(in_features=C, out_features=head_size, bias=False)
k = key(x)
q = query(x)
v = value(x)
wei = q @ k.transpose(-2, -1) * head_size**-0.5 # Scaled dot product
tril = torch.tril(torch.ones(T, T))
wei = wei.masked_fill(tril == 0, float('-inf'))
wei = F.softmax(wei, dim=-1)
out = wei @ v
I asked Claude to explain this code to a kid, and this is how it responded:
Alright, let's try to explain this in a simpler way, imagining we're talking about a classroom of kids:
x
). There are 4 groups of kids (B=4
), each group has 3 kids (T=3
), and each kid knows 5 things about themselves (C=5
).head_size
=6).tril
and masked_fill
part).out
) is what each kid learned from this game, based on what they were looking for and what the other kids before them could share.This game helps each kid learn things that are most relevant to what they're looking for, but only from the kids who came before them in line. It's a way of mixing and matching information in a smart way!
The v
matrix ("value" matrix) represents the actual information or content from each position that we want to gather or aggregate. When we do wei @ v
, we're essentially doing a weighted sum of the values, where the weights are determined by our attention mechanism.
Here's why we need the v
matrix:
**Information content:
q
(query) and k
(key) are used to calculate "where" to pay attention, they don't contain the actual information we want to gather. v
contains the information we actually want to aggregate.Separation of concerns:
q
is like you asking "Who knows about cooking?" (your query)k
is like each person's expertise (their keys)wei
(after softmax) is like you figuring out who to listen to about cookingv
is the actual cooking tips each person haswei @ v
is like you gathering cooking tips, paying more attention to the people you determined know more about cookingWithout v
, we would only know where to pay attention, but we wouldn't have any information to gather from those attention points. The v
matrix gives us the actual content to aggregate based on our calculated attention weights.
In summary, wei @ v
allows us to gather information (v
) from different positions in our input, weighted by how relevant we've determined each position to be (wei
). This is the core mechanism that allows attention to selectively combine information from across the input sequence.