The Attention Mechanism

Transformers in Brief

Author

Mark Andrews

Abstract

We introduce the attention mechanism that sits at the heart of the transformer architecture. Starting from intuition, we derive scaled dot-product attention, implement it directly in PyTorch, wrap it in a self-attention module, and assemble a transformer block. The goal is understanding what attention does and how it is computed, not building a complete model.

What attention does

A standard feedforward or convolutional layer processes each input element without any direct knowledge of other elements in the set. Attention breaks this constraint. It allows each element to gather information from every other element, weighted by how relevant each one is.

In image recognition, the input to a transformer is a set of non-overlapping image patches. Attention lets the representation of each patch be informed by every other patch in the image, regardless of distance. Consider face recognition: the eyes, nose, mouth, and jawline are spread across the image, and recognising a face requires relating all of those regions simultaneously. The distance between the eyes matters, as do the proportions between features. None of these are local relationships, and a convolutional layer at any single scale cannot capture them directly. Convolutional networks only allow nearby pixels to interact directly, so long-range dependencies require stacking many layers. Attention provides a direct path between any two patches in a single step.

Queries, keys, and values

The attention mechanism assigns each input element three roles: a query, a key, and a value. All three are produced from the input by learned linear projections.

Think of the query as what an element is looking for, and the key as what it advertises. The dot product of query \(i\) with key \(j\) gives a relevance score: how much element \(j\) matters to element \(i\). After the scores are converted to weights by softmax, the output for element \(i\) is a weighted sum of all the value vectors, where the weights express relevance.

For \(n\) input vectors arranged as rows of a matrix \(X \in \mathbb{R}^{n \times d}\), the three projections are:

\[Q = X W_Q, \qquad K = X W_K, \qquad V = X W_V\]

where \(W_Q\) and \(W_K\) are learned weight matrices of shape \(d \times d_k\), and \(W_V\) is of shape \(d \times d_v\). In practice \(d_v = d_k\) in almost all standard implementations, so the output of attention has the same width as the query and key vectors. Scaled dot-product attention is then:

\[\text{Attention}(Q, K, V) = \text{softmax}\!\left(\frac{Q K^\top}{\sqrt{d_k}}\right) V\]

The division by \(\sqrt{d_k}\) keeps the dot products from growing large as the dimension increases, which would push the softmax into flat regions where gradients are near zero.

Implementing attention

import torch
import torch.nn as nn

The formula translates directly into a short function.

def scaled_dot_product_attention(Q, K, V):
    d_k = Q.shape[-1]
    scores  = Q @ K.transpose(-2, -1) / d_k ** 0.5   # (n, n) score matrix
    weights = torch.softmax(scores, dim=-1)            # each row sums to 1
    return weights @ V, weights

To see it work, create a small set of five input vectors each of length eight and pass them through with no separate projection (Q, K, V all equal to the input itself).

X = torch.randn(5, 8)

out, weights = scaled_dot_product_attention(X, X, X)
out.shape     # (5, 8): one output vector per input element
torch.Size([5, 8])
weights.round(decimals=2)   # (5, 5): weight that element i gives to element j
tensor([[0.4000, 0.1300, 0.1500, 0.1000, 0.2100],
        [0.1900, 0.2900, 0.2200, 0.2600, 0.0500],
        [0.1100, 0.1100, 0.6600, 0.0900, 0.0300],
        [0.0100, 0.0100, 0.0100, 0.9700, 0.0000],
        [0.0500, 0.0100, 0.0100, 0.0100, 0.9300]])

Each row sums to one. The \((i, j)\) entry is the fraction of element \(j\)’s value vector that flows into element \(i\)’s output.

A self-attention module

In a real model the Q, K, V projections are learned parameters. Wrapping them in an nn.Module makes a reusable self-attention layer.

class SelfAttention(nn.Module):
    def __init__(self, embed_dim):
        super().__init__()
        self.q_proj = nn.Linear(embed_dim, embed_dim, bias=False)
        self.k_proj = nn.Linear(embed_dim, embed_dim, bias=False)
        self.v_proj = nn.Linear(embed_dim, embed_dim, bias=False)

    def forward(self, x):
        Q = self.q_proj(x)
        K = self.k_proj(x)
        V = self.v_proj(x)
        return scaled_dot_product_attention(Q, K, V)

