2.4 Semi-Supervised and Self-Supervised Learning
Right, so you’ve got your supervised learning (labeled data, the gold standard) and your unsupervised learning (no labels, just a messy pile of stuff). But what if I told you there’s a middle ground? A place where you can leverage a mountain of cheap, unlabeled data with just a handful of precious labeled examples? Welcome to the world of semi-supervised and self-supervised learning, where we’re not above cheating a little to get the job done.
Think of it this way: getting a human to label a million images of cats is expensive and soul-crushing. But downloading a million random images from the internet? That’s the easy part. These techniques are all about designing clever algorithms that can learn the underlying structure of the data from the unlabeled mountain and then use the labeled hill to attach meaningful names to that structure.
The Semi-Supervised Playbook: Learning from the Labeled and Unlabeled Crowd
Semi-supervised learning (SSL) isn’t one algorithm; it’s a strategy. The core assumption, and it’s a big one we have to make, is that the data lives on a manifold. Fancy term, simple idea: it means that data points of the same class are closer together in their natural structure than points of different classes. If this isn’t true, SSL will fail spectacularly. But for things like images, audio, and text, it usually holds up.
One classic SSL technique is label propagation. It’s like starting a rumor in a crowd. You tell a few people a secret (the labels), and they whisper it to their closest neighbors (the similar unlabeled data points). Eventually, the rumor spreads through the entire crowd of unlabeled data. Here’s a taste with scikit-learn:
import numpy as np
from sklearn import datasets
from sklearn.semi_supervised import LabelPropagation
from sklearn.metrics import accuracy_score
# Let's use the Iris dataset, but pretend we only labeled 10% of it
iris = datasets.load_iris()
labels = np.copy(iris.target)
random_unlabeled_points = np.random.rand(len(labels)) < 0.9 # Mask for 90% unlabeled
labels[random_unlabeled_points] = -1 # scikit-learn uses -1 for 'unlabeled'
print(f"Number of labeled points: {np.sum(labels != -1)}")
# Create and train the label propagator
label_prop_model = LabelPropagation(kernel='knn', n_neighbors=10)
label_prop_model.fit(iris.data, labels)
# See how well it predicted the *true* labels for the unlabeled data
true_labels = iris.target[random_unlabeled_points]
predicted_labels = label_prop_model.transduction_[random_unlabeled_points] # transduction_ holds the learned labels for all data
print(f"SSL Accuracy on 'unlabeled' data: {accuracy_score(true_labels, predicted_labels):.3f}")
The magic is in the transduction_ attribute. The model doesn’t just learn a function; it learns the labels for the specific unlabeled data you gave it. This is great for a fixed dataset but you can’t use it on new, unseen data directly—that’s why it’s called transductive learning. For new data, you’d need a different approach, like training a separate classifier on the now-fully-labeled set.
Self-Supervised Learning: The Art of Making Your Own Homework
If semi-supervised is using a few real labels, self-supervised is the ultimate power move: it invents its own labels from the data itself. It’s the student who writes their own practice exam, aces it, and is then perfectly prepared for the real thing. The most stunning successes in modern AI, like Large Language Models (GPT, etc.) and advanced image models, are built on this principle.
The recipe is genius in its simplicity:
- Take your unlabeled data.
- Invent a pretext task—a simple puzzle that forces the model to learn meaningful representations of the data just to solve it.
- Train a model on this fake task.
- Take the internal representations (embeddings) the model learned and use them for your actual downstream task (e.g., image classification) with a few real labels.
For text, the pretext task is often “predict the next word” (language modeling). For images, a common pretext task is contrastive learning. Here’s the idea: you take an image, apply two different random transformations to it (e.g., cropping, color jitter, rotation). These two altered versions are a “positive pair.” You then feed both through a neural network and train it to output similar representations for this pair, while outputting dissimilar representations for all other images in the batch (the “negatives”). The network learns that a dog is a dog, no matter how you crop it or mess with its colors.
# Simplified PyTorch-inspired pseudocode for contrastive learning (e.g., SimCLR framework)
import torch
import torch.nn as nn
# Your neural network 'f' (e.g., a ResNet)
encoder = MyResNet()
projection_head = nn.Sequential(nn.Linear(encoder.output_dim, 512), nn.ReLU(), nn.Linear(512, 128))
# For a batch of images 'x'
x_i = random_transform(x) # Augmented version 1
x_j = random_transform(x) # Augmented version 2
# Get their projections
h_i = projection_head(encoder(x_i)) # shape: [batch_size, 128]
h_j = projection_head(encoder(x_j))
# Normalize to use cosine similarity
h_i = nn.functional.normalize(h_i, dim=1)
h_j = nn.functional.normalize(h_j, dim=1)
# Calculate contrastive loss - push positive pairs together, negative pairs apart
temperature = 0.1
similarity_matrix = torch.mm(h_i, h_j.T) / temperature
# ... loss calculation involves comparing diagonals (positives) to off-diagonals (negatives)
After pre-training on this invented task, you can throw away the projection head, take the powerful encoder, and stick a simple linear classifier on top. You’ll find it learns incredibly accurate models with shockingly few real labeled examples. It’s learned what a “good representation” of an image looks like, all on its own.
The Pitfalls: Where the Magic Fizzles
This isn’t a free lunch. Get these things wrong and your model will learn nothing but bad habits.
- Violating the Manifold Assumption: If your unlabeled data is a complete, random mess with no coherent structure, SSL can’t create structure from nothing. It will confidently assign wrong labels.
- Distribution Mismatch: Your unlabeled mountain and labeled hill must be from the same mountain. If you pre-train on general web images but fine-tune on medical X-rays, the representations might be worse than useless. Garbage in, garbage out.
- Confirmation Bias: This is the big one in SSL. The model makes predictions on the unlabeled data. If it’s confident but wrong early on, it will reinforce its own mistakes in the next iteration, leading to a catastrophic feedback loop. Techniques like adding dropout during pseudo-labeling or using very conservative confidence thresholds are crucial to avoid this.
- The Complexity Tax: These methods add significant algorithmic and computational complexity. It’s often simpler and faster to just spend your engineering budget on collecting more labeled data or building a better supervised model. Always ask: is the performance gain worth the extra hassle?
The takeaway? Don’t treat your unlabeled data as worthless. With the right tricks, it’s the most valuable asset you have. Use it wisely.