19.6 Multi-Task Learning: Sharing Representations Across Tasks
Right, so you’ve mastered the art of fine-tuning a pre-trained model on a single new task. It’s a fantastic trick, but let’s be honest: it feels a little… single-minded. What if you don’t just want your model to be good at one thing? What if you want it to be a multi-talented savant, capable of looking at an image and simultaneously telling you what’s in it (classification), where the objects are (bounding box detection), and perhaps even tracing their outlines (segmentation)?
Welcome to Multi-Task Learning (MTL). The core idea is gloriously simple yet powerful: instead of training a separate model for each of those tasks, you train one single model to do all of them at once. It’s the Swiss Army knife of deep learning. By sharing a common representation (that big, fancy feature extractor backbone like a ResNet or a Vision Transformer) across multiple tasks, you force it to learn a more general, robust, and fundamentally useful understanding of your data. It’s like learning to be a chef, a plumber, and an electrician all at once—you develop a much deeper understanding of “a house” than if you’d just specialized in hanging cabinets.
The Architectural Blueprint: Hard vs. Soft Parameter Sharing
The first design choice you face is how, exactly, you’ll share the model’s “brain.” There are two main schools of thought here.
Hard parameter sharing is the most common approach, and it’s probably what you intuitively imagine. We use one single backbone network to extract features from the input. Then, for each task, we attach a separate “head” network (usually just a few layers) that takes those shared features and produces a task-specific output.
import torch
import torch.nn as nn
import torchvision.models as models
class MultiTaskModel(nn.Module):
def __init__(self, num_classes, num_boxes=4): # num_boxes for bounding box regression
super().__init__()
# Load a pre-trained backbone with its weights
backbone = models.resnet50(weights='IMAGENET1K_V2')
# Chop off the original classifier head
self.feature_extractor = nn.Sequential(*list(backbone.children())[:-2])
# Freeze early layers if you want, a classic fine-tuning trick
for param in list(self.feature_extractor.children())[:5]:
param.requires_grad = False
# Task-specific heads
# Head 1: Classification (uses global average pooling and a linear layer)
self.classifier = nn.Sequential(
nn.AdaptiveAvgPool2d((1, 1)),
nn.Flatten(),
nn.Linear(2048, num_classes) # ResNet50 final features are 2048-dim
)
# Head 2: Bounding Box Regression (also uses GAP and a linear layer)
self.bbox_regressor = nn.Sequential(
nn.AdaptiveAvgPool2d((1, 1)),
nn.Flatten(),
nn.Linear(2048, num_boxes) # Predicts (x_center, y_center, width, height)
)
def forward(self, x):
# Get shared features
shared_features = self.feature_extractor(x)
# Pass through each head
class_logits = self.classifier(shared_features)
bbox_coords = self.bbox_regressor(shared_features)
return class_logits, bbox_coords
This is brilliant because it drastically reduces the risk of overfitting. The shared backbone has to find patterns that are useful for all tasks, which acts as a very effective regularizer. It’s the reason MTL often outperforms training separate models, especially when you have limited data for each individual task.
Soft parameter sharing is the more academic, slightly-weirder cousin. Here, each task has its own backbone, but we add constraints or loss terms to encourage the parameters of these backbones to be similar. It’s more flexible but also more parameter-heavy and finicky. You see it less in practice because, frankly, hard sharing usually works better and is far simpler to implement.
The Loss Function: Juggling Priorities
Here’s where the rubber meets the road, and where most people mess it up. Your model now has multiple objectives. How do you combine the classification loss (e.g., CrossEntropy) and the regression loss (e.g., SmoothL1 or MSE for bounding boxes) into one single number to minimize?
The naive approach is to just sum them: Total Loss = Loss_task1 + Loss_task2. This is a disaster waiting to happen. Why? Because the scales of these losses are almost certainly different. If your classification loss is around 1.0 and your regression loss is around 10.0, the model will overwhelmingly focus on improving the regression task and all but ignore the classification task. You’ve just built a brilliant bounding box predictor that has no idea what’s in the boxes.
You need to balance them. Here are your main weapons:
Loss Weighting: This is the most common and effective practice. You manually set scalars (
λ1,λ2) to balance the contribution of each loss.Total Loss = λ1 * Loss_classification + λ2 * Loss_regressionFinding good weights is a bit of a dark art. You can treat them as hyperparameters and grid search, which is expensive but solid. A good rule of thumb is to set them so the losses are roughly on the same scale at the start of training.
Gradient Normalization: Fancy, auto-balancing methods like GradNorm dynamically adjust these weights during training based on the rate of learning of each task. It’s cool, but it adds complexity and can be unstable. Start with manual weighting. Get it working, then get clever.
# Example training loop snippet with manual loss weighting
criterion_cls = nn.CrossEntropyLoss()
criterion_reg = nn.SmoothL1Loss()
# These are the magic numbers you get to tune!
lambda_cls = 1.0
lambda_reg = 0.5 # Regression loss is typically larger, so we scale it down
for images, cls_labels, bbox_labels in dataloader:
optimizer.zero_grad()
cls_logits, bbox_preds = model(images)
loss_cls = criterion_cls(cls_logits, cls_labels)
loss_reg = criterion_reg(bbox_preds, bbox_labels)
total_loss = (lambda_cls * loss_cls) + (lambda_reg * loss_reg)
total_loss.backward()
optimizer.step()
The Pitfalls and The Payoff
Let’s be direct about the downsides. MTL is not a free lunch.
- Task Conflict: This is the big one. Sometimes, what’s good for one task is bad for another. The gradients can literally point in opposite directions. If your tasks are too dissimilar (e.g., training a model to classify images and generate poetry), the shared representation might become a confused mess that helps neither. The solution is to choose related tasks. Classification, detection, and segmentation are a classic combo because they all build on a hierarchical understanding of visual features.
- The Architecture Saddle: You are now responsible for designing not one, but multiple output heads. You need to think about the granularity of the shared features. Do you branch off after an early layer for a low-level task (like edge detection) and from a later layer for a high-level task (like classification)? This requires actual thought about your problem domain.
- Data Logistics: Your training data now needs to be annotated for every single task you care about. If you have images with class labels but no bounding boxes, you can’t use them for the MTL model. This “multi-task dataset” is often the hardest part to acquire.
But when it works? It’s pure magic. You get a more data-efficient, robust model that often generalizes better to unseen data because its internal representation is so grounded. You also get a massive efficiency win at inference time: one forward pass through the shared backbone gives you predictions for all tasks. It’s the difference between hiring a whole team of specialists and finding one brilliant polymath who can do it all. And honestly, who wouldn’t want that?