Understanding the attention mechanism

transformer
LLM
All about matrix multiplications.
Published

June 19, 2023

This post explains how the attention mechanism in modern language models works, and how it is accomplished through matrix multiplication.

Token interaction as matrix multiplication

In language modeling, a chuck of text can be tokenised and embedded into a data matrix. In the following (2, 3) data matrix, each row represent the embedding of one token

\[ \mathbf{X} = \begin{pmatrix} x_{11} & x_{12} & x_{13} \\ x_{21} & x_{22} & x_{23} \end{pmatrix} \]

However, we know that to model the human natural language, it’s not enough to have each token independently embedded; the tokens are also semantically related to each other. For example in the sentence “the art of writting is the art of discovering what you believe”, the meaning of the sentence is not only in the individual words, but also in how they are methodically arranged together. With the words and sentences tokenized as the above data matrix, interactions can easily be realized by left and right multiplying the data matrix with interaction matrices.

For example left multiplication it with an (2, 2) interaction matrix \[ \mathbf{L} = \begin{pmatrix} l_{11} & l_{12} \\ l_{21} & l_{22} \end{pmatrix} \]

will give us

\[ \mathbf{L} \mathbf{X} = \begin{pmatrix} l_{11} & l_{12} \\ l_{21} & l_{22} \end{pmatrix} \begin{pmatrix} x_{11} & x_{12} & x_{13} \\ x_{21} & x_{22} & x_{23} \end{pmatrix} = \begin{pmatrix} l_{11}x_{11} + l_{12}x_{21} & l_{11}x_{12} + l_{12}x_{22} & l_{11}x_{13} + l_{12}x_{23} \\ l_{21}x_{11} + l_{22}x_{21} & l_{21}x_{12} + l_{22}x_{22} & l_{21}x_{13} + l_{22}x_{23} \end{pmatrix}, \]

which is again a (2, 3) matrix, the same as the original data matrix. We notice that each entry in the new data matrix is the weighted sum of the corresponding entries in the original matrix, frome different rows, but from the same column. In this way we achieved the goal of interacting the different tokens with each other. However we should also note that there is no interaction between the different columns of the original matrix.

Similarly, right multiplication it with an (3, 2) interaction matrix

\[ \mathbf{R} = \begin{pmatrix} r_{11} & r_{12} \\ r_{21} & r_{22} \\ r_{31} & r_{32} \end{pmatrix} \]

will give us

\[ \mathbf{X} \mathbf{R} = \begin{pmatrix} x_{11} & x_{12} & x_{13} \\ x_{21} & x_{22} & x_{23} \end{pmatrix} \begin{pmatrix} r_{11} & r_{12} \\ r_{21} & r_{22} \\ r_{31} & r_{32} \end{pmatrix} = \begin{pmatrix} x_{11}r_{11} + x_{12}r_{21} + x_{13}r_{31} & x_{11}r_{12} + x_{12}r_{22} + x_{13}r_{32} \\ x_{21}r_{11} + x_{22}r_{21} + x_{23}r_{31} & x_{21}r_{12} + x_{22}r_{22} + x_{23}r_{32} \end{pmatrix}, \]

which is a (2, 2) matrix. It has the same number of rows as the original data matrix, but the number of columns is determined by the interaction matrix. We notice that each entry in the new data matrix is the weighted sum of the corresponding entries in the original matrix, from different columns, but from the same row. And there is no interaction between the different rows of the original matrix. In this way we have updated the embedding of each token, but without interacting the tokens with each other.

Combining left and right multiplication, we can thus both make the tokens interact with each other, and update the embedding of each token. In transformer language models, it’s the repeated application of this operation that allows the model to “understand” the context of the sentence, and represent it in the most meaningful way. This is the backbone of all the large language models.

The final interaction mechanism will look like this: \[ \mathbf{Y} = \mathbf{L} \mathbf{X} \mathbf{R}. \]

Interaction matrices in practice

To see this in action let’s first define a data matrix.

import numpy as np
np.random.seed(0)

