Right, so we’ve established that self-attention is the magic trick that lets every word in a sequence have a little meeting with every other word to figure out how much they should care about each other. But if that’s all we had, it would be a bit of a blunt instrument. It’s like only having one tool in your workshop—a hammer. Sure, you can attend to everything, but you’re probably going to treat every relationship like a nail.

This is where Multi-Head Attention comes in. Think of it as giving the model multiple sets of eyes, each looking for different kinds of relationships. One “head” might specialize in finding pronoun-antecedent relationships (what “it” refers to), another might be great at spotting syntactic connections, and another might key in on the emotional tone of certain words. By running these attention operations in parallel and then smooshing their outputs together, we allow the model to attend to information from different representation subspaces. It’s a way to parallelize pattern recognition, and it’s brutally effective.

The Core Idea: Multiple Queries, Keys, and Values

Instead of having one single set of Query, Key, and Value matrices (W_Q, W_K, W_V), we have h number of heads, each with their own dedicated set of these matrices. This is the crucial part. Each head learns to project the input embeddings into a different subspace.

The dimensionality of these subspaces is usually smaller to keep the computational cost similar to single-head attention. The typical formula is d_k = d_v = d_model / h. So if your model dimension is 512 and you have 8 heads, each key, query, and value vector will be 64-dimensional.

Here’s the process, step-by-step:

  1. For each head i, we take the input X and compute head-specific Queries, Keys, and Values: Q_i = X @ W_Q_i, K_i = X @ W_K_i, V_i = X @ W_V_i.
  2. For each head, we compute the scaled dot-product attention as we normally would: Output_i = softmax((Q_i @ K_i.T) / sqrt(d_k)) @ V_i.
  3. This gives us h different output matrices, each of shape [sequence_length, d_v].
  4. We concatenate all these head outputs together. If d_v is 64 and h is 8, the concatenated matrix is [sequence_length, 512] again.
  5. Finally, we project this concatenated output through a final linear layer W_O to allow the heads to mix their learned features together. The final output is Z = Concat(Output_1, ..., Output_h) @ W_O.

Why This Works: A Committee of Experts

The “why” is the beautiful part. You can think of each head as a semi-independent feature detector. During training, they often (though not always) specialize in different syntactic and semantic relationships. Researchers analyzing models have found heads that specialize in things like:

  • Finding the direct object of a verb.
  • Attending to the word that ends a dependent clause.
  • Focusing on delimiter tokens like periods or commas.
  • Tracking certain types of entities.

This parallel investigation is far more powerful than just having one big attention mechanism. It’s the difference between asking one generalist “what’s important here?” and asking a committee of eight specialists the same question and then combining their answers.

Code It Like You Mean It

Let’s break this down into code. First, we’ll write a single head, then we’ll scale it up. Note: This is a simplified, batched implementation for clarity.

import torch
import torch.nn as nn
import torch.nn.functional as F

class MultiHeadAttention(nn.Module):
    def __init__(self, d_model, num_heads):
        super().__init__()
        assert d_model % num_heads == 0, "d_model must be divisible by num_heads"
        
        self.d_model = d_model
        self.num_heads = num_heads
        self.d_k = d_model // num_heads  # Depth of each head's Q, K, V
        
        # These linear layers project into *all heads at once*.
        # We'll split the result later. It's more efficient this way.
        self.w_q = nn.Linear(d_model, d_model)  # Projects to (num_heads * d_k)
        self.w_k = nn.Linear(d_model, d_model)
        self.w_v = nn.Linear(d_model, d_model)
        self.w_o = nn.Linear(d_model, d_model)  # Output projection
        
    def forward(self, x, mask=None):
        batch_size, seq_len, d_model = x.size()
        
        # Linear projection and then split into heads
        # Shape: (batch_size, seq_len, d_model) -> (batch_size, seq_len, num_heads, d_k)
        q = self.w_q(x).view(batch_size, seq_len, self.num_heads, self.d_k)
        k = self.w_k(x).view(batch_size, seq_len, self.num_heads, self.d_k)
        v = self.w_v(x).view(batch_size, seq_len, self.num_heads, self.d_k)
        
        # Transpose to get dimensions: (batch_size, num_heads, seq_len, d_k)
        # This allows us to compute attention for each head in parallel
        q = q.transpose(1, 2)
        k = k.transpose(1, 2)
        v = v.transpose(1, 2)
        
        # Compute scaled dot-product attention (we'll write this function)
        # attn_output shape: (batch_size, num_heads, seq_len, d_k)
        attn_output, attn_weights = scaled_dot_product_attention(q, k, v, mask)
        
        # Concatenate heads: transpose back to (batch_size, seq_len, num_heads, d_k)
        # then reshape to (batch_size, seq_len, d_model)
        concat_output = attn_output.transpose(1, 2).contiguous().view(batch_size, seq_len, d_model)
        
        # Final output projection
        output = self.w_o(concat_output)
        
        return output, attn_weights

def scaled_dot_product_attention(q, k, v, mask=None):
    # q, k, v shapes: (batch_size, ..., seq_len, d_k)
    # '...' could be num_heads, for example
    
    # Compute attention scores
    matmul_qk = torch.matmul(q, k.transpose(-2, -1))  # (..., seq_len, seq_len)
    
    # Scale scores
    d_k = q.size()[-1]
    scaled_scores = matmul_qk / (d_k ** 0.5)
    
    # Apply mask (if provided) - crucial for decoder to prevent cheating
    if mask is not None:
        scaled_scores = scaled_scores.masked_fill(mask == 0, -1e9)
    
    # Softmax to get attention weights
    attn_weights = F.softmax(scaled_scores, dim=-1)  # (..., seq_len, seq_len)
    
    # Weighted sum of values
    output = torch.matmul(attn_weights, v)  # (..., seq_len, d_k)
    
    return output, attn_weights

Pitfalls and Best Practices

  • The Divisor Check: Always, always assert that d_model % num_heads == 0. If you don’t, the dimension math will explode messily, and you’ll deserve the frustration. It’s the “did you turn it off and on again?” of transformer coding.
  • Masking is Non-Negotiable: For decoder blocks, you must apply a causal mask (usually a upper-triangular mask of -inf) to prevent the model from attending to future tokens. Forgetting this is the single easiest way to build a model that cheats on its next-token prediction task.
  • Interpreting Heads: Don’t assume every head will be perfectly interpretable. While some specialize nicely, many become entangled and hard for humans to understand. The model cares about performance, not our need for neat explanations.
  • Efficiency: The implementation above uses view and transpose to handle the batching of heads. This is standard, but it can be tricky to get the dimensions just right. Use tensor shape debugging prints (print(x.shape)) liberally when you’re first building this. The performance gain over looping through heads is massive, so it’s worth the headache.

The designers got this one right. Multi-head attention isn’t just an incremental improvement; it’s a fundamental architectural insight that makes the whole transformer tick. It’s the reason the model can understand that in the sentence “The animal didn’t cross the street because it was too tired,” “it” refers to “animal” and not “street,” while also parsing the negation and the causal clause. It’s not magic, but it’s the closest thing we’ve got in machine learning.