18.5 The Encoder Stack: Self-Attention + FFN + LayerNorm
Right, so you’ve got your input embeddings and you’ve added positional encoding. Now the real party starts: the Encoder Stack. This isn’t just one layer; it’s a series of identical layers stacked on top of each other. And each one is a beautifully engineered little machine with two main workhorses and one crucial piece of organizational glue: Self-Attention, a Feed-Forward Network (FFN), and Layer Normalization. Don’t let the simplicity fool you—this is where the magic of context gets woven into your data.
Let’s break down why this trio is so effective. The self-attention mechanism lets every token in your sequence look at every other token and figure out who it should be paying attention to. It’s the “read” operation. The FFN then takes that aggregated information and does a non-linear transformation on it—it’s the “think” operation. And the LayerNorm? It’s the exasperated parent constantly telling the other two to calm down and keep their outputs in a reasonable, stable range so the next layer doesn’t explode or vanish. They’re a team.
The Self-Attention Mechanism (The Partygoer)
Imagine you’re at a crowded party. Your brain doesn’t just process each person in isolation; it instantly connects the guy talking about “gradient descent” to the woman rolling her eyes—ah, a fellow data scientist! That’s self-attention. It allows each word (or token) to go, “Hey, for the current task, which other words in this sentence are my best friends?”
Technically, it does this by creating three vectors for each token: a Query (what I’m looking for), a Key (what I contain), and a Value (what I actually offer). The attention score between token i and token j is essentially the dot product of Query_i and Key_j. A high score means “these two are relevant to each other.” We scale it down (to prevent gradients from going nuts during softmax) and softmax it to get nice, add-up-to-one attention weights. Then, we take a weighted sum of all the Value vectors.
Here’s the code for a single head. Note the masking for the decoder; we’re focusing on the encoder here, so we can ignore the mask kwarg for now.
import torch
import torch.nn as nn
import torch.nn.functional as F
class SelfAttention(nn.Module):
def __init__(self, embed_size, heads):
super(SelfAttention, self).__init__()
self.embed_size = embed_size
self.heads = heads
self.head_dim = embed_size // heads
assert (
self.head_dim * heads == embed_size
), "Embedding size needs to be divisible by heads"
self.values = nn.Linear(self.head_dim, self.head_dim, bias=False)
self.keys = nn.Linear(self.head_dim, self.head_dim, bias=False)
self.queries = nn.Linear(self.head_dim, self.head_dim, bias=False)
self.fc_out = nn.Linear(heads * self.head_dim, embed_size)
def forward(self, values, keys, query, mask):
N = query.shape[0] # batch size
value_len, key_len, query_len = values.shape[1], keys.shape[1], query.shape[1]
# Split the embedding into self.heads different pieces
values = values.reshape(N, value_len, self.heads, self.head_dim)
keys = keys.reshape(N, key_len, self.heads, self.head_dim)
queries = query.reshape(N, query_len, self.heads, self.head_dim)
energy = torch.einsum("nqhd,nkhd->nhqk", [queries, keys])
# queries shape: (N, query_len, heads, head_dim)
# keys shape: (N, key_len, heads, head_dim)
# energy shape: (N, heads, query_len, key_len)
if mask is not None:
energy = energy.masked_fill(mask == 0, float("-1e20"))
attention = torch.softmax(energy / (self.embed_size ** (1/2)), dim=3)
out = torch.einsum("nhql,nlhd->nqhd", [attention, values]).reshape(
N, query_len, self.heads * self.head_dim
)
out = self.fc_out(out)
return out
The key takeaway (no pun intended) is that this mechanism is order-agnostic. It’s just a fancy set of weighted sums. This is why the positional encoding we added earlier is so critical—without it, the model would have no idea about word order. The utter absurdity of processing a sentence as a bag of words is precisely why we had to invent that sinusoidal fix.
The Feed-Forward Network (The Hermit)
Right after the party, the data goes home for some quiet time. The Feed-Forward Network is a simple, per-position multilayer perceptron. It’s applied to each token independently and identically. Wait, if it’s independent, what’s the point? Didn’t we just do all that work to mix information?
Yes, and that’s the genius. The self-attention gathered the context. The FFN now allows each token to process that gathered context in a high-dimensional, non-linear way. Think of it as each token going, “Okay, based on everyone I just talked to, what should I become now?” It’s a transformation function that operates on the already context-enriched representation.
It’s almost always the same: a linear layer that expands the dimension (usually by a factor of 4), a GELU activation (ReLU works, but GELU is smoother and generally better), and a linear layer that projects it back down to the original embed_size. This bottleneck with a larger inner layer gives the model plenty of representational power.
class FeedForward(nn.Module):
def __init__(self, embed_size, expansion_factor=4):
super(FeedForward, self).__init__()
hidden_size = embed_size * expansion_factor
self.linear1 = nn.Linear(embed_size, hidden_size)
self.linear2 = nn.Linear(hidden_size, embed_size)
self.gelu = nn.GELU() # Better than ReLU for transformers
# Optional: You might see dropout here in practice
def forward(self, x):
out = self.linear1(x)
out = self.gelu(out)
out = self.linear2(out)
return out
Layer Normalization (The Therapist)
Here’s where the original “Attention is All You Need” paper made a, let’s say, questionable choice. They put the LayerNorm after the sub-layer (the attention or FFN), in what’s called a Post-LayerNorm architecture. This is the classic setup. The problem? It makes training deep transformers notoriously unstable. The gradients can be all over the place, requiring a careful learning rate warm-up.
The community, in its wisdom, largely said “nah” and most modern implementations (like GPT and its descendants) use Pre-LayerNorm. You normalize the input before throwing it into the attention or FFN sub-layer. This is dramatically more stable. It’s like giving your brilliant but erratic friend (the attention mechanism) some Adderall before the exam. The output is simply added to the sub-layer’s input via a residual connection, which helps with gradient flow through the deep stack.
The difference is no joke. Pre-LN often lets you train without any learning rate warm-up at all. It’s a clear case of the literature improving on the original design.
class TransformerEncoderLayer(nn.Module):
def __init__(self, embed_size, heads, expansion_factor=4, dropout=0.1):
super(TransformerEncoderLayer, self).__init__()
self.attention = SelfAttention(embed_size, heads)
self.norm1 = nn.LayerNorm(embed_size)
self.norm2 = nn.LayerNorm(embed_size)
self.ffn = FeedForward(embed_size, expansion_factor)
self.dropout = nn.Dropout(dropout)
def forward(self, x, mask):
# Pre-LayerNorm (The modern, stable choice)
normalized_input = self.norm1(x)
attention_out = self.attention(normalized_input, normalized_input, normalized_input, mask)
x = x + self.dropout(attention_out) # Residual connection
normalized_x = self.norm2(x)
ffn_out = self.ffn(normalized_x)
x = x + self.dropout(ffn_out) # Another residual connection
return x
See that? We normalize the input first, then do the operation, then add the original input (x) back. This structure is repeated twice and is the absolute core of the transformer encoder. Stack about 6 or 12 of these TransformerEncoderLayer modules, and you’ve got a powerhouse that can understand the intricate relationships in your data, all while keeping its gradients in check.