T, E = 4, 5  # sequence length, embedding dimension
x = np.random.randn(T, E)  # input tensor
print(x)
[[ 1.76405235  0.40015721  0.97873798  2.2408932   1.86755799]
 [-0.97727788  0.95008842 -0.15135721 -0.10321885  0.4105985 ]
 [ 0.14404357  1.45427351  0.76103773  0.12167502  0.44386323]
 [ 0.33367433  1.49407907 -0.20515826  0.3130677  -0.85409574]]

Here each row is an embedding of one token as a 5 dimensional vector (E=5), and 4 tokens (T=4) are stacked together, from top to bottom. This way we end up with a (4, 5) data matrix.

Next, let’s start with left multiplication interactions. A diagonal interaction matrix, which means that each token only interacts with itself (i.e. no interaction at all), is the simplest of all the interaction matrices.

L = np.eye(T)
print(L, L @ x, sep="\n")
[[1. 0. 0. 0.]
 [0. 1. 0. 0.]
 [0. 0. 1. 0.]
 [0. 0. 0. 1.]]
[[ 1.76405235  0.40015721  0.97873798  2.2408932   1.86755799]
 [-0.97727788  0.95008842 -0.15135721 -0.10321885  0.4105985 ]
 [ 0.14404357  1.45427351  0.76103773  0.12167502  0.44386323]
 [ 0.33367433  1.49407907 -0.20515826  0.3130677  -0.85409574]]

As we expected, the embedding matrix is the same as the original x. Next we can make all the tokens equally affecting all others by setting the interaction matrix to be of all ones.

L = np.ones((T, T)) 
print(L, L @ x, sep="\n")
[[1. 1. 1. 1.]
 [1. 1. 1. 1.]
 [1. 1. 1. 1.]
 [1. 1. 1. 1.]]
[[1.26449236 4.29859821 1.38326024 2.57241707 1.86792399]
 [1.26449236 4.29859821 1.38326024 2.57241707 1.86792399]
 [1.26449236 4.29859821 1.38326024 2.57241707 1.86792399]
 [1.26449236 4.29859821 1.38326024 2.57241707 1.86792399]]

Now all the rows of the embedding matrix are the same, since we effectively summed up all the rows in the original matrix to create the new ones. A slightly more interesting interaction matrix is the lower triangular matrix:

L = np.tril(np.ones((T, T))) 
print(L, L @ x, sep="\n")
[[1. 0. 0. 0.]
 [1. 1. 0. 0.]
 [1. 1. 1. 0.]
 [1. 1. 1. 1.]]
[[1.76405235 0.40015721 0.97873798 2.2408932  1.86755799]
 [0.78677447 1.35024563 0.82738078 2.13767435 2.27815649]
 [0.93081804 2.80451913 1.5884185  2.25934936 2.72201972]
 [1.26449236 4.29859821 1.38326024 2.57241707 1.86792399]]

After left multiplying L, the first row (embeddings for the first word) is left as is, the new second row is the sum of the first two rows, the new third row is the sum of the first three, etc. This is a one way interaction mechanism where the tokens come before affect the tokens that come after, but not the other way around.

Of course the interaction matrix can be any matrix, not just the ones we’ve seen above.

L = np.random.randn(T, T)
print(L, L @ x, sep="\n")
[[-2.55298982  0.6536186   0.8644362  -0.74216502]
 [ 2.26975462 -1.45436567  0.04575852 -0.18718385]
 [ 1.53277921  1.46935877  0.15494743  0.37816252]
 [-0.88778575 -1.98079647 -0.34791215  0.15634897]]
[[-5.26549961 -0.25232838 -1.78750815 -5.91561089 -3.48191029]
 [ 5.36941815 -0.68663938  2.51485007  5.18336211  3.82192147]
 [ 1.41643325  2.79971404  1.31812887  3.42037269  3.21165905]
 [ 0.37174317 -2.50954735 -0.86595236 -1.77836191 -2.75926583]]

We can no longer identify the relationship between the tokens since, as the name implied, we are now randomly mixing the tokens up.

