Creating the self-attention mechanism from scratch
Lewis Won

Lewis Won @lewis_won

About: A software engineer with budding interest in knowledge graphs

Location:
Singapore
Joined:
Nov 30, 2024

Creating the self-attention mechanism from scratch

Publish Date: Jun 8
2 0

Table of Contents

What will you gain by reading this

By the end of this article, you will understand the why and how of the softmax self-attention mechanism, i.e.:

Attention(Q,K,V)=softmax(QKTdin)V \text{Attention}(Q, K, V) = \text{softmax}\left(\frac{QK^T}{\sqrt{d_{in}}}\right)V

This is the beginning of a series of articles on LLM that I will write about. I chose to begin with the self-attention mechanism because it "serves as the cornerstone of every LLM based on the transformer architecture" (Sebastian Raschka, Build a Large Language Model (From Scratch), section 3.3).

A large part of the content of this article is derived from Sebastian's Raschka's Build a Large Language Model (From Scratch). If you do not have access to his book, you can also watch his video here, or read his article here.

There is nothing new in this article that is not already covered by Sebastian's book, so if you prefer to read from the OG, please read his book directly (highly recommended). My value add is that I tried to explain the concepts in a way that makes intuitive sense for me, and I expanded on details which Sebastian might have decided to give a lighter treatment.

Motivation

With the rapid development of LLM in recent months, I thought a quick, fun and meaningful way to stay engaged and relevant would be to learn how to build the key components of LLM from scratch. I believe the best way to retain knowledge is to be hands on and document what I have learnt.

I began my journey deepening my understanding of LLMs by watching Andrej Karpathy's Deep Dive into LLM's like ChatGPT and reading Sebastian Raschka's Build a Large Language Model (From Scratch). In the process, I realised that LLMs were an excellent learning aid. Whenever I came across a term which I did not fully understand while reading Sebastian's book, I would seek LLM's help to explain, and I could quickly understand (at least I thought I did) not just the concepts, but also how the concepts fit into the bigger literature and problems the concept sought to solve.

Documenting my AI-assisted learning process is also an attempt to validate my hypothesis that with LLM, any concept can be quickly learnt even without much formal background in the topic. Seems foolhardy? I thought so too, but I thought no harm trying! This means when I try to read the original paper and implement the code, I am also seeking the help of LLM to understand the concepts, and even with the code. However, for the purpose of learning, I will personally go through every word and code written here, and only put down in writing what I understood.

If you see any errors, please let me know. I will fix these errors and also mention what was fixed, which I think is a way to document understanding and growth. If the topic of learning to learn with LLMs interests you, I have included a segment about this in Appendix A.

How does the attention mechanism fit into the LLM architecture

While the attention mechanism is a key component powering LLMs, it is only one of the many components that make up a LLM. Below is a diagram of the key components of the LLM. If the diagram looks too complicated, just know that the attention mechanism is a key component of LLM, and understanding how attention mechanism work is understanding a large part of the inner workings of a LLM. In fact, the original creators of the self-attention mechanism boldly claimed that attention is all you need. I will go through each of these other components in future articles.

Architecture diagram of a transformer-based LLM

Why attention?

Before the rise of LLMs, recurrent neural networks (RNNs) were the workhorses used for translations. RNNs were used to overcome the limitations of a word-by-word translation, so as to take into account differences in grammatical structures between both the original and translated languages. Without going to deep into how RNNs work, attention mechanism helped to overcome the key limitations of RNNs, i.e. RNNs cannot "directly access earlier hidden states from the encoder during the decoding phase", which could lead to a loss of context, especially problematic for complex sentences with references cross-cutting the sentences (source: Build a Large Language Model (From Scratch), chapter 3.1).

Self-attention overcame the limitations of RNNs by allowing "each position in the input sequence to consider the relevancy of, or attend to, all other positions in the same sequence when computing the representation of a sequence" (source: Building LLM from Scratch, chapter 3.2).

