17.3 LSTM: Forget Gate, Input Gate, Output Gate, and Cell State
Right, so you’ve hit the wall with the basic RNN. You’ve watched it valiantly try to remember what happened more than three steps ago in a sequence, only to see its memory either vanish into nothingness or explode into a chaotic mess of NaNs. This is the infamous vanishing/exploding gradient problem, and it’s why simple RNNs are, frankly, useless for most real-world tasks.
The Long Short-Term Memory network, or LSTM, is the brilliant, slightly over-engineered solution to this problem. It’s a RNN with a more complex internal cell structure. Instead of just a simple tanh layer, it has a carefully regulated memory system, complete with gates. Think of it less like a neuron and more like a tiny, efficient bureaucracy inside each cell, with forms to fill out in triplicate for any memory operation. It’s convoluted, but it works.
The key to the LSTM’s success is its cell state, denoted as $C_t$. This is the network’s long-term memory highway. Information can flow down this highway relatively unchanged if we let it. The entire gate mechanism exists to carefully add, remove, or modulate the information traveling on this road. The gates themselves are just layers that output numbers between 0 and 1 (thanks to a sigmoid activation), controlling how much of something should pass through.
The Forget Gate: “Do I need to remember this?”
The first thing our LSTM cell does is decide what to throw away from its long-term memory. It looks at the new input $x_t$ and the previous hidden state $h_{t-1}$ (the short-term memory) and makes a decision.
$$f_t = \sigma(W_f \cdot [h_{t-1}, x_t] + b_f)$$
This forget vector $f_t$ is multiplied element-wise with the previous cell state $C_{t-1}$. A value of 1 means “keep everything!” A value of 0 means “forget it completely.” Usually, it’s somewhere in between. This is how the network learns to reset its memory when a new sentence starts or a new time series pattern begins.
The Input Gate: “What new information should we store?”
Simultaneously, the cell decides what new information is going to be stored in the long-term memory. This is a two-part process.
First, the input gate layer ($i_t$) decides which values we’ll update (again, a vector of numbers between 0 and 1).
$$i_t = \sigma(W_i \cdot [h_{t-1}, x_t] + b_i)$$
Second, a tanh layer creates a vector of new candidate values, $\tilde{C}_t$, that could be added to the state.
$$\tilde{C}t = \tanh(W_C \cdot [h{t-1}, x_t] + b_C)$$
We then combine these two. The cell state isn’t just overwritten; it’s updated. We multiply the candidate values by the input gate’s decisions, so we only add the parts we’ve deemed important.
Updating the Cell State
Now we have all the pieces to update our long-term memory highway, the cell state.
$$C_t = f_t * C_{t-1} + i_t * \tilde{C}_t$$
Look at that. It’s beautifully direct. The old state is multiplied by what we decided to forget, and then we add the new candidate values, scaled by what we decided to update. This additive nature is the secret sauce that mitigates the vanishing gradient problem. The gradient of $C_t$ with respect to $C_{t-1}$ can easily be 1 (if the forget gate is ~1), providing a clean, uninterrupted path for gradients to flow backwards during training. The network learns to set the forget gate to ~1 most of the time, defaulting to “remember everything,” and only forgets when it’s truly beneficial.
The Output Gate: “What are we going to output?”
Finally, we need to decide what the hidden state ($h_t$) will be. Remember, the hidden state is the “short-term memory” that gets passed to the next cell and is often used as the output for this timestep. We don’t want to output the entire raw cell state; it’s our internal memory. We want to output a filtered version of it.
First, we run an output gate, which looks at the input and previous hidden state to decide what parts of the cell state we’re going to output.
$$o_t = \sigma(W_o \cdot [h_{t-1}, x_t] + b_o)$$
Then, we push the cell state $C_t$ through a tanh (to squash the values between -1 and 1) and multiply it by the output gate’s decision.
$$h_t = o_t * \tanh(C_t)$$
This $h_t$ becomes our output for this step and is fed into the next cell.
Here’s what this looks like in code, because seeing the bureaucracy in action is half the battle.
import torch
import torch.nn as nn
# Let's define a single LSTM cell from scratch to see the gears turn.
class LSTMCell(nn.Module):
def __init__(self, input_size, hidden_size):
super().__init__()
self.hidden_size = hidden_size
# One big linear layer for all gates. We'll split the result later.
# This computes W_figot * [h, x] + b_figot for all four gates.
self.linear = nn.Linear(input_size + hidden_size, 4 * hidden_size)
def forward(self, x, state):
h_prev, c_prev = state
# Combine input and previous hidden state
combined = torch.cat((x, h_prev), dim=1)
# Compute all gate activations at once
gate_activations = self.linear(combined)
# Split the result into the four chunks: forget, input, candidate, output
f, i, g, o = torch.chunk(gate_activations, 4, dim=1)
# Apply activations
forget_gate = torch.sigmoid(f)
input_gate = torch.sigmoid(i)
candidate_cell_state = torch.tanh(g)
output_gate = torch.sigmoid(o)
# Update cell state: forget old stuff, add new candidate stuff
c_t = (forget_gate * c_prev) + (input_gate * candidate_cell_state)
# Produce new hidden state from the updated, filtered cell state
h_t = output_gate * torch.tanh(c_t)
return (h_t, c_t)
# Example usage:
batch_size = 2
input_size = 3
hidden_size = 5
x_t = torch.randn(batch_size, input_size) # Current input
h_prev = torch.randn(batch_size, hidden_size) # Previous hidden state
c_prev = torch.randn(batch_size, hidden_size) # Previous cell state
cell = LSTMCell(input_size, hidden_size)
h_t, c_t = cell(x_t, (h_prev, c_prev))
print(f"New hidden state shape: {h_t.shape}")
print(f"New cell state shape: {c_t.shape}")
In practice, you’ll almost always use the highly optimized nn.LSTM layer, which stacks these cells and handles the sequence looping for you. But understanding this cell-level operation is non-negotiable for debugging why your model might be failing to learn. The most common pitfall is initializing the hidden state incorrectly or not detaching it between batches for truncated backpropagation through time (BPTT). Also, remember the cell state is where the true long-term memory lives; if your model is struggling to remember, look at the forget gate activations—they’re probably being too aggressive.