77.4 Subplots: plt.subplots() and GridSpec
Alright, let’s talk about subplots. This is where you stop making polite, single-graph conversations with Matplotlib and start building the dashboard of your dreams (or, more commonly, the multi-panel figure your reviewer #2 demanded). The core idea is simple: you want to carve up your figure canvas (fig) into a grid and populate each cell with an axes object. The trick is doing it without pulling your hair out.
The Workhorse: plt.subplots()
For 90% of what you’ll do, plt.subplots() is your best friend. It’s a one-stop shop that returns a figure and a NumPy array of axes objects in one go. The beauty is in its simplicity.
import matplotlib.pyplot as plt
import numpy as np
# Let's create some nonsense data to plot
x = np.linspace(0, 10, 100)
y1 = np.sin(x)
y2 = np.cos(x)
y3 = np.tan(x) / 10 # Scaling tan(x) because it's a drama queen and goes to infinity
# Create a 2x2 grid of subplots
fig, axs = plt.subplots(nrows=2, ncols=2, figsize=(10, 8)) # figsize is (width, height) in inches
# Now, axs is a 2D array. Index it like you would any NumPy array.
axs[0, 0].plot(x, y1, color='teal')
axs[0, 0].set_title('The Trusty Sine Wave')
axs[0, 1].plot(x, y2, color='crimson')
axs[0, 1].set_title('The Reliable Cosine Wave')
axs[1, 0].plot(x, y3, color='purple')
axs[1, 0].set_title('The Unstable Tangent (Scaled!)')
axs[1, 0].set_ylim(-2, 2) # Crucial for taming tantrum-prone functions
# What about that pesky empty subplot in the bottom right? We can turn it off.
axs[1, 1].set_axis_off() # Handy for when your grid doesn't divide neatly into your content.
# A super common gotcha: overlapping labels. Let's fix it.
fig.tight_layout() # This automatically adjusts spacing between subplots. Use it. Love it.
plt.show()
Why is axs a 2D array? Because you told it to be with nrows=2, ncols=2. If you ask for a single row (nrows=1, ncols=4), axs becomes a 1D array. This is intuitive until you forget and try to index axs[1, 0] on a 1D array, at which point Python will throw a very justified error at you. Always check the shape of your axs array if you’re unsure (print(axs.shape)).
When subplots() Isn’t Enough: Enter GridSpec
plt.subplots() is brilliant for simple, uniform grids. But what if you need one plot to be twice the width of the others? Or a tiny inset plot in the corner? This is where matplotlib.gridspec.GridSpec flexes its muscles. It separates the grid specification from the figure creation, giving you surgical control.
Think of GridSpec as the blueprint for your figure. You define the overall grid, then tell axes exactly which cells to occupy.
from matplotlib.gridspec import GridSpec
fig = plt.figure(figsize=(12, 6))
# Define a 2x3 grid...
gs = GridSpec(nrows=2, ncols=3, figure=fig, width_ratios=[2, 1, 1], height_ratios=[1, 3])
# Now create axes by slicing the grid blueprint.
# This ax takes the entire first row and first column (which is twice as wide).
ax_main = fig.add_subplot(gs[0, 0])
ax_main.plot(x, y1, lw=3)
ax_main.set_title('Main Event (Width Ratio = 2)')
# This takes the first row, second column (normal width).
ax_small1 = fig.add_subplot(gs[0, 1])
ax_small1.plot(x, y2, color='crimson')
# This takes the first row, third column.
ax_small2 = fig.add_subplot(gs[0, 2])
ax_small2.plot(x, y3, color='purple')
ax_small2.set_ylim(-2, 2)
# This ax spans the entire second row and all three columns.
ax_bottom = fig.add_subplot(gs[1, :]) # The colon means "all columns"
ax_bottom.scatter(x, y1 + np.random.randn(100)/10, alpha=0.5, s=10) # Add some noise for a scatter
ax_bottom.set_title('A Spanned Subplot')
fig.tight_layout() # Still your friend, even with complex grids.
plt.show()
The key power here is width_ratios and height_ratios. You’re no longer locked into a prison of equally-sized boxes. You can make a subplot span multiple rows/columns with slice notation (gs[1, :] for a whole row, gs[:, 2] for a whole column, gs[0:2, 1:3] for a 2x2 block).
Best Practices and Pitfalls
Embrace
tight_layout()(orconstrained_layout=True): I’ve mentioned it twice because it’s that important. Without it, your titles, axis labels, and tick labels will happily overlap between subplots. It’s Matplotlib’s way of saying, “I’ve drawn everything exactly where you told me to, even if it’s stupid.” Usefig.tight_layout()orplt.subplots(constrained_layout=True)to automate the sanity-preserving part of layout management.Label Your Axes, You Maniac: In a multi-panel figure, never assume the reader can magically infer which Y-axis belongs to which plot. Use
ax.set_ylabel()on every relevant subplot. Your readers (and your future self) will thank you.Share Your Axes (When It Makes Sense): Plotting the same X data? Use
plt.subplots(sharex=True). This links the x-axes, so zooming in one zooms the others. It also automatically hides redundant tick labels, making things cleaner. The same exists forsharey.The
squeezeParameter: This is a weird one. By default,plt.subplots(squeeze=True)will “squeeze” a 1x1 grid of axes into a singleAxesobject, not an array. If you’re writing a function that expects to always get an array back, setsqueeze=Falseto be safe. It’s a classic “it works in testing but breaks in production” edge case.