Reproduced below is the standard Softmax Attention. Scary looking? I know, I felt the same when I first saw it. I will unpack every bit of the equation slowly, explaining both the "why" and the "how".

Attention(Q,K,V)=softmax(QKTdk)V \text{Attention}(Q, K, V) = \text{softmax}\left(\frac{QK^T}{\sqrt{d_k}}\right)V

Why the "self" in "self-attention"?

I think Sebastian explained it best in his book, so I shall quote him directly:

In self-attention, the “self” refers to the mechanism’s ability to compute attention weights by relating different positions within a single input sequence. It assesses and learns the relationships and dependencies between various parts of the input itself, such as words in a sentence or pixels in an image.

This is in contrast to traditional attention mechanisms, where the focus is on the relationships between elements of two different sequences, such as in sequence-to-sequence models where the attention might be between an input sequence and an output sequence...

To give a more concrete example, say we have a task to translate from English to French. The language model has already translated part of the input sentence to French, and now it needs to predict the next word.

Input sentence: "The animal didn't cross the street because it was too tired."
Output sentence: "L'animal n'a pas traversé la rue car il était trop ______?"

An attention mechanism will have an encoder that processes the input query which is in English, with vectors storing a key and value pair for each token in the input sentence. Think of keys as labels of each token in the input sentence, and values as content for each of the key. The decoder state is currently "L'animal...trop", which forms the query. We can then match the query from the decoder against each key in the encoder, and the key with the highest score can then be translated into French and appended to the query.

Whereas with a self-attention mechanism, there is only a single input, i.e.

Translate from English to French: "The animal didn't cross the street because it was too tired" French: "L'animal n'a pas traversé la rue car il était trop".

The LLM has one and only job, which is to predict the next token after "trop".

Below is a table summarising the differences between attention and self-attention.

Feature Modern Self-Attention Model (Decoder-Only) Older Attention Model (Encoder-Decoder)
Model Structure A single, unified stack of Transformer blocks. Two separate stacks: an Encoder and a Decoder.
Input Format One continuous sequence containing instructions, source, and target. Two separate sequences fed into the Encoder and Decoder.
Attention Scope At any step, a token can attend to any previous token in the combined sequence (both source and target). The decoder attends to its own previous tokens AND separately attends to the encoder's final output.
How it "Links" By learning patterns within the single combined sequence. Through a dedicated "cross-attention" mechanism.

How does a self-attention mechanism work?

For any text, LLM understands them as a list of tokens, and each token is understood as a vector. A token, unlike a word, is a sequence of characters that are grouped together as a useful semantic unit for processing. So the word don't may be represented as tokens of do and n't. Sentences are broken into tokens by tokenisers, different LLMs may have different tokenisers. A vector is a list of numbers.

For example, given a text of "the quick brown fox jumps over", assuming each token is embedded as a vector of length 3, the text could be represented as tokens, and each token represented as an embedding. Taking token x(1) (i.e. the) as an example, our aim is to calculate the attention weight of to weigh the importance of x(1) with all the other tokens. For example, a_12 is the attention weight of x(1) with respect to x(2). We can then compute the context vector z(1) for the token x(1) by computing as a combination of all input vectors weighted with respect to input element x(1). The method to compute this combination will become clearer in the next few sections.

From token embeddings to context vector

We now put down what was in text, into code. Below is a python code to represent what is captured in the diagram:

import torch
inputs = torch.tensor(
  [[0.3, 0.2, 0.9], # the     (x^1)
   [0.1, 0.5, 0.2], # quick   (x^2)
   [0.6, 0.4, 0.3], # brown   (x^3)
   [0.8, 0.4, 0.3], # fox     (x^4)
   [0.7, 0.2, 0.5], # jumps   (x^5)
   [0.9, 0.4, 0.7]] # over    (x^6)
)
Enter fullscreen mode Exit fullscreen mode

Assuming the is the query token, we apply dot product of the with each of the other tokens (including the itself). A worked example of the dot product is shown in the first blue box. Because we are multiplying a 1x3 matrix with a 3x1 matrix for each pair of tokens embedding, we get a scalar at the end of the dot product.