One extra thing to note is that it’s generally not a good idea to simply summing vectors up: when the vectors get extremely long, as is often the case in modern deep learning (to accurately capture the meaning of a word we need some context, the longer the better), the sum of vectors can often explode, which affects the stability and efficiency of the model training. So we’ll use weights the instead, and make sure each row of L sum to one. Things are easy if all the values are positive, say 1, 2, 2, 3, then the sum is 8, and the weights are simply 1/8, 2/8, 2/8, and 3/8. But, it’s not always easy to get meaningful weights, if the entries in L are not always positive. What if the values are 1, -2, 2, 3? And they may even sum to zero, like -1, 2, 2, -3, leaving us with a devided by zero error. The solution we use is simple: first exponentiate all the entries to make sure they are all positive, then compute the weights using these positive numbers.

L = np.random.randn(T, T)
L = np.exp(L)
L = L / L.sum(axis=1, keepdims=True)
print(L, L @ x, sep="\n")
[[0.41896738 0.40743534 0.08311088 0.09048641]
 [0.04488369 0.03095732 0.02325121 0.90090777]
 [0.16201575 0.17403762 0.07705739 0.58688924]
 [0.0689366  0.27987549 0.14140365 0.50978426]]
[[ 0.38306743  0.8108122   0.39307749  0.93524704  0.90934403]
 [ 0.35288226  1.4272138  -0.12788986  0.3822584  -0.6626072 ]
 [ 0.32265065  1.21910435  0.07046752  0.53820807 -0.09302326]
 [ 0.03856185  1.26078951  0.02813675  0.30239341 -0.12898112]]

Similarly we can create a right interaction matrix R to make the columns interact, and thus update the embedding of each token. We can of course apply the same softmax treatment to each column of R so that each column sums to one, but for some reason (which is still beyond my knowledge) this has never been done in practice. So here we’ll stick to the common wisdom. When applying the right interaction matrix, we need to decide on the new embedding dimension H1. (It’s named H1 because we are going to need another H2, as we’ll see shortly.)

H1 = 6
R = np.random.randn(E, H1)
print(R, x @ R, sep="\n")
[[-0.51080514 -1.18063218 -0.02818223  0.42833187  0.06651722  0.3024719 ]
 [-0.63432209 -0.36274117 -0.67246045 -0.35955316 -0.81314628 -1.7262826 ]
 [ 0.17742614 -0.40178094 -1.63019835  0.46278226 -0.90729836  0.0519454 ]
 [ 0.72909056  0.12898291  1.13940068 -1.23482582  0.40234164 -0.68481009]
 [-0.87079715 -0.57884966 -0.31155253  0.05616534 -1.16514984  0.90082649]]
[[-0.97371195 -3.41308712  0.05709096 -1.59755613 -2.3704341   0.04139219]
 [-0.56312213  0.61899371 -0.61014338 -0.67973328 -1.22017855 -1.50301919]
 [-1.15883076 -1.24459388 -2.22229344 -0.23431315 -2.33165625 -2.11086604]
 [-0.18257152 -0.31870854 -0.05685886 -0.92377577  0.11453969 -3.47271662]]

We end up with a new embedding matrix of dimension (T, H1).

Parameterizing query, key, and value

We have seen how the left and right interaction matrices can be used to update the token embedding, now it’s time to determine how to populate them. Since we are talking about neural networks, we can just treat L and R as parameters of the network, and learn them from the data. However, as things stand now, once learned, the same parameter matrices will be applied to all sentences, which effectively means that all the sentences will be forced to interact in the same way, which is clearly not what we want. On the contrary the interaction matrices should be bespoke for each sentence; each sentence should has its own way of interaction between the tokens. The natural way to achieve this is to make the interaction matrices not only have learnable parameters, but are also functions of the input tokens.

We finally settle on the following interaction mechanism: \[ \begin{align*} \mathbf{Y} &= \mathbf{L} \mathbf{X} \mathbf{R} \\ \mathbf{L} &= \text{softmax} \left( \mathbf{Q} \mathbf{K}^\top \right) \\ \mathbf{Q} &= \mathbf{X} \mathbf{W}^Q \\ \mathbf{K} &= \mathbf{X} \mathbf{W}^K \\ \mathbf{V} &= \mathbf{X} \mathbf{W}^V = \mathbf{X} \mathbf{R} \end{align*} \]

