14.6 Weight Initialization: Xavier, He, and Orthogonal
Right, let’s talk about the very first thing your network does before it even gets a chance to be smart: guessing. That’s all weight initialization is. You’re setting the starting values for the millions of parameters your model will spend the next however-long tweaking. Get this wrong, and you’re not just starting on the back foot; you’re starting in a different stadium, facing the wrong way.
Think of it like this: if you initialize all your weights to zero, every neuron in a layer will calculate the exact same thing on the first forward pass. On the backward pass, they’ll all get the exact same gradient. They’ll all update in the exact same way. You don’t have a hundred neurons; you have one neuron with a hundred clones. It’s a spectacular waste of compute and will never break symmetry. So, we need to start with random values. But “random” is a big, scary universe. Do we use a uniform distribution between -1 and 1? A normal distribution? This is where the math nerds (bless them) come in to save us from ourselves.
The core problem we’re solving is the vanishing/exploding gradient problem. If your initial weights are too small, the signals (and gradients) shrink to nothing as they pass through each layer. If they’re too large, the signals balloon into infinity. Neither is great for learning. We need a clever way to scale our random guesses so the variance of the outputs of a layer is roughly equal to the variance of its inputs. This is called maintaining the variance, and it’s the golden rule behind the good schemes.
The Xavier (Glorot) Initialization
This is the old guard, the reliable workhorse for its time. Proposed by Xavier Glorot and Yoshua Bengio in 2010, it was designed for layers that use tanh or sigmoid activations—the smooth, saturating kind.
The intuition is brilliantly simple: scale your random values based on the number of connections feeding into a layer (fan_in) and the number of connections coming out of it (fan_out). This accounts for both the forward and backward passes of gradients. The formula typically uses a uniform distribution scaled by $\sqrt{\frac{6}{fan_in + fan_out}}$.
Here’s how you’d see it in the wild, though you’d usually just use the built-in version:
import numpy as np
def xavier_uniform_init(fan_in, fan_out):
limit = np.sqrt(6 / (fan_in + fan_out))
return np.random.uniform(-limit, limit, size=(fan_in, fan_out))
# For a layer with 100 inputs and 50 outputs
weights = xavier_uniform_init(100, 50)
In practice, you just tell your framework to do it. In PyTorch, for a linear layer, it’s often the default for tanh/sigmoid, but you can be explicit:
import torch.nn as nn
layer = nn.Linear(in_features=100, out_features=50)
# Manually apply Xavier uniform initialization
nn.init.xavier_uniform_(layer.weight)
The He (Kaiming) Initialization
Then along came Rectified Linear Units (ReLU). ReLU is a fantastic activation function, but it has a quirk: it kills half your network (sets all negative values to zero). This halves the variance of the output, completely throwing off the careful balance Xavier initialization provides.
So, Kaiming He and his colleagues derived a new formula specifically for ReLU and its leaky cousins. It’s gloriously straightforward: just scale based on fan_in and account for the ReLU’s “variance-destroying” property with an extra factor. The common uniform version uses $\sqrt{\frac{3}{fan_in}}$.
def kaiming_uniform_init(fan_in, fan_out):
limit = np.sqrt(3 / fan_in) # Note: only fan_in for the forward pass
return np.random.uniform(-limit, limit, size=(fan_in, fan_out))
# For a ReLU layer with 100 inputs and 50 outputs
weights = kaiming_uniform_init(100, 50)
This is now the undisputed king for most modern networks using ReLU. In PyTorch, it’s the default initialization for Conv2d and Linear layers, and for good reason.
# This is the standard practice for a ReLU-based network
nn.init.kaiming_uniform_(layer.weight, nonlinearity='relu')
Orthogonal Initialization
Now for the fancy one. Sometimes, you don’t just want your weights to be well-scaled; you want them to be well-behaved. Orthogonal initialization generates a matrix where all the row vectors are not only normalized but also orthogonal to each other. This means they don’t represent redundant information.
Why is this useful? It helps preserve the norm of the gradients during backpropagation, which is a more direct way of fighting the vanishing/exploding gradient problem, especially in deep networks and Recurrent Neural Networks (RNNs). It’s like starting with a perfectly balanced, non-redundant set of features. You can think of it as the “premium, artisanal” initialization scheme.
def orthogonal_init(shape):
# This is a simplified version. Actual impls use QR decomposition.
flat_shape = (shape[0], np.prod(shape[1:]))
a = np.random.normal(0.0, 1.0, flat_shape)
u, _, v = np.linalg.svd(a, full_matrices=False)
# Pick the one with the correct shape
q = u if u.shape == flat_shape else v
return q.reshape(shape)
# Initializing a weight matrix for an RNN cell
weights = orthogonal_init((256, 256))
In PyTorch, it’s a one-liner, and it’s a godsend for RNNs:
nn.init.orthogonal_(rnn_cell.weight_hh) # For the hidden-to-hidden weights
Best Practices and Pitfalls
Here’s the crucial part: stop guessing. The frameworks have sensible defaults for a reason. nn.Linear in PyTorch uses Kaiming Uniform. That’s what you want 95% of the time for a standard feedforward or convolutional net.
The biggest pitfall is mismatching your initialization to your activation function. Using Xavier with ReLU is a common rookie mistake that will leave your network struggling to learn from the start. Conversely, using He initialization with tanh will likely lead to unstable training as the inputs to the saturating activation will be too large.
The edge case is when you have a weird, custom activation function. In that case, you might need to derive your own scaling factor or, more realistically, just experiment with a few from the standard toolkit (Xavier, He) and see which one works best. It’s less science and more alchemy at that point. But for the standard stuff, just follow the recipe: ReLU -> He, Tanh/Sigmoid -> Xavier, RNNs -> Orthogonal. It’ll save you a week of debugging for absolutely no reason. Trust me, I’ve wasted that week so you don’t have to.