Dot product to get attention score

query = inputs[0]                           
attn_scores = torch.empty(inputs.shape[0])
for i, x_i in enumerate(inputs):
    attn_scores[i] = torch.dot(x_i, query)
print(attn_scores)
Enter fullscreen mode Exit fullscreen mode

By applying dot product, we get the following attention scores. This attention score represents the relationship of each token with the query token the. The computed attention score is below.

tensor([0.9400, 0.3100, 0.5300, 0.5900, 0.7000, 0.9800])
Enter fullscreen mode Exit fullscreen mode

Source: Building LLM from Scratch, chapter 3.3

We can then normalise the attention scores to get attention weights, such that the sum of all attention weights sum up to 1, so that each attention weight represents a probability, which is useful for ease of interpretation and "maintaining training stability of LLM" (source: Building LLM from Scratch, chapter 3.3.1). The softmax normalisation is commonly used.

attn_weights = torch.softmax(attn_scores, dim=-1)
print("Attention weights:", attn_weights)
print("Sum:", attn_weights.sum())
Enter fullscreen mode Exit fullscreen mode

The attention weights after applying softmax normalisation on the attention scores are below.

Attention weights: tensor([0.2115, 0.1126, 0.1404, 0.1490, 0.1664, 0.2201])
Sum: tensor(1.0000)
Enter fullscreen mode Exit fullscreen mode

Why Softmax?

A natural question to ask is why not simply divide each attention score by the total sum of all scores to obtain the attention weight, i.e.

attentionWeight1=attentionScore1i=1nattentionScorei attentionWeight_1 = \frac{attentionScore_1}{\sum_{i=1}^{n} attentionScore_i}

Instead of the more complicated softmax:

attentionWeight1=eattentionScore1i=1neattentionScorei attentionWeight_1 = \frac{e^{attentionScore_1}}{\sum_{i=1}^{n} e^{attentionScore_i}}

This is because softmax "is better at managing extreme values and offers more favorable gradient properties during training" (source: Building LLM from Scratch, 3.3.1). Sounds abstract, so let us work with a simple example to understand the implications:

Imagine there are three attention scores:
w_1: 10
w_2: 5
w_3: 1

With the first approach of normalisation, the attention weight of w_1 is: a_1 = 10 / (10 + 5 + 1) = 0.625 (or 62.5%)

The attention weight of w_1 is: a_1 = e^(w_1) / ( e^(w_1) + e^(w_2) + e^(w_3) ) = e^10 / (e^10 + e^5 + e^1) ≈ 22026.47 / 22177.60 ≈ 0.993185 (or 99.32%)

We can see that by using softmax for normalisation, the attention share of w_1 shoots up from 62.5% to more than 99%.

We can see that for w_1, not only is w_1 the largest number, its attention weight a_1 also occupies almost all of the softmax weights. This "winner-take-most" effect is often what we want in attention mechanisms – to really focus on the most important thing. In comparison, the attention weight a_3 drops from 0.0625 (6.25%) with a simple normalisation, to 0.000123 (or 0.012%) with softmax.

Another advantage with softmax is that when the values are negative, with softmax, each of the normalised values will be positive and all sum up nicely to 1. Whereas a negative value will still remain negative after normalisation, with a simple summation approach.

However, it appears that with softmax, we are trading off time complexity for better attention. As it turns out, not applying softmax could effectively bring down the time complexity of the self-attention mechanism from O(N^2) to O(N). I will explore this in future articles.

Generalising calculation of attention scores to all query tokens

In our previous example, we used the as the query token. We can now generalise to have each of the token as the query token, and generate the lists of attention scores for each of these query tokens. The results of all attention scores for each of the token is below.

Generalising calculation of attention scores to all tokens

Instead of hand calculating the attention scores, we can use python to do the same with just two lines of code:

attn_scores = inputs @ inputs.T
print(attn_scores)
Enter fullscreen mode Exit fullscreen mode