The left interaction matrix L is the product of two matrices, the query matrix Q and the key matrix K, each in turn is the product of two matrices, the data matrix and a parameter matrix (Ignore softmax, which is for turning random matrices into weights). Since Q and K are calculated exactly the same way, their difference in naming is only an conventional. Note that since Q and K (and V) are all calculated by right multiplying a parameter matrix, they are all updated embeddings for each token, with no interactions between them, as we’ve discussed earlier. However by multiplying them together

\[ \mathbf{Q} \mathbf{K}^\top = \left( \mathbf{X} \mathbf{W}^Q \right) \left( \mathbf{X} \mathbf{W}^K \right)^\top = \mathbf{X} \mathbf{W}^Q {\mathbf{W}^K}^\top \mathbf{X}^\top, \]

We see that the weight matrices are both appearing on the left (for the second X) and right (for the first X), the resulting matrix thus captures all possible information flows among the tokens. To make sure the query and key have compatible dimensions, their parameter matrices should have the same dimensions.

The value matrix \(\mathbf{V}\) is just a rewrite of the right interaction \(\mathbf{X} \mathbf{R}\).

We can now implement the whole process.

T, E, H1, H2 = 4, 5, 6, 7

X = np.random.randn(T, E)

W_Q = np.random.randn(E, H2)
W_K = np.random.randn(E, H2)
W_V = np.random.randn(E, H1)

Q = X @ W_Q
K = X @ W_K
V = X @ W_V

L = np.exp(Q @ K.T)
L = L / L.sum(axis=1, keepdims=True)

print("Q shape:", Q.shape)
print("K shape:", K.shape)
print("V shape:", V.shape)
print("L shape:", L.shape)
print("L @ V shape:", (L @ V).shape)
print(L, V, L @ V, sep='\n')
Q shape: (4, 7)
K shape: (4, 7)
V shape: (4, 6)
L shape: (4, 4)
L @ V shape: (4, 6)
[[9.99996706e-01 3.28680000e-06 7.50610498e-09 2.85006607e-14]
 [9.78226751e-01 1.28560219e-03 2.02652246e-02 2.22422213e-04]
 [3.07277486e-05 1.29535875e-01 8.07295984e-01 6.31374132e-02]
 [2.82603173e-01 6.82619752e-03 7.10526096e-01 4.45339485e-05]]
[[ 0.48796236 -1.23724751  0.89411441  1.86818403  0.0708713   4.78357836]
 [ 2.45759555 -0.6918528   2.06088201  3.50938285 -0.60840164  3.90835192]
 [-0.94901715 -0.49210324 -0.95973845 -1.99359209 -0.69361156 -1.88892325]
 [-1.75111978 -2.41142737 -5.32539656 -2.93730234 -0.33078805 -0.89377003]]
[[ 0.48796882 -1.23724571  0.89411823  1.86818939  0.07086906  4.78357544]
 [ 0.46087579 -1.221707    0.85666231  1.79096535  0.05441627  4.64597066]
 [-0.55833713 -0.63918203 -0.84403913 -1.34022418 -0.65964259 -1.07493171]
 [-0.5197037  -0.70413238 -0.41540882 -0.86471954 -0.47696846  0.03636454]]

Note that we have used different embedding dimensions for the data matrix(E=5), right interaction matrix(H1=6), and the query/key embedding for the left interaction matrix(H2=7). In practice, since transformer models usually have multiple attention layers stacked together, and each attention layer also has an extra residual connection, the embedding dimensions for the data matrix and the right interaction matrix are usually the same (E=H1). Besides, the query/key embedding for the left interaction matrix is usually also kept the same as the embedding dimension for the data matrix (E=H2).

T, E = 4, 5

X = np.random.randn(T, E)

W_Q = np.random.randn(E, E)
W_K = np.random.randn(E, E)
W_V = np.random.randn(E, E)

Q = X @ W_Q
K = X @ W_K
V = X @ W_V

L = np.exp(Q @ K.T)
L = L / L.sum(axis=1, keepdims=True)

