80.3 Training Loops: compile(), fit(), callbacks
Right, let’s talk about the part where your model actually learns something. You’ve built this beautiful, intricate architecture—a digital Rube Goldberg machine of tensors and activations. Now we have to feed it data and hope it doesn’t embarrass us. This is where we move from architecture to action, and Keras gives you two main paths: the quick and civilized compile() & fit() autobahn, or the gritty, manual GradientTape backroads. We’ll save the backroads for another day and focus on the highway, because frankly, it’s a marvel of engineering that you should use until you have a very good reason not to.
Think of model.compile() as the pre-flight checklist. You’re telling the model three absolutely critical pieces of information before it even thinks about looking at your data:
- How to optimize (The
optimizer): This is the algorithm that will navigate the loss landscape. Throwingadamin here is the default for a reason—it’s the all-wheel-drive, adaptive cruise control of optimizers. It’s a fantastic starting point for almost everything. - What to optimize for (The
lossfunction): This is the score, the number it will desperately try to minimize. Choosing the right one is non-negotiable. Using a loss for regression (mse) on a classification task is a great way to get a model that’s confidently wrong. - How to brag about it (The
metrics): Loss is great for the model’s internal mechanics, but it’s often not in a human-readable unit. Metrics like'accuracy'are for you to know how well it’s actually doing.
# This is the most common incantation you'll use.
# We're building a simple classifier here.
model = Sequential([
Dense(128, activation='relu', input_shape=(784,)),
Dropout(0.2), # See? Being a good friend and adding dropout early.
Dense(10, activation='softmax')
])
# The compile step: checklist complete.
model.compile(optimizer='adam',
loss='sparse_categorical_crossentropy',
metrics=['accuracy'])
The fit() function and its many knobs
Now for the main event: model.fit(). This is where the magic (read: linear algebra) happens. You give it your data and tell it how long to work. The simplest call is fit(x_train, y_train, epochs=10). But the real power, and the devilish details, are in its arguments.
validation_data: This is your single most important argument. Always use it. This is your reality check, showing you how the model performs on data it isn’t training on. Without it, you’re flying blind and will almost certainly overfit.batch_size: This controls how many samples it looks at before updating weights. Smaller batches are noisier but can generalize better. Larger batches are more computationally efficient and stable. 32 is a good default to start with.shuffle: Do you want it to shuffle the training data before each epoch? The answer is almost alwaysTrue. You don’t want the model learning the order of your data instead of the content.
# A more realistic, responsible fit() call.
history = model.fit(
x_train, y_train,
batch_size=32,
epochs=50,
validation_data=(x_val, y_val), # Non-negotiable.
shuffle=True
)
Why you can’t live without callbacks
If fit() is the engine, callbacks are the dashboard gauges, the automatic transmission, and the emergency brake. They are utilities that get called at various points during the training process, letting you intervene, monitor, and save your work without breaking stride. They are, in my professional opinion, the best part of Keras.
The most important one by far is ModelCheckpoint. Imagine training for 50 epochs and your laptop battery dies at epoch 49. Without a checkpoint, you’ve just donated electricity to the grid for nothing. With it, you can pick up right where you left off.
from tensorflow.keras.callbacks import ModelCheckpoint, EarlyStopping
# Save the best model seen during training based on validation accuracy.
checkpoint_cb = ModelCheckpoint('my_best_model.h5',
monitor='val_accuracy',
save_best_only=True,
mode='max')
# The patron saint of impatient developers: stop training when validation loss stops improving.
early_stopping_cb = EarlyStopping(monitor='val_loss',
patience=5, # Give it 5 epochs to improve before giving up.
restore_best_weights=True) # This part is crucial.
# Now fit with callbacks. Pass them as a list.
history = model.fit(
x_train, y_train,
epochs=100, # Set it high and let early stopping do its thing.
validation_data=(x_val, y_val),
callbacks=[checkpoint_cb, early_stopping_cb] # This right here is the pro move.
)
EarlyStopping with restore_best_weights=True is a power combo. It lets you set a high number of epochs but stops training once the model has plateaued, and it reverts the model’s weights to those from the epoch where it performed the best on the validation set. This saves you time and saves your model from overfitting.
The history object and what to do with it
After fit() finishes, it returns a history object. This is a goldmine. It’s a dictionary recording the loss and metrics for each epoch. Plotting this is your first and most important post-mortem.
import matplotlib.pyplot as plt
# Plot training & validation accuracy values
plt.plot(history.history['accuracy'])
plt.plot(history.history['val_accuracy'])
plt.title('Model accuracy')
plt.ylabel('Accuracy')
plt.xlabel('Epoch')
plt.legend(['Train', 'Validation'], loc='upper left')
plt.show()
# The same for loss
plt.plot(history.history['loss'])
plt.plot(history.history['val_loss'])
plt.title('Model loss')
plt.ylabel('Loss')
plt.xlabel('Epoch')
plt.legend(['Train', 'Validation'], loc='upper left')
plt.show()
If those two lines are right on top of each other, you might be underfitting. If they start together and the validation line dramatically diverges, you’re overfitting. This chart tells the story of your training run. Always look at it.