The resulting attention scores are below, which is consistent with what we have in the diagram above.

tensor([[0.9400, 0.3100, 0.5300, 0.5900, 0.7000, 0.9800],
        [0.3100, 0.3000, 0.3200, 0.3400, 0.2700, 0.4300],
        [0.5300, 0.3200, 0.6100, 0.7300, 0.6500, 0.9100],
        [0.5900, 0.3400, 0.7300, 0.8900, 0.7900, 1.0900],
        [0.7000, 0.2700, 0.6500, 0.7900, 0.7800, 1.0600],
        [0.9800, 0.4300, 0.9100, 1.0900, 1.0600, 1.4600]])
Enter fullscreen mode Exit fullscreen mode

Next, we apply softmax on each row. As we want to normalise the attention scores for each row, we set dim=1, which refers to the last dimension of the attn_scores matrix.

attn_weights = torch.softmax(attn_scores, dim=-1)
print(attn_weights)
Enter fullscreen mode Exit fullscreen mode

And we get the corresponding attention weights:

tensor([[0.2115, 0.1126, 0.1404, 0.1490, 0.1664, 0.2201],
        [0.1634, 0.1618, 0.1651, 0.1684, 0.1570, 0.1843],
        [0.1491, 0.1209, 0.1616, 0.1822, 0.1682, 0.2181],
        [0.1399, 0.1089, 0.1609, 0.1888, 0.1708, 0.2306],
        [0.1610, 0.1047, 0.1531, 0.1761, 0.1744, 0.2307],
        [0.1581, 0.0912, 0.1474, 0.1765, 0.1713, 0.2555]])
Enter fullscreen mode Exit fullscreen mode

We can verify that each of the rows sum up to 1:

row_sums = torch.sum(attn_weights, dim=-1)
print(row_sums)
Enter fullscreen mode Exit fullscreen mode

The result confirms that each row sums to 1:

tensor([1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000])
Enter fullscreen mode Exit fullscreen mode

In the final step, we use each of the attention weights to calculate the context vectors:

all_context_vecs = attn_weights @ inputs
print(all_context_vecs)
Enter fullscreen mode Exit fullscreen mode

In the resulting output tensor, each row contains a three-dimensional context vector. [0.5927, 0.3357, 0.5370] is the z_1 shown in one of the earlier diagrams. There are a total of 6 rows, each representing a context vector for each of the 6 tokens.

tensor([[0.5927, 0.3357, 0.5370],   # the
        [0.5747, 0.3521, 0.4870],   # quick
        [0.6135, 0.3486, 0.4983],   # brown
        [0.6276, 0.3487, 0.4995],   # fox
        [0.6212, 0.3434, 0.5133],   # jumps
        [0.6360, 0.3432, 0.5222]])  # over
Enter fullscreen mode Exit fullscreen mode

Bringing in the Q, K, V

Now that we know how to derive the context vectors from the tokens embedding, we now introduce the concepts of Q, K and V. Q stands for query, K stands for key, and V stands for value. These terms were inspired by databases. When users send a query to a database, for keys that match, the corresponding values for each of the keys are returned.

Here's the kicker -- when we "train a LLM", we are tuning the values (also known as weights) in these matrices (along with other components such as the token embedding matrix, feed-forward network matrices, layer normalisation parameters, and final output layer matrix), and when we "download a LLM", we are downloading these matrices (along with the other parameter-containing components). You may want to relook the LLM architecture diagram mentioned in above under "How does the attention mechanism fit into the LLM architecture" of this article to understand what you are doing when you "download a LLM". The values in these matrices are known as weights because we weigh the attention scores with these values, and normalise these weighted values to obtain the attention weights.

Note that these "weights" matrices are not the same as "attention weights". The former are the result of training, whereas the latter are dynamically generated from prompts provided by users.

Each of the Q, K and V is a matrix of dimension d_in by d_out. d_in corresponds to the length of each embedding token, in our example it is a value of 3. d_out is usually of same value as d_in, but for the purpose of learning, we set d_out as 4, to disambiguate between d_in and d_out.

