17.5 Bidirectional RNNs
Right, so you’ve got vanilla RNNs, LSTMs, and GRUs under your belt. You understand that they process sequences step-by-step, like a person reading a sentence from left to right. This is great, until you realize a massive flaw: the word you’re trying to understand right now is often best explained by the words that come after it.
Think about it. In the sentence “The food was terrible and absolutely…”, you can probably guess the next word is something like “disgusting.” Your model, processing left-to-right, has all the context it needs. But what about in the sentence “Despite the terrible reviews, we decided to go to the restaurant anyway”? The word “despite” at the beginning completely changes the emotional context of “terrible” later on. A standard RNN processing the sequence left-to-right would have already passed “terrible” by the time it gets the “despite” context. It’s like trying to understand a punchline without having heard the setup. This is where we stop being polite and start getting real: we go bidirectional.
The core idea is brilliantly simple, almost “why didn’t I think of that?” material. You train not one, but two separate RNN layers on your input sequence. The first layer, the good ol’ faithful, reads the sequence from start to end (forward). The second, the rebellious one, reads it from end to start (backward). At each time step, we combine (usually concatenate) the outputs from both the forward and backward layers. This gives the downstream layers a complete view of the entire sequence for every single element—the full context, both past and future.
Why Concatenation is Your Default Choice
You’ll see merge_mode='concat' everywhere, and for good reason. It’s the most expressive option. By sticking the forward hidden state [h_f] and the backward hidden state [h_b] together into one big vector [h_f, h_b], you’re giving the next layer all the raw information to work with. It can learn to weight the forward and backward contexts differently for its specific task. Summing or averaging, while computationally cheaper, forces an immediate, fixed combination of the states, which can lose information. You didn’t go through the trouble of running two RNNs just to mush their outputs together prematurely. Concatenate first, ask questions later.
Here’s how you wield this power in TensorFlow. It’s almost embarrassingly straightforward, which is a credit to the API designers (for once).
import tensorflow as tf
# Let's build a Bidirectional LSTM for sentiment analysis
model = tf.keras.Sequential([
tf.keras.layers.Embedding(input_dim=10000, output_dim=128), # Embed our words
# The magic happens here. Wrap any RNN layer in Bidirectional.
tf.keras.layers.Bidirectional(tf.keras.layers.LSTM(64, return_sequences=True)),
# Note: The output is now 128-dim because 64 (forward) + 64 (backward) = 128.
tf.keras.layers.Dropout(0.5),
tf.keras.layers.Bidirectional(tf.keras.layers.LSTM(32)), # This one returns only final output
tf.keras.layers.Dense(1, activation='sigmoid') # Classify positive/negative
])
model.compile(loss='binary_crossentropy', optimizer='adam', metrics=['accuracy'])
model.summary()
Run that summary and look at the output shapes. You’ll see the first BiLSTM layer outputs a tensor of (None, None, 128)—that’s (batch, timesteps, 64_fw + 64_bw). It’s the clearest proof of what’s happening under the hood.
The I/O Shape Headache (And How to Avoid It)
This is the part where the brilliant idea meets the frustrating reality of tensor shapes. The most common pitfall is forgetting the return_sequences argument. A Bidirectional wrapper will faithfully pass through the argument you set on its inner RNN layer.
- If
return_sequences=True, you get an output for every timestep. This is mandatory for any stacked RNN layers (so the next RNN has a sequence to process) and for tasks like named entity recognition where you need a prediction for every word. - If
return_sequences=False(the default), you only get the output for the last timestep, which is a concatenation of the forward final state and the backward final state. This is what you want for sequence-level classification, like our sentiment model above.
Mess this up, and you’ll get a shape error so cryptic it might make you question your life choices. Always, always sketch out the shape of your data as it flows through the model.
When Not to Be Bidirectional
Hold on, don’t just make everything bidirectional. It’s not free. It literally doubles the number of parameters and compute requirements for the RNN layers. More importantly, it breaks causality. You cannot use bidirectional RNNs for real-time prediction or any kind of sequential generation (like ChatGPT). Why? Because to predict the next word, the model would need to see the future words, which is a neat trick if you can manage it but generally frowned upon. Bidirectional models are for analysis of complete sequences: machine translation (of a full sentence), text classification, speech recognition on a recorded clip. If you’re generating the sequence as you go, you’re stuck moving forward, like the rest of us.