Table of Contents
- What will you gain by reading this
- Motivation
- How does the attention mechanism fit into the LLM architecture
- Why attention?
- Why the "self" in "self-attention"?
- How does a self-attention mechanism work?
- Why softmax?
- Generalising calculation of attention scores to all query tokens
- Bringing in the Q, K, V
- Why do we scale attention scores with square root of d_in
- Back to regular programming
- Generalising to calculation of context vectors for all queries
- What have we achieved and what is next
- Appendix A - Learning with LLM
- Appendix B - Why is
nn.Linear
better at initialising weights that lead to more stable model training
What will you gain by reading this
By the end of this article, you will understand the why and how of the softmax self-attention mechanism, i.e.:
This is the beginning of a series of articles on LLM that I will write about. I chose to begin with the self-attention mechanism because it "serves as the cornerstone of every LLM based on the transformer architecture" (Sebastian Raschka, Build a Large Language Model (From Scratch), section 3.3).
A large part of the content of this article is derived from Sebastian's Raschka's Build a Large Language Model (From Scratch). If you do not have access to his book, you can also watch his video here, or read his article here.
There is nothing new in this article that is not already covered by Sebastian's book, so if you prefer to read from the OG, please read his book directly (highly recommended). My value add is that I tried to explain the concepts in a way that makes intuitive sense for me, and I expanded on details which Sebastian might have decided to give a lighter treatment.
Motivation
With the rapid development of LLM in recent months, I thought a quick, fun and meaningful way to stay engaged and relevant would be to learn how to build the key components of LLM from scratch. I believe the best way to retain knowledge is to be hands on and document what I have learnt.
I began my journey deepening my understanding of LLMs by watching Andrej Karpathy's Deep Dive into LLM's like ChatGPT and reading Sebastian Raschka's Build a Large Language Model (From Scratch). In the process, I realised that LLMs were an excellent learning aid. Whenever I came across a term which I did not fully understand while reading Sebastian's book, I would seek LLM's help to explain, and I could quickly understand (at least I thought I did) not just the concepts, but also how the concepts fit into the bigger literature and problems the concept sought to solve.
Documenting my AI-assisted learning process is also an attempt to validate my hypothesis that with LLM, any concept can be quickly learnt even without much formal background in the topic. Seems foolhardy? I thought so too, but I thought no harm trying! This means when I try to read the original paper and implement the code, I am also seeking the help of LLM to understand the concepts, and even with the code. However, for the purpose of learning, I will personally go through every word and code written here, and only put down in writing what I understood.
If you see any errors, please let me know. I will fix these errors and also mention what was fixed, which I think is a way to document understanding and growth. If the topic of learning to learn with LLMs interests you, I have included a segment about this in Appendix A.
How does the attention mechanism fit into the LLM architecture
While the attention mechanism is a key component powering LLMs, it is only one of the many components that make up a LLM. Below is a diagram of the key components of the LLM. If the diagram looks too complicated, just know that the attention mechanism is a key component of LLM, and understanding how attention mechanism work is understanding a large part of the inner workings of a LLM. In fact, the original creators of the self-attention mechanism boldly claimed that attention is all you need. I will go through each of these other components in future articles.
Why attention?
Before the rise of LLMs, recurrent neural networks (RNNs) were the workhorses used for translations. RNNs were used to overcome the limitations of a word-by-word translation, so as to take into account differences in grammatical structures between both the original and translated languages. Without going to deep into how RNNs work, attention mechanism helped to overcome the key limitations of RNNs, i.e. RNNs cannot "directly access earlier hidden states from the encoder during the decoding phase", which could lead to a loss of context, especially problematic for complex sentences with references cross-cutting the sentences (source: Build a Large Language Model (From Scratch), chapter 3.1).
Self-attention overcame the limitations of RNNs by allowing "each position in the input sequence to consider the relevancy of, or attend to, all other positions in the same sequence when computing the representation of a sequence" (source: Building LLM from Scratch, chapter 3.2).
Reproduced below is the standard Softmax Attention. Scary looking? I know, I felt the same when I first saw it. I will unpack every bit of the equation slowly, explaining both the "why" and the "how".
Why the "self" in "self-attention"?
I think Sebastian explained it best in his book, so I shall quote him directly:
In self-attention, the “self” refers to the mechanism’s ability to compute attention weights by relating different positions within a single input sequence. It assesses and learns the relationships and dependencies between various parts of the input itself, such as words in a sentence or pixels in an image.
This is in contrast to traditional attention mechanisms, where the focus is on the relationships between elements of two different sequences, such as in sequence-to-sequence models where the attention might be between an input sequence and an output sequence...
To give a more concrete example, say we have a task to translate from English to French. The language model has already translated part of the input sentence to French, and now it needs to predict the next word.
Input sentence: "The animal didn't cross the street because it was too tired."
Output sentence: "L'animal n'a pas traversé la rue car il était trop ______?"
An attention mechanism will have an encoder that processes the input query which is in English, with vectors storing a key and value pair for each token in the input sentence. Think of keys as labels of each token in the input sentence, and values as content for each of the key. The decoder state is currently "L'animal...trop", which forms the query. We can then match the query from the decoder against each key in the encoder, and the key with the highest score can then be translated into French and appended to the query.
Whereas with a self-attention mechanism, there is only a single input, i.e.
Translate from English to French: "The animal didn't cross the street because it was too tired" French: "L'animal n'a pas traversé la rue car il était trop".
The LLM has one and only job, which is to predict the next token after "trop".
Below is a table summarising the differences between attention and self-attention.
Feature | Modern Self-Attention Model (Decoder-Only) | Older Attention Model (Encoder-Decoder) |
---|---|---|
Model Structure | A single, unified stack of Transformer blocks. | Two separate stacks: an Encoder and a Decoder. |
Input Format | One continuous sequence containing instructions, source, and target. | Two separate sequences fed into the Encoder and Decoder. |
Attention Scope | At any step, a token can attend to any previous token in the combined sequence (both source and target). | The decoder attends to its own previous tokens AND separately attends to the encoder's final output. |
How it "Links" | By learning patterns within the single combined sequence. | Through a dedicated "cross-attention" mechanism. |
How does a self-attention mechanism work?
For any text, LLM understands them as a list of tokens, and each token is understood as a vector. A token, unlike a word, is a sequence of characters that are grouped together as a useful semantic unit for processing. So the word don't
may be represented as tokens of do
and n't
. Sentences are broken into tokens by tokenisers, different LLMs may have different tokenisers. A vector is a list of numbers.
For example, given a text of "the quick brown fox jumps over", assuming each token is embedded as a vector of length 3, the text could be represented as tokens, and each token represented as an embedding. Taking token x(1) (i.e. the
) as an example, our aim is to calculate the attention weight of to weigh the importance of x(1) with all the other tokens. For example, a_12 is the attention weight of x(1) with respect to x(2). We can then compute the context vector z(1) for the token x(1) by computing as a combination of all input vectors weighted with respect to input element x(1). The method to compute this combination will become clearer in the next few sections.
We now put down what was in text, into code. Below is a python code to represent what is captured in the diagram:
import torch
inputs = torch.tensor(
[[0.3, 0.2, 0.9], # the (x^1)
[0.1, 0.5, 0.2], # quick (x^2)
[0.6, 0.4, 0.3], # brown (x^3)
[0.8, 0.4, 0.3], # fox (x^4)
[0.7, 0.2, 0.5], # jumps (x^5)
[0.9, 0.4, 0.7]] # over (x^6)
)
Assuming the
is the query token, we apply dot product of the
with each of the other tokens (including the
itself). A worked example of the dot product is shown in the first blue box. Because we are multiplying a 1x3 matrix with a 3x1 matrix for each pair of tokens embedding, we get a scalar at the end of the dot product.
query = inputs[0]
attn_scores = torch.empty(inputs.shape[0])
for i, x_i in enumerate(inputs):
attn_scores[i] = torch.dot(x_i, query)
print(attn_scores)
By applying dot product, we get the following attention scores. This attention score represents the relationship of each token with the query token the
. The computed attention score is below.
tensor([0.9400, 0.3100, 0.5300, 0.5900, 0.7000, 0.9800])
Source: Building LLM from Scratch, chapter 3.3
We can then normalise the attention scores to get attention weights, such that the sum of all attention weights sum up to 1, so that each attention weight represents a probability, which is useful for ease of interpretation and "maintaining training stability of LLM" (source: Building LLM from Scratch, chapter 3.3.1). The softmax normalisation is commonly used.
attn_weights = torch.softmax(attn_scores, dim=-1)
print("Attention weights:", attn_weights)
print("Sum:", attn_weights.sum())
The attention weights after applying softmax normalisation on the attention scores are below.
Attention weights: tensor([0.2115, 0.1126, 0.1404, 0.1490, 0.1664, 0.2201])
Sum: tensor(1.0000)
Why Softmax?
A natural question to ask is why not simply divide each attention score by the total sum of all scores to obtain the attention weight, i.e.
Instead of the more complicated softmax:
This is because softmax "is better at managing extreme values and offers more favorable gradient properties during training" (source: Building LLM from Scratch, 3.3.1). Sounds abstract, so let us work with a simple example to understand the implications:
Imagine there are three attention scores:
w_1: 10
w_2: 5
w_3: 1
With the first approach of normalisation, the attention weight of w_1 is: a_1 = 10 / (10 + 5 + 1) = 0.625 (or 62.5%)
The attention weight of w_1 is: a_1 = e^(w_1) / ( e^(w_1) + e^(w_2) + e^(w_3) ) = e^10 / (e^10 + e^5 + e^1) ≈ 22026.47 / 22177.60 ≈ 0.993185 (or 99.32%)
We can see that by using softmax for normalisation, the attention share of w_1 shoots up from 62.5% to more than 99%.
We can see that for w_1, not only is w_1 the largest number, its attention weight a_1 also occupies almost all of the softmax weights. This "winner-take-most" effect is often what we want in attention mechanisms – to really focus on the most important thing. In comparison, the attention weight a_3 drops from 0.0625 (6.25%) with a simple normalisation, to 0.000123 (or 0.012%) with softmax.
Another advantage with softmax is that when the values are negative, with softmax, each of the normalised values will be positive and all sum up nicely to 1. Whereas a negative value will still remain negative after normalisation, with a simple summation approach.
However, it appears that with softmax, we are trading off time complexity for better attention. As it turns out, not applying softmax could effectively bring down the time complexity of the self-attention mechanism from O(N^2) to O(N). I will explore this in future articles.
Generalising calculation of attention scores to all query tokens
In our previous example, we used the
as the query token. We can now generalise to have each of the token as the query token, and generate the lists of attention scores for each of these query tokens. The results of all attention scores for each of the token is below.
Instead of hand calculating the attention scores, we can use python to do the same with just two lines of code:
attn_scores = inputs @ inputs.T
print(attn_scores)
The resulting attention scores are below, which is consistent with what we have in the diagram above.
tensor([[0.9400, 0.3100, 0.5300, 0.5900, 0.7000, 0.9800],
[0.3100, 0.3000, 0.3200, 0.3400, 0.2700, 0.4300],
[0.5300, 0.3200, 0.6100, 0.7300, 0.6500, 0.9100],
[0.5900, 0.3400, 0.7300, 0.8900, 0.7900, 1.0900],
[0.7000, 0.2700, 0.6500, 0.7900, 0.7800, 1.0600],
[0.9800, 0.4300, 0.9100, 1.0900, 1.0600, 1.4600]])
Next, we apply softmax on each row. As we want to normalise the attention scores for each row, we set dim=1
, which refers to the last dimension of the attn_scores matrix.
attn_weights = torch.softmax(attn_scores, dim=-1)
print(attn_weights)
And we get the corresponding attention weights:
tensor([[0.2115, 0.1126, 0.1404, 0.1490, 0.1664, 0.2201],
[0.1634, 0.1618, 0.1651, 0.1684, 0.1570, 0.1843],
[0.1491, 0.1209, 0.1616, 0.1822, 0.1682, 0.2181],
[0.1399, 0.1089, 0.1609, 0.1888, 0.1708, 0.2306],
[0.1610, 0.1047, 0.1531, 0.1761, 0.1744, 0.2307],
[0.1581, 0.0912, 0.1474, 0.1765, 0.1713, 0.2555]])
We can verify that each of the rows sum up to 1:
row_sums = torch.sum(attn_weights, dim=-1)
print(row_sums)
The result confirms that each row sums to 1:
tensor([1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000])
In the final step, we use each of the attention weights to calculate the context vectors:
all_context_vecs = attn_weights @ inputs
print(all_context_vecs)
In the resulting output tensor, each row contains a three-dimensional context vector. [0.5927, 0.3357, 0.5370]
is the z_1 shown in one of the earlier diagrams. There are a total of 6 rows, each representing a context vector for each of the 6 tokens.
tensor([[0.5927, 0.3357, 0.5370], # the
[0.5747, 0.3521, 0.4870], # quick
[0.6135, 0.3486, 0.4983], # brown
[0.6276, 0.3487, 0.4995], # fox
[0.6212, 0.3434, 0.5133], # jumps
[0.6360, 0.3432, 0.5222]]) # over
Bringing in the Q, K, V
Now that we know how to derive the context vectors from the tokens embedding, we now introduce the concepts of Q, K and V. Q stands for query, K stands for key, and V stands for value. These terms were inspired by databases. When users send a query to a database, for keys that match, the corresponding values for each of the keys are returned.
Here's the kicker -- when we "train a LLM", we are tuning the values (also known as weights) in these matrices (along with other components such as the token embedding matrix, feed-forward network matrices, layer normalisation parameters, and final output layer matrix), and when we "download a LLM", we are downloading these matrices (along with the other parameter-containing components). You may want to relook the LLM architecture diagram mentioned in above under "How does the attention mechanism fit into the LLM architecture" of this article to understand what you are doing when you "download a LLM". The values in these matrices are known as weights because we weigh the attention scores with these values, and normalise these weighted values to obtain the attention weights.
Note that these "weights" matrices are not the same as "attention weights". The former are the result of training, whereas the latter are dynamically generated from prompts provided by users.
Each of the Q, K and V is a matrix of dimension d_in
by d_out
. d_in
corresponds to the length of each embedding token, in our example it is a value of 3. d_out
is usually of same value as d_in
, but for the purpose of learning, we set d_out
as 4, to disambiguate between d_in
and d_out
.
x = inputs[0]
d_in = inputs.shape[1]
d_out = 4
We now initialise the QVK matrices using pytorch:
torch.manual_seed(123)
W_query = torch.nn.Parameter(torch.rand(d_in, d_out), requires_grad=False)
W_key = torch.nn.Parameter(torch.rand(d_in, d_out), requires_grad=False)
W_value = torch.nn.Parameter(torch.rand(d_in, d_out), requires_grad=False)
We set requires_grad=False
for now as we are not training the LLM.
With x_1
representing the token the
, we now obtain the query, key and value vectors related to the
by multiplying the embedding of the
(dimension of 1 X 3) with each of Q, K, V matrices (each dimension of 3 X 4):
query = x @ W_query
key = x @ W_key
value = x @ W_value
print("query:", query)
print("key:", key)
print("value:", value)
As an example, the query
vector has the following value, which is of the correct dimension of 1 X 4.
query: tensor([0.2693, 0.9821, 0.3865, 0.8455])
key: tensor([0.8146, 0.7583, 0.7444, 1.1370])
value: tensor([0.9083, 0.8406, 1.0927, 0.7955])
Next, we calculate the keys and values by multiplying the tokens embedding of "the quick brown fox jumps over" with the K and V matrices, as below:
keys = inputs @ W_key
values = inputs @ W_value
print("keys.shape:", keys.shape)
print("values.shape:", values.shape)
The shapes of keys
and values
are as below. The first dimension is 6 because there are a total of 6 tokens, which the second dimension of 4 is consistent with the value of d_out
.
keys.shape: torch.Size([6, 4])
values.shape: torch.Size([6, 4])
Now, we can obtain the attention score of the query vector with each of the key vector with the following dot product:
attn_scores = query @ keys.T
print(attn_scores)
And we obtain the following attention scores with respect to the query the
:
tensor([2.2131, 1.2207, 1.5923, 1.7274, 1.7276, 2.5506])
Why do we scale attention scores with square root of d_in
Recalling the softmax attention equation mentioned at the very beginning of this article, we had this part of the equation, where we divide the the multiplication of query and keys by the square root of d_in
:
The purpose of this "scaled-dot product attention" is to improve training performance by avoiding small gradients. Going back to the example of the w values, where w_1 has the largest value, because of the "winner-takes-most" effect, w_1 dominated the share of attention weights with a_1 = 99.32%. If w_1 is the wrong choice, the gradient (i.e. signal to "fix the mistake") will be nearly zero, and the model will not learn.
Instead, we now scale the values by dividing by the square root of d_in
, and we assume the value of d_in
was 25, and the square root is therefore 5:
w_1: 10/5 = 2
w_2: 5/5 = 1
w_3: 1/5 = 0.2
Applying softmax, the share of w_1 is now e^2 / (e^2 + e^1 + e^0.2) = 0.652 (or 65.2%). As a result, most of the attention are still focused on w_1, but now more attention can also be paid to w_2 and w_3 to account for the possibility that w_1 could be wrong.
For sake of completion, if we had simply divide by d_in
instead of the square root, we get the following:
z_1: 10/25 = 0.5
z_2: 5/25 = 0.2
z_3: 1/25 = 0.04
Applying softmax, the share of w_1 decreased even further to 0.0422 (42.2%). The share of w_2 increased from 24.0% to 31.2%, which may split the attention between w_1 and w_2 and and reduced attention on w_1 by a too large degree such that not enough attention is placed on w_1 (When we should have given that it has the largest value).
Back to regular programming
Applying the concept of scaling by square root of d_in
, we can obtain the attention weights with the
as the query with the following:
d_k = keys.shape[-1]
attn_weights = torch.softmax(attn_scores / d_k**0.5, dim=-1)
print(attn_weights)
The resulting attention weights with respect to the
are:
tensor([0.1963, 0.1195, 0.1439, 0.1540, 0.1540, 0.2324])
Same as before I introduced QKV, we can now multiply the attention weights with the values (which are input embeddings weighted by the W matrix) to obtain the context vector for the
:
context_vec = attn_weights @ values
print(context_vec)
Output:
tensor([0.8516, 0.7803, 0.9675, 0.9944])
A diagram visualising all the operations shown above is summarised below.
Generalising to calculation of context vectors for all queries
We can now generalise the calculation of the context vectors of all tokens in the input, and by doing so, we are essentially implementing the equation that I originally set out to explain:
For those who like something more visual, the diagram below visualises what the formula is actually doing:
We can now bringe everything together and obtain the matrix Z of context vectors with the following python class SelfAttention
(source: Build a LLM from scratch, chapter 3.4). The following code initialises the QKV matrices with nn.Linear
using random values. An alternative to using nn.Linear
is nn.Parameter(torch.rand(...))
, and nn.Linear
is a better choice because it has an optimised weight initialisation scheme, contributing to more stable and effective model training. For more information of the advantage of nn.Linear
in initialising weights, check out Appendix B.
class SelfAttention(nn.Module):
def __init__(self, d_in, d_out, qkv_bias=False):
super().__init__()
self.W_query = nn.Linear(d_in, d_out, bias=qkv_bias)
self.W_key = nn.Linear(d_in, d_out, bias=qkv_bias)
self.W_value = nn.Linear(d_in, d_out, bias=qkv_bias)
def forward(self, x):
keys = self.W_key(x)
queries = self.W_query(x)
values = self.W_value(x)
attn_scores = queries @ keys.T
attn_weights = torch.softmax(
attn_scores / keys.shape[-1]**0.5, dim=-1
)
context_vec = attn_weights @ values
return context_vec
torch.manual_seed(123)
sa = SelfAttention(d_in, d_out)
print(sa(inputs))
Output containing the matrix Z of context vectors is below. Note that the values are negative because these represent the context vectors, not the attention weights.
tensor([[-0.2117, 0.1381, 0.5026, 0.2667],
[-0.2066, 0.1430, 0.4978, 0.2678],
[-0.2092, 0.1417, 0.5006, 0.2677],
[-0.2098, 0.1417, 0.5015, 0.2679],
[-0.2113, 0.1390, 0.5024, 0.2670],
[-0.2119, 0.1408, 0.5038, 0.2679]], grad_fn=<MmBackward0>)
What have we achieved and what is next
I have unpacked all the details going on in the softmax self-attention mechanism, explaining:
- why do we use attention
- what is "self" in "self-attention"
- Why use softmax normalisation
- How to bring together the Q, K, V matrices, and the input matrix, to caclulate the attention weights, and finally the context vectors
- Why do we scale the dot product matrix by the square root of d_in
Up next, I am keen to explore the following topics:
- Low-rank adaptation (popularly known as LoRA)
- Explore the recent innovations in efficient kernel-based attentions, starting with linear attentions, studying the trade-off between quality of response and performance
- Google's paper on Titans: learning to memorise at test time
If you do happen to have any cool topics in mind and would like me to explore as well, please let me know! Who knows, maybe we can even explore these topics together.
Appendix A - Learning with LLM
This is more of a meta-point about how LLM helped to speed up my learning. Whenever I come across any concept or mathematical proofs which was skipped over, I found myself reaching out to Google Gemini to explain what these missing details. Compared to looking around for the appendix of the book for the explanations, or other literature as references, the immediate answers provided by LLM helped me to maintain my train of thought and immediate move on to the next paragraph. I also saved significant time from flipping through pages without guarantee of finding the answers I was looking for.
In addition, LLM seems to have the ability to synthesise and explain concepts much simpler and clearer than I could otherwise get from sources such as StackOverflow. For example, say I am interested to know the difference between max
and argmax
, and here is an answer which I found on StackExchange:
argmax returns a value from the domain of the function and max returns from the range of the function?
source
If I still do not understand (or I have missed a simpler explanation somewhere in the webpage), I could instead ask LLM to explain like I'm 5:
what is the difference between max and argmax? ELI5
And below is the reply I received from Google Gemini Pro 2.5, which subjectively is a much easier read for me.
Okay, imagine you have a group of friends and you want to find out two things about their heights:
Max (Maximum):
- This is like asking: "What is the tallest height anyone in this group is?"
- You'd measure everyone, look at all the height numbers (e.g., 4 feet, 5 feet, 4.5 feet), and pick the biggest number.
- Answer: "5 feet!" (This is the value of the height).
Argmax (Argument of the Maximum):
- This is like asking: "WHO is the tallest person?"
- You'd still find the tallest height (5 feet), but then you'd point to the person who is that tall.
- Answer: "It's Sarah!" (This is the person or item that has the maximum value).
So, in short:
- Max gives you the highest value itself.
- Argmax tells you which item (or which position in a list) has that highest value.
Example with numbers:
Let's say you have a list of scores from a game:
[10, 50, 20]
(Player A scored 10, Player B scored 50, Player C scored 20)
- Max: The highest score is 50.
- Argmax: The player who got the highest score is Player B (or if we use positions like in computers (0, 1, 2), it's position 1).
I have also been using Google Gemini as my personal math tutor, and it has been amazing in helping me to fill in the gaps of mathematics proofs which have been skipped over because they were deemed trivial by the authors. LLM trained with reinforcement learning do appear to be ideal at tackling mathematics proofs where the answer can be validated. However, I think some care needs to be taken to ensure that LLM does not replace my habit of thinking, or at least I would try to attempt the proof myself after having seen the proof came up by LLM.
Appendix B - Why is nn.Linear
better at initialising weights that lead to more stable model training
(note: the following explanation is contributed by Google Gemini using AI Studio)
In model training, the start state of the QKV matrices has significant impact on the training effectiveness.
Say we want our initialisation weights to roughly model after the characteristic of input data. For example, if the input data has a standard deviation of roughly 1.0, we therefore would also like our initialised data to also have a standard deviation of about 1.0.
We can create a random set of values with standard deviation using nn.randn
, which works:
import torch
import torch.nn as nn
import math
# Input data: 1000 examples, 10 features each
d_in = 10
x = torch.randn(1000, d_in)
print(f"Input mean: {x.mean():.4f}, Input std dev: {x.std():.4f}")
# Expected output: Input mean: ~0.0000, Input std dev: ~1.0000
However, if we now apply nn.rand
in the context of matrix multiplication (which is what we require in the attention mechanism), we run into trouble:
# Create a weight matrix with the "guessing" method
d_out = 20
W_rand = torch.rand(d_in, d_out)
# Perform the matrix multiplication
output_rand = x @ W_rand
print(f"Using torch.rand:")
print(f" - Weight matrix std dev: {W_rand.std():.4f}")
print(f" - Output mean: {output_rand.mean():.4f}")
print(f" - Output std dev: {output_rand.std():.4f} <-- Problematic!")
Output is below.
Using torch.rand:
- Weight matrix std dev: 0.2883
- Output mean: -0.0508
- Output std dev: 1.7807 <-- Problematic!
We can see that as the matrices scale in size, the "drift" in the standard deviation becomes even more pronounced. This is a problem as modern day LLMs have very large matrices with billions of parameters:
d_in_large = 512 # A more realistic input size for a transformer model
x_large = torch.randn(1000, d_in_large)
W_rand_large = torch.rand(d_in_large, d_out)
output_rand_large = x_large @ W_rand_large
print(f"\nUsing torch.rand with a larger input size ({d_in_large}):")
print(f" - Input std dev: {x_large.std():.4f}")
print(f" - Output std dev: {output_rand_large.std():.4f} <-- Exploding!")
Output:
Using torch.rand with a larger input size (512):
- Input std dev: 1.0011
- Output std dev: 13.3220 <-- Exploding!
To address the problem of the "exploding" standard deviation, nn.Linear
uses a smart initialisation called "Kaiming He initialisation"), paper available here. Essentially, it scales the initial random weights based on the number of inputs, i.e.
# Create a weight matrix with the "smart" method
linear_layer = nn.Linear(d_in, d_out, bias=False)
# Re-initialize the weights in-place using Kaiming Normal.
# This changes the weight distribution to what we need.
nn.init.kaiming_normal_(linear_layer.weight, a=1)
W_linear = linear_layer.weight.T # Get the weights in the same (d_in, d_out) shape
# Perform the matrix multiplication
output_linear = x @ W_linear
# Let's check the math: 1 / sqrt(10) is about 0.316
print(f"\nUsing nn.Linear (d_in={d_in}):")
print(f" - Target std dev for weights: {1 / math.sqrt(d_in):.4f}")
print(f" - Actual weight matrix std dev: {W_linear.std():.4f}")
print(f" - Output mean: {output_linear.mean():.4f}")
print(f" - Output std dev: {output_linear.std():.4f} <-- Just right!")
Output:
Using nn.Linear (d_in=10):
- Target std dev for weights: 0.3162
- Actual weight matrix std dev: 0.3058
- Output mean: 0.0010
- Output std dev: 0.9805 <-- Just right!
With larger input sizes:
linear_layer_large = nn.Linear(d_in_large, d_out, bias=False)
# Re-initialize the weights in-place using Kaiming Normal.
# This changes the weight distribution to what we need.
nn.init.kaiming_normal_(linear_layer_large.weight, a=1)
W_linear_large = linear_layer_large.weight.T
output_linear_large = x_large @ W_linear_large
print(f"\nUsing nn.Linear with a larger input size ({d_in_large}):")
print(f" - Target std dev for weights: {1 / math.sqrt(d_in_large):.4f}")
print(f" - Actual weight matrix std dev: {W_linear_large.std():.4f}")
print(f" - Output std dev: {output_linear_large.std():.4f} <-- Still just right!")
output:
Using nn.Linear with a larger input size (512):
- Target std dev for weights: 0.0442
- Actual weight matrix std dev: 0.0438
- Output std dev: 0.9932 <-- Still just right!