x = inputs[0]     
d_in = inputs.shape[1]     
d_out = 4         
Enter fullscreen mode Exit fullscreen mode

We now initialise the QVK matrices using pytorch:

torch.manual_seed(123)
W_query = torch.nn.Parameter(torch.rand(d_in, d_out), requires_grad=False)
W_key   = torch.nn.Parameter(torch.rand(d_in, d_out), requires_grad=False)
W_value = torch.nn.Parameter(torch.rand(d_in, d_out), requires_grad=False)
Enter fullscreen mode Exit fullscreen mode

We set requires_grad=False for now as we are not training the LLM.

With x_1 representing the token the, we now obtain the query, key and value vectors related to the by multiplying the embedding of the (dimension of 1 X 3) with each of Q, K, V matrices (each dimension of 3 X 4):

query = x @ W_query
key = x @ W_key
value = x @ W_value
print("query:", query)
print("key:", key)
print("value:", value)
Enter fullscreen mode Exit fullscreen mode

As an example, the query vector has the following value, which is of the correct dimension of 1 X 4.

query: tensor([0.2693, 0.9821, 0.3865, 0.8455])
key: tensor([0.8146, 0.7583, 0.7444, 1.1370])
value: tensor([0.9083, 0.8406, 1.0927, 0.7955])
Enter fullscreen mode Exit fullscreen mode

Next, we calculate the keys and values by multiplying the tokens embedding of "the quick brown fox jumps over" with the K and V matrices, as below:

keys = inputs @ W_key 
values = inputs @ W_value
print("keys.shape:", keys.shape)
print("values.shape:", values.shape)
Enter fullscreen mode Exit fullscreen mode

The shapes of keys and values are as below. The first dimension is 6 because there are a total of 6 tokens, which the second dimension of 4 is consistent with the value of d_out.

keys.shape: torch.Size([6, 4])
values.shape: torch.Size([6, 4])
Enter fullscreen mode Exit fullscreen mode

Now, we can obtain the attention score of the query vector with each of the key vector with the following dot product:

attn_scores = query @ keys.T     
print(attn_scores)
Enter fullscreen mode Exit fullscreen mode

And we obtain the following attention scores with respect to the query the:

tensor([2.2131, 1.2207, 1.5923, 1.7274, 1.7276, 2.5506])
Enter fullscreen mode Exit fullscreen mode

Why do we scale attention scores with square root of d_in

Recalling the softmax attention equation mentioned at the very beginning of this article, we had this part of the equation, where we divide the the multiplication of query and keys by the square root of d_in:

Attention(Q,K,V)=softmax(QKTdk)V \text{Attention}(Q, K, V) = \text{softmax}\left(\frac{QK^T}{\sqrt{d_k}}\right)V

The purpose of this "scaled-dot product attention" is to improve training performance by avoiding small gradients. Going back to the example of the w values, where w_1 has the largest value, because of the "winner-takes-most" effect, w_1 dominated the share of attention weights with a_1 = 99.32%. If w_1 is the wrong choice, the gradient (i.e. signal to "fix the mistake") will be nearly zero, and the model will not learn.

Instead, we now scale the values by dividing by the square root of d_in, and we assume the value of d_in was 25, and the square root is therefore 5:

w_1: 10/5 = 2
w_2: 5/5 = 1
w_3: 1/5 = 0.2

Applying softmax, the share of w_1 is now e^2 / (e^2 + e^1 + e^0.2) = 0.652 (or 65.2%). As a result, most of the attention are still focused on w_1, but now more attention can also be paid to w_2 and w_3 to account for the possibility that w_1 could be wrong.

For sake of completion, if we had simply divide by d_in instead of the square root, we get the following:

z_1: 10/25 = 0.5
z_2: 5/25 = 0.2
z_3: 1/25 = 0.04

Applying softmax, the share of w_1 decreased even further to 0.0422 (42.2%). The share of w_2 increased from 24.0% to 31.2%, which may split the attention between w_1 and w_2 and and reduced attention on w_1 by a too large degree such that not enough attention is placed on w_1 (When we should have given that it has the largest value).