print("Q shape:", Q.shape)
print("K shape:", K.shape)
print("V shape:", V.shape)
print("L shape:", L.shape)
print("L @ V shape:", (L @ V).shape)
print(L, V, L @ V, sep='\n')
Q shape: (4, 5)
K shape: (4, 5)
V shape: (4, 5)
L shape: (4, 4)
L @ V shape: (4, 5)
[[4.17764387e-09 2.57135620e-06 9.99997419e-01 5.39037660e-09]
 [8.83647632e-01 3.61057169e-02 1.10946675e-05 8.02355563e-02]
 [3.27298619e-01 1.62202688e-02 4.33976229e-01 2.22504884e-01]
 [8.16727896e-01 2.61363556e-04 1.81996886e-01 1.01385435e-03]]
[[-0.38265795 -1.0078002  -1.28306344  0.87532503  1.31775286]
 [ 0.42510411 -0.71553386  2.17308255 -0.01319503 -0.38305383]
 [ 1.70990215 -0.46195207 -2.27786996 -0.59268283  1.09518226]
 [-1.44638454 -4.79764402  3.20740801 -0.21663068  2.96228223]]
[[ 1.70989882 -0.46195275 -2.27785848 -0.59268133  1.09517847]
 [-0.4388186  -1.30132189 -0.79799237  0.75561442  1.38829128]
 [ 0.30188115 -1.60943321 -0.65957438 -0.019133    1.55949079]
 [-0.00268587 -0.9122235  -1.45865913  0.60681286  1.27846849]]

And this is usually what attention in language models looks like.

Masked attention

Modern Large language models (basically all LLMs from GPT2 on, the so called “decoder only transformers”) are mostly next-word-prediction machines. That is, given a word, we want the model to predict what the next word should be, but we are not interested in what has come before. The information only flows one way, like the rivers to the sea.

We have already seen how such one-way forward interaction can be achieved: by making L a lower triangle matrix. We’ll create a mask matrix the same shape as L, and after applying the mask, the lower left part of L will be intact, while the upper right part above the diagonal will be set to negative infinity. This way, after exponentiation, the upper right part of the left interaction matrix will be set to zero, which effectively forbids the later tokens affecting the earlier ones.

T, E = 4, 5

X = np.random.randn(T, E)

W_Q = np.random.randn(E, E)
W_K = np.random.randn(E, E)
W_V = np.random.randn(E, E)

Q = X @ W_Q
K = X @ W_K
V = X @ W_V

L = Q @ K.T
mask = np.tril(np.ones(L.shape), k=0)  # Create a lower triangular mask
L = np.where(mask == 1, L, -np.inf)    # Apply the mask to L, setting upper right entries to -inf
L = np.exp(L)
L = L / L.sum(axis=1, keepdims=True)

print(L, V, L @ V, sep='\n')
[[1.00000000e+00 0.00000000e+00 0.00000000e+00 0.00000000e+00]
 [8.73730868e-01 1.26269132e-01 0.00000000e+00 0.00000000e+00]
 [8.75161032e-02 4.22552421e-01 4.89931476e-01 0.00000000e+00]
 [2.44179030e-04 2.07785297e-03 1.42193742e-02 9.83458594e-01]]
[[-0.64112576  3.93575956 -0.12979446  2.44301203  0.22577084]
 [ 1.43055231 -1.6645132   0.44924118 -1.49425456  0.68901062]
 [ 0.0815931   2.13348808  1.5848282   1.0688065   1.61374919]
 [ 0.28371997  0.39761655  0.89286344 -0.20351501  0.41637134]]
[[-0.64112576  3.93575956 -0.12979446  2.44301203  0.22577084]
 [-0.37953677  3.22861798 -0.05668014  1.94585679  0.28426373]
 [ 0.58834954  0.68636122  0.95492606  0.10604396  1.10152821]
 [ 0.28300297  0.4188787   0.90153125 -0.18745914  0.43391727]]

We can see that L is in fact a lower triangle matrix.

Conclusion

And voilà, we now have gone through the single most import component is modern large language models.