Right, so you’ve made it past the encoder. Good. That was the warm-up. Now we get to the real party trick of the Transformer: the decoder. This is where the model actually becomes a generative model, where it takes all that juicy contextual understanding from the encoder and uses it to produce something new, one token at a time. It’s a beautiful, slightly unhinged process of creative constraint.

The decoder stack looks suspiciously like the encoder stack—it’s built from layers of self-attention and feed-forward networks—but it has two absolutely critical modifications that prevent it from cheating. And I mean really prevent it. Because if it could cheat, it would be useless.

The Straightjacket of Masked Self-Attention

Think about the task: translating “I am a cat” into French. When generating the output sequence “Je suis un chat”, the model must generate the tokens one-by-one. When it’s generating the word “suis”, it should only be allowed to look at the words that came before it (“Je”) and the entire input sentence (“I am a cat”). It should NOT be allowed to peek at the word it’s about to generate (“suis”) or the words that come after (“un chat”). That would be a spectacularly useless model.

This is where masked self-attention comes in. It’s the straightjacket we lovingly force upon the decoder during training. The mechanism is brilliantly simple: we compute the attention scores as usual, but then we apply a mask that sets the values for all future positions to negative infinity before the softmax step.

import torch
import torch.nn as nn
import torch.nn.functional as F

# Suppose we have a sequence of length 4
seq_len = 4
d_k = 64  # dimension of key/query

# Random queries and keys (for illustration)
queries = torch.randn(seq_len, d_k)
keys = torch.randn(seq_len, d_k)

# Compute the attention scores the normal way
scores = torch.matmul(queries, keys.transpose(-2, -1)) / (d_k ** 0.5)
print("Raw Scores:\n", scores)

# Now, the magic mask. It's an upper triangular matrix of -inf.
mask = torch.triu(torch.ones(seq_len, seq_len) * float('-inf'), diagonal=1)
print("Mask:\n", mask)

# Apply the mask by adding it to the scores
masked_scores = scores + mask
print("Masked Scores:\n", masked_scores)

# Apply softmax. The -inf values become zero.
weights = F.softmax(masked_scores, dim=-1)
print("Final Attention Weights (notice the upper triangle is zeroed out):\n", weights)

Why negative infinity? Because softmax(e^(-inf)) is softmax(0), which is zero. This elegantly ensures that when calculating the output for position i, the attention weights for any position j > i are exactly zero. The model is forced to attend only to the past, building its prediction for the next word based on what it’s already produced and the original input. It’s the core of its auto-regressive nature.

Cross-Attention: The Decoder’s Hotline to the Encoder

Okay, so the decoder can attend to its own previous output. Big deal. That alone would just give us a fancy language model generating a sentence in a vacuum. The whole point is to condition that generation on the input sequence. This is where cross-attention acts as the direct hotline, the conference call between the decoder and the encoder.

In a decoder layer, after the masked self-attention sub-layer has done its job, the cross-attention sub-layer takes over. Here’s the slightly weird part that everyone gets wrong at first:

  • The Queries come from the decoder’s previous layer (the output of the masked self-attention).
  • The Keys and Values come from the encoder’s final output.

Let me say that again. The decoder uses its own current state as the query, and goes looking for relevant information in the encoder’s memory. It’s asking: “Based on what I, the decoder, have produced so far (‘Je’), what part of the input sequence (‘I’, ‘am’, ‘a’, ‘cat’) is most relevant for me to look at next?”

# Assume we have encoder output (the "memory")
encoder_output = torch.randn(seq_len, d_k)  # [seq_len, d_model]

# And the output from the decoder's masked self-attention layer
decoder_hidden = torch.randn(seq_len, d_k)  # [seq_len, d_model]

# Projections for Q, K, V
W_q = nn.Linear(d_k, d_k)  # for decoder_hidden (Queries)
W_k = nn.Linear(d_k, d_k)  # for encoder_output (Keys)
W_v = nn.Linear(d_k, d_k)  # for encoder_output (Values)

Q = W_q(decoder_hidden)  # Queries come from the decoder
K = W_k(encoder_output)  # Keys come from the encoder
V = W_v(encoder_output)  # Values come from the encoder

# Now compute cross-attention exactly like normal attention
scores = torch.matmul(Q, K.transpose(-2, -1)) / (d_k ** 0.5)
attn_weights = F.softmax(scores, dim=-1)
cross_attn_output = torch.matmul(attn_weights, V)

# This cross_attn_output is then passed to the feed-forward network

This is the moment of alignment. It’s how the model learns to connect “chat” to “cat” and “suis” to “am”. The decoder isn’t just generating; it’s generating while constantly cross-referencing the original source material.

The Subtle Pitfalls You Will Encounter

This architecture is genius, but it’s not without its quirks.

  1. The Teacher Forcing Trap: During training, we use the true previous tokens as input to the decoder (this is called teacher forcing). It’s stable and efficient. But during inference, the model uses its own previous, potentially incorrect, predictions. This mismatch between training and inference can cause a model to go off the rails quickly if it wasn’t trained robustly. Techniques like scheduled sampling or beam search are used to mitigate this, but it’s a fundamental tension.

  2. The Information Bottleneck: The entire burden of source-target alignment is placed on that single cross-attention matrix. For very long sequences, this can be a lot to ask. It has to learn which single encoder hidden state to pay attention to for every decoder step. If that information is diffuse or complex, the model can struggle. This is why many subsequent models (like T5) moved to an encoder-decoder structure where the encoder can build a richer, more compressed representation.

  3. The Curse of Auto-regressive Slowness: You can’t parallelize inference. You must generate one token at a time, waiting for step i to finish before you can compute step i+1. This is why these models can feel slow to generate from compared to, say, a BERT-style encoder that processes everything in one go.

The designers knew these trade-offs. The masked self-attention is a necessary evil for auto-regressive generation, and the cross-attention mechanism is the most direct way to link the two sequences. It’s not the only way, but it’s a brilliantly effective one that, frankly, changed everything.