Back to regular programming

Applying the concept of scaling by square root of d_in, we can obtain the attention weights with the as the query with the following:

d_k = keys.shape[-1]
attn_weights = torch.softmax(attn_scores / d_k**0.5, dim=-1)
print(attn_weights)
Enter fullscreen mode Exit fullscreen mode

The resulting attention weights with respect to the are:

tensor([0.1963, 0.1195, 0.1439, 0.1540, 0.1540, 0.2324])
Enter fullscreen mode Exit fullscreen mode

Same as before I introduced QKV, we can now multiply the attention weights with the values (which are input embeddings weighted by the W matrix) to obtain the context vector for the:

context_vec = attn_weights @ values
print(context_vec)
Enter fullscreen mode Exit fullscreen mode

Output:

tensor([0.8516, 0.7803, 0.9675, 0.9944])
Enter fullscreen mode Exit fullscreen mode

A diagram visualising all the operations shown above is summarised below.

From attention scores to attention weights to softmax

Generalising to calculation of context vectors for all queries

We can now generalise the calculation of the context vectors of all tokens in the input, and by doing so, we are essentially implementing the equation that I originally set out to explain:

Attention(Q,K,V)=softmax(QKTdk)V \text{Attention}(Q, K, V) = \text{softmax}\left(\frac{QK^T}{\sqrt{d_k}}\right)V

For those who like something more visual, the diagram below visualises what the formula is actually doing:

Visual diagram of the softmax self-attention mechanism

We can now bringe everything together and obtain the matrix Z of context vectors with the following python class SelfAttention (source: Build a LLM from scratch, chapter 3.4). The following code initialises the QKV matrices with nn.Linear using random values. An alternative to using nn.Linear is nn.Parameter(torch.rand(...)), and nn.Linear is a better choice because it has an optimised weight initialisation scheme, contributing to more stable and effective model training. For more information of the advantage of nn.Linear in initialising weights, check out Appendix B.

class SelfAttention(nn.Module):
    def __init__(self, d_in, d_out, qkv_bias=False):
        super().__init__()
        self.W_query = nn.Linear(d_in, d_out, bias=qkv_bias)
        self.W_key   = nn.Linear(d_in, d_out, bias=qkv_bias)
        self.W_value = nn.Linear(d_in, d_out, bias=qkv_bias)

    def forward(self, x):
        keys = self.W_key(x)
        queries = self.W_query(x)
        values = self.W_value(x)
        attn_scores = queries @ keys.T
        attn_weights = torch.softmax(
            attn_scores / keys.shape[-1]**0.5, dim=-1
        )
        context_vec = attn_weights @ values
        return context_vec

torch.manual_seed(123)
sa = SelfAttention(d_in, d_out)
print(sa(inputs))
Enter fullscreen mode Exit fullscreen mode

Output containing the matrix Z of context vectors is below. Note that the values are negative because these represent the context vectors, not the attention weights.

tensor([[-0.2117,  0.1381,  0.5026,  0.2667],
        [-0.2066,  0.1430,  0.4978,  0.2678],
        [-0.2092,  0.1417,  0.5006,  0.2677],
        [-0.2098,  0.1417,  0.5015,  0.2679],
        [-0.2113,  0.1390,  0.5024,  0.2670],
        [-0.2119,  0.1408,  0.5038,  0.2679]], grad_fn=<MmBackward0>)
Enter fullscreen mode Exit fullscreen mode

What have we achieved and what is next

I have unpacked all the details going on in the softmax self-attention mechanism, explaining:

  • why do we use attention
  • what is "self" in "self-attention"
  • Why use softmax normalisation
  • How to bring together the Q, K, V matrices, and the input matrix, to caclulate the attention weights, and finally the context vectors
  • Why do we scale the dot product matrix by the square root of d_in