The input x has shape (batch, n, embed_dim), where n is the number of elements (patches, tokens, or whatever the inputs are).

attn = SelfAttention(embed_dim=16)

x = torch.randn(2, 6, 16)   # 2 samples, 6 elements each, 16-dimensional
out, weights = attn(x)

out.shape       # (2, 6, 16): output has the same shape as input
weights.shape   # (2, 6, 6): attention weights, one 6x6 matrix per sample
torch.Size([2, 6, 6])

The weight matrix for each sample shows how every element attends to every other element.

Multi-head attention

Single-head attention computes one set of relevance scores. Multi-head attention runs several attention operations in parallel, each in a lower-dimensional subspace of the embedding. Different heads can simultaneously capture different kinds of relationship among the elements. Their outputs are concatenated and projected back to the original dimension.

PyTorch’s nn.MultiheadAttention implements this. The three arguments to forward are the query, key, and value inputs — all the same tensor for self-attention.

mha = nn.MultiheadAttention(embed_dim=16, num_heads=4, batch_first=True)

x = torch.randn(2, 6, 16)
out, weights = mha(x, x, x)

out.shape       # (2, 6, 16)
weights.shape   # (2, 6, 6)
torch.Size([2, 6, 6])

With num_heads=4 and embed_dim=16, each head works in a subspace of dimension \(16/4 = 4\).

A transformer block

A transformer block combines multi-head attention with two additional components.

A position-wise feedforward network is a small two-layer MLP applied identically and independently at every element position. It allows each element’s representation to be further transformed after the attention step.

Residual connections add the input of each sublayer directly to its output: if a sublayer computes \(f(x)\), the result passed forward is \(x + f(x)\). This gives gradients a direct path back through many stacked blocks and is what makes deep transformers trainable.

Layer normalisation normalises each element’s representation across the embedding dimension, keeping activations in a stable range throughout training.

class TransformerBlock(nn.Module):
    def __init__(self, embed_dim, num_heads, ff_dim):
        super().__init__()
        self.attn  = nn.MultiheadAttention(embed_dim, num_heads, batch_first=True)
        self.ff    = nn.Sequential(
            nn.Linear(embed_dim, ff_dim),
            nn.ReLU(),
            nn.Linear(ff_dim, embed_dim),
        )
        self.norm1 = nn.LayerNorm(embed_dim)
        self.norm2 = nn.LayerNorm(embed_dim)

    def forward(self, x):
        attn_out, _ = self.attn(x, x, x)
        x = self.norm1(x + attn_out)      # attention with residual + norm
        x = self.norm2(x + self.ff(x))    # feedforward with residual + norm
        return x
block = TransformerBlock(embed_dim=16, num_heads=4, ff_dim=64)
x = torch.randn(2, 6, 16)
block(x).shape   # (2, 6, 16): shape is preserved through the block
torch.Size([2, 6, 16])

The output has exactly the same shape as the input, so blocks can be stacked arbitrarily.

The transformer architecture

A complete transformer stacks \(N\) such blocks. The input elements are first converted to fixed-length vectors (the embedding step), then passed through the stack, and finally a linear layer maps to the output.

input elements  →  embeddings
                →  transformer block 1
                →  transformer block 2
                →  ...
                →  transformer block N
                →  linear layer  →  output

For language models, input elements are word-piece tokens and the output is a probability distribution over the vocabulary. For image models (Vision Transformer, ViT), the image is divided into fixed-size patches, each patch is flattened and projected to a vector, and the resulting sequence of patch vectors is fed into exactly this stack. The attention mechanism then lets every patch interact with every other patch without the locality constraints of a convolutional network.

Stacking multiple blocks in a small example:

blocks = nn.Sequential(*[TransformerBlock(embed_dim=32, num_heads=4, ff_dim=128)
                          for _ in range(3)])

x = torch.randn(4, 10, 32)   # batch=4, 10 elements, embed_dim=32
blocks(x).shape               # (4, 10, 32)
torch.Size([4, 10, 32])

Each block refines the representations by attending across the elements and transforming them through the feedforward network. After \(N\) blocks, the representations carry information about the global context of every element.