Right, so you want to understand Recurrent Neural Networks. Let’s start with the classic version, the one that’s conceptually simple but practically a bit of a diva: the Vanilla RNN. It’s called “vanilla” not because it’s plain, but because it’s the fundamental flavor that all the fancy ones (LSTM, GRU) are desperately trying to improve upon. Think of it as the Icarus of neural networks—beautiful in its ambition, but it has a nasty habit of flying too close to the sun and having its wings melt. We’ll get to that.

The core idea is brilliant and, in hindsight, obvious: give a neural network a memory. A standard feedforward network is stateless. You show it a picture, it gives you an answer. Show it the same picture again, it gives the same answer. It has no concept of what came before. This is useless for sequences—words in a sentence, stock prices over time, notes in a melody. The meaning of any piece of data is utterly dependent on its context.

The Vanilla RNN solves this by being stateful. It has an internal state (often called the hidden state, denoted h_t) that it updates over time. This state is a vector that acts as a sort of running summary of everything the network has seen so far. It’s the network’s short-term memory.

The Unrolled Computation Graph

This is the single most important concept to grasp. The “recurrent” part means the same little chunk of neural network gets applied over and over again, once for each step in the sequence. We “unroll” this process to visualize it.

Imagine you have a sequence of three inputs: x_0, x_1, x_2. The RNN doesn’t see them all at once. It processes them one by one, and its hidden state gets passed along like a baton in a relay race.

  1. At time step t=0: It takes the first input x_0 and an initial hidden state h_{-1} (usually initialized to a vector of zeros). It squishes them together through a function (a linear layer followed by a non-linearity like tanh) to produce a new hidden state, h_0. It might also produce an output y_0 based on h_0.

    • Computation: h_0 = tanh(W_{xh} * x_0 + W_{hh} * h_{-1} + b_h)
  2. At time step t=1: It takes the next input x_1 and the previous hidden state h_0. It does the exact same computation, but now the context from x_0 is baked into h_0, so the output is informed by both x_1 and x_0.

    • Computation: h_1 = tanh(W_{xh} * x_1 + W_{hh} * h_0 + b_h)
  3. At time step t=2: Same deal. It takes x_2 and h_1 (which contains information from x_0 and x_1) to produce h_2.

    • Computation: h_2 = tanh(W_{xh} * x_2 + W_{hh} * h_1 + b_h)

Unrolling this means drawing this process out as a chain: [h_{-1}] -> [x_0] -> [RNN Cell] -> (h_0, y_0) -> [x_1] -> [RNN Cell] -> (h_1, y_1) -> [x_2] -> [RNN Cell] -> (h_2, y_2)

The magic, and the problem, lies in the W_{hh} * h_{prev} term. This is how the network maintains its state. The weight matrix W_hh determines how much of the past state to keep and how to mix it with the new input.

Here’s what this looks like in raw PyTorch, building it from scratch so you feel the recurrence. We’ll create an RNN cell and manually loop over a sequence.

import torch
import torch.nn as nn

# Hyperparameters
input_size = 10   # e.g., size of a word embedding
hidden_size = 20  # size of the hidden state vector
sequence_length = 5
batch_size = 1

# Our weights. In practice, you'd use nn.Linear, but this is for clarity.
W_ih = nn.Parameter(torch.randn(input_size, hidden_size)) # Input to hidden
W_hh = nn.Parameter(torch.randn(hidden_size, hidden_size)) # Hidden to hidden
b_h = nn.Parameter(torch.randn(hidden_size))               # Hidden bias

# Initialize hidden state (typically zeros)
h_prev = torch.zeros(batch_size, hidden_size)

# Create a dummy input sequence of shape (sequence_length, batch_size, input_size)
inputs = torch.randn(sequence_length, batch_size, input_size)

# Manual loop through the sequence (this is the "unrolling")
hidden_states = []
for i in range(sequence_length):
    # Get input at this time step
    x_t = inputs[i] # Shape: [batch_size, input_size]

    # The core RNN equation: h_t = tanh( W_ih * x_t + W_hh * h_prev + b_h )
    h_t = torch.tanh( x_t @ W_ih + h_prev @ W_hh + b_h )

    # Store the hidden state for this step
    hidden_states.append(h_t)

    # The hidden state becomes the previous state for the next iteration
    h_prev = h_t

# Now hidden_states is a list of tensors, each [batch_size, hidden_size]
print(f"Final hidden state shape: {hidden_states[-1].shape}")

The Vanishing Gradient Problem (Where It All Goes Wrong)

Okay, you’ve seen the beautiful theory. Now let’s get to the ugly truth. Remember that W_hh matrix? It gets multiplied by itself at every single time step. To see why this is a disaster, let’s think about backpropagation.

When you calculate the gradient of the loss with respect to a weight early in the sequence (say, W_hh at t=0), you have to use the chain rule. The gradient for h_0 depends on h_1, which depends on h_2, and so on, all the way to the end. This means you get a long product of the derivatives of the tanh function and, crucially, the W_hh matrix itself.

The derivative of tanh is always less than 1.0. If W_hh has eigenvalues less than 1, this long chain of multiplications causes the gradient to shrink exponentially as it travels backwards through time. It vanishes, becoming effectively zero. The network stops learning. Long-range dependencies (e.g., a verb at the start of a sentence agreeing with a subject at the end) become impossible to learn because the signal from the loss can’t travel back that far.

Conversely, if W_hh has eigenvalues greater than 1, the gradient can explode, becoming impossibly large and causing training to destabilize. While exploding gradients can be hacked away with gradient clipping, the vanishing gradient problem is a fundamental flaw in the architecture. It’s the reason vanilla RNNs are famously terrible at learning long-range dependencies. They have the memory of a goldfish.

This inherent weakness is the entire reason the LSTM and GRU were invented. They are clever architectural workarounds designed specifically to create a “highway” for gradients to flow through time without vanishing. The vanilla RNN is the problem; those other guys are the elegant, complex, and frankly brilliant solutions. But you have to understand the problem first to appreciate the solution.