Up next, I am keen to explore the following topics:

  • Low-rank adaptation (popularly known as LoRA)
  • Explore the recent innovations in efficient kernel-based attentions, starting with linear attentions, studying the trade-off between quality of response and performance
  • Google's paper on Titans: learning to memorise at test time

If you do happen to have any cool topics in mind and would like me to explore as well, please let me know! Who knows, maybe we can even explore these topics together.

Appendix A - Learning with LLM

This is more of a meta-point about how LLM helped to speed up my learning. Whenever I come across any concept or mathematical proofs which was skipped over, I found myself reaching out to Google Gemini to explain what these missing details. Compared to looking around for the appendix of the book for the explanations, or other literature as references, the immediate answers provided by LLM helped me to maintain my train of thought and immediate move on to the next paragraph. I also saved significant time from flipping through pages without guarantee of finding the answers I was looking for.

In addition, LLM seems to have the ability to synthesise and explain concepts much simpler and clearer than I could otherwise get from sources such as StackOverflow. For example, say I am interested to know the difference between max and argmax, and here is an answer which I found on StackExchange:

argmax returns a value from the domain of the function and max returns from the range of the function?
source

If I still do not understand (or I have missed a simpler explanation somewhere in the webpage), I could instead ask LLM to explain like I'm 5:

what is the difference between max and argmax? ELI5

And below is the reply I received from Google Gemini Pro 2.5, which subjectively is a much easier read for me.

Okay, imagine you have a group of friends and you want to find out two things about their heights:

  1. Max (Maximum):

    • This is like asking: "What is the tallest height anyone in this group is?"
    • You'd measure everyone, look at all the height numbers (e.g., 4 feet, 5 feet, 4.5 feet), and pick the biggest number.
    • Answer: "5 feet!" (This is the value of the height).
  2. Argmax (Argument of the Maximum):

    • This is like asking: "WHO is the tallest person?"
    • You'd still find the tallest height (5 feet), but then you'd point to the person who is that tall.
    • Answer: "It's Sarah!" (This is the person or item that has the maximum value).

So, in short:

  • Max gives you the highest value itself.
  • Argmax tells you which item (or which position in a list) has that highest value.

Example with numbers:

Let's say you have a list of scores from a game:
[10, 50, 20]
(Player A scored 10, Player B scored 50, Player C scored 20)

  • Max: The highest score is 50.
  • Argmax: The player who got the highest score is Player B (or if we use positions like in computers (0, 1, 2), it's position 1).

I have also been using Google Gemini as my personal math tutor, and it has been amazing in helping me to fill in the gaps of mathematics proofs which have been skipped over because they were deemed trivial by the authors. LLM trained with reinforcement learning do appear to be ideal at tackling mathematics proofs where the answer can be validated. However, I think some care needs to be taken to ensure that LLM does not replace my habit of thinking, or at least I would try to attempt the proof myself after having seen the proof came up by LLM.

Appendix B - Why is nn.Linear better at initialising weights that lead to more stable model training

(note: the following explanation is contributed by Google Gemini using AI Studio)

In model training, the start state of the QKV matrices has significant impact on the training effectiveness.

Say we want our initialisation weights to roughly model after the characteristic of input data. For example, if the input data has a standard deviation of roughly 1.0, we therefore would also like our initialised data to also have a standard deviation of about 1.0.

We can create a random set of values with standard deviation using nn.randn, which works:

import torch
import torch.nn as nn
import math

# Input data: 1000 examples, 10 features each
d_in = 10
x = torch.randn(1000, d_in)

print(f"Input mean: {x.mean():.4f}, Input std dev: {x.std():.4f}")
# Expected output: Input mean: ~0.0000, Input std dev: ~1.0000
Enter fullscreen mode Exit fullscreen mode

However, if we now apply nn.rand in the context of matrix multiplication (which is what we require in the attention mechanism), we run into trouble:

# Create a weight matrix with the "guessing" method
d_out = 20
W_rand = torch.rand(d_in, d_out)

# Perform the matrix multiplication
output_rand = x @ W_rand

print(f"Using torch.rand:")
print(f"  - Weight matrix std dev: {W_rand.std():.4f}")
print(f"  - Output mean:           {output_rand.mean():.4f}")
print(f"  - Output std dev:        {output_rand.std():.4f}  <-- Problematic!")
Enter fullscreen mode Exit fullscreen mode

Output is below.

Using torch.rand:
  - Weight matrix std dev: 0.2883
  - Output mean:           -0.0508
  - Output std dev:        1.7807  <-- Problematic!
Enter fullscreen mode Exit fullscreen mode

We can see that as the matrices scale in size, the "drift" in the standard deviation becomes even more pronounced. This is a problem as modern day LLMs have very large matrices with billions of parameters:

d_in_large = 512 # A more realistic input size for a transformer model
x_large = torch.randn(1000, d_in_large)
W_rand_large = torch.rand(d_in_large, d_out)
output_rand_large = x_large @ W_rand_large

print(f"\nUsing torch.rand with a larger input size ({d_in_large}):")
print(f"  - Input std dev:      {x_large.std():.4f}")
print(f"  - Output std dev:     {output_rand_large.std():.4f}  <-- Exploding!")
Enter fullscreen mode Exit fullscreen mode

Output:

Using torch.rand with a larger input size (512):
  - Input std dev:      1.0011
  - Output std dev:     13.3220  <-- Exploding!
Enter fullscreen mode Exit fullscreen mode

To address the problem of the "exploding" standard deviation, nn.Linear uses a smart initialisation called "Kaiming He initialisation"), paper available here. Essentially, it scales the initial random weights based on the number of inputs, i.e.

