80.5 Custom Modules with nn.Module
Right, so you’ve graduated from nn.Sequential and are ready to build something that doesn’t look like a straight line. Welcome to nn.Module, your new best friend and the absolute bedrock of any non-trivial model in PyTorch. Think of it as your own personal LEGO box. nn.Sequential gives you pre-built, boring little cars. nn.Module gives you the bricks, the weird angled pieces, and even that one-piece cockpit window you can never find. It’s how you build the Millennium Falcon instead of a go-kart.
The core idea is laughably simple: you create a class that inherits from nn.Module, you define your building blocks in its __init__ method, and you spell out how data flows through them in the forward method. PyTorch’s genius is that it handles the backward pass for you automatically via autograd, so you only have to define the forward pass. It’s like only having to build the domino chain; PyTorch knocks them over and records exactly how each one fell.
Your First Custom Module: A Fancy Perceptron
Let’s not get ahead of ourselves. Here’s the “Hello, World!” of nn.Module: a simple multi-layer perceptron with a twist. We’ll add a skip connection, because why not? It’s 2023.
import torch
from torch import nn
class FancyMLP(nn.Module):
def __init__(self, input_size, hidden_size, output_size):
super().__init__() # This is non-negotiable. Do it. Always.
# Define your submodules. These are your LEGO bricks.
self.layer1 = nn.Linear(input_size, hidden_size)
self.layer2 = nn.Linear(hidden_size, hidden_size)
self.layer3 = nn.Linear(hidden_size, output_size)
self.relu = nn.ReLU()
# This is our twist: a skip connection from input to before the final layer
self.skip = nn.Linear(input_size, hidden_size)
def forward(self, x):
# Remember the input for the skip connection
identity = x
# Standard forward pass
out = self.layer1(x)
out = self.relu(out)
out = self.layer2(out)
out = self.relu(out)
# Here's the skip: we project the original input to the right size and add it
skip_out = self.skip(identity)
out = out + skip_out # The skip connection!
out = self.relu(out)
# Final layer
out = self.layer3(out)
return out
# Let's instantiate it and see what it does
model = FancyMLP(10, 20, 5)
dummy_input = torch.randn(32, 10) # A batch of 32 samples, 10 features each
output = model(dummy_input)
print(output.shape) # torch.Size([32, 5])
See? Not so bad. The super().__init__() call is crucial—it’s what registers your class with PyTorch’s internal machinery. Forget it, and your model will be about as useful as a chocolate teapot.
Why nn.Module is a Control Freak’s Dream
You might be wondering, “Why can’t I just define these layers as global variables and write a function?” Because nn.Module does a ton of black magic for you under the hood. When you assign an nn.Linear or another nn.Module as an attribute (like self.layer1), PyTorch automatically registers it. This allows the base nn.Module class to:
- Track all parameters: Call
model.parameters()and you get a generator of every single weight and bias in every layer, even layers nested five classes deep. Try doing that manually without pulling your hair out. - Move everything to GPU/CPU seamlessly:
model.to('cuda')recursively sends every parameter and buffer to the GPU. It’s glorious. - Handle training/eval modes:
model.train()andmodel.eval()recursively set modes for all submodules, which is vital for layers likenn.Dropoutandnn.BatchNorm1dthat behave differently during training and inference.
The forward vs. __call__ Misdirection
Here’s a common point of confusion. You define a forward method, but you call the model with model(x). What gives? If you look at the source code (which you should, it’s brilliant), nn.Module defines a __call__ method. This __call__ method does some important pre- and post-processing (hooking, gradient tracking) and then calls your forward method. This is why you should never call model.forward(x) directly. You’ll bypass all that machinery and things like hooks or gradient tracking will break silently. It’s a classic “use the public interface” situation. Always call the object, model(x).
The Register: Parameters, Buffers, and Submodules
Not everything in your model is a parameter. Sometimes you have a learned parameter (a weight), and sometimes you have a persistent state that isn’t learned (like a running mean in Batch Norm). nn.Module helps you manage both.
class ModelWithBuffer(nn.Module):
def __init__(self):
super().__init__()
self.weight = nn.Parameter(torch.randn(10, 5)) # A learned parameter
self.register_buffer('running_mean', torch.zeros(5)) # A non-learned state
# You can also register modules or parameters later
self.linear = None
def setup_linear(self, input_size):
# This is a bit of a niche trick, but it shows how registration works
self.linear = nn.Linear(input_size, 10)
# The parameters within 'linear' are automatically registered because it's an nn.Module
model = ModelWithBuffer()
# This will include 'weight' and all parameters inside 'linear' (once it's created)
print(list(model.parameters()))
# This will include 'running_mean'
print(list(model.buffers()))
Use nn.Parameter for tensors you want to be learned and optimized. Use register_buffer for tensors that need to be saved with the model’s state dict but shouldn’t receive gradients.
Common Pitfalls: How to Shoot Yourself in the Foot
- Forgetting
super().__init__(): I said it was non-negotiable. I meant it. Your model will break in bizarre ways. - Using Python lists instead of
nn.ModuleList: If you put modules in a standard Python list,model.parameters()won’t see them. PyTorch can only find modules registered as attributes. Usenn.ModuleListfor lists of modules.# BAD: Parameters in this list will be invisible. self.layers = [nn.Linear(10, 10) for _ in range(5)] # GOOD: Parameters are properly registered. self.layers = nn.ModuleList([nn.Linear(10, 10) for _ in range(5)]) - Trying to be too clever in
__init__: The__init__method should just define your components. Don’t start doing heavy data processing or making network calls in there. Keep it simple: define layers, parameters, and buffers. - Ignoring the device placement: If you create tensors inside your
nn.Module(like for a buffer) usingtorch.zeros, they’ll be created on the CPU. If you then callmodel.to('cuda'), your registered parameters will move to GPU, but that buffer you created withtorch.zeroswill stubbornly stay on the CPU, causing device mismatch errors. The solution is to useself.register_bufferor be more careful by usingtorch.zeros(..., device=self.weight.device)later on.
The power of nn.Module is that it composes. You can create a FancyMLP module and then use it as a building block inside a larger UltimateVisionModel module. This composability is what makes PyTorch so incredibly flexible for research. You’re not just stacking layers; you’re building a hierarchy of concepts. Now go build something weird.