standard deviation=1number of inputs \text{standard deviation} = \frac{1}{\text{number of inputs}}
# Create a weight matrix with the "smart" method
linear_layer = nn.Linear(d_in, d_out, bias=False)

# Re-initialize the weights in-place using Kaiming Normal.
# This changes the weight distribution to what we need.
nn.init.kaiming_normal_(linear_layer.weight, a=1)

W_linear = linear_layer.weight.T # Get the weights in the same (d_in, d_out) shape

# Perform the matrix multiplication
output_linear = x @ W_linear

# Let's check the math: 1 / sqrt(10) is about 0.316
print(f"\nUsing nn.Linear (d_in={d_in}):")
print(f"  - Target std dev for weights: {1 / math.sqrt(d_in):.4f}")
print(f"  - Actual weight matrix std dev: {W_linear.std():.4f}")
print(f"  - Output mean:                  {output_linear.mean():.4f}")
print(f"  - Output std dev:               {output_linear.std():.4f}  <-- Just right!")
Enter fullscreen mode Exit fullscreen mode

Output:

Using nn.Linear (d_in=10):
  - Target std dev for weights: 0.3162
  - Actual weight matrix std dev: 0.3058
  - Output mean:                  0.0010
  - Output std dev:               0.9805  <-- Just right!
Enter fullscreen mode Exit fullscreen mode

With larger input sizes:

linear_layer_large = nn.Linear(d_in_large, d_out, bias=False)

# Re-initialize the weights in-place using Kaiming Normal.
# This changes the weight distribution to what we need.
nn.init.kaiming_normal_(linear_layer_large.weight, a=1)

W_linear_large = linear_layer_large.weight.T

output_linear_large = x_large @ W_linear_large

print(f"\nUsing nn.Linear with a larger input size ({d_in_large}):")
print(f"  - Target std dev for weights: {1 / math.sqrt(d_in_large):.4f}")
print(f"  - Actual weight matrix std dev: {W_linear_large.std():.4f}")
print(f"  - Output std dev:               {output_linear_large.std():.4f}  <-- Still just right!")
Enter fullscreen mode Exit fullscreen mode

output:

Using nn.Linear with a larger input size (512):
  - Target std dev for weights: 0.0442
  - Actual weight matrix std dev: 0.0438
  - Output std dev:               0.9932  <-- Still just right!
Enter fullscreen mode Exit fullscreen mode

Comments 0 total

    Add comment