Summary of "Building makemore Part 3: Activations & Gradients, BatchNorm"
Summary of "Building makemore Part 3: Activations & Gradients, BatchNorm"
This video lecture continues the implementation of a character-level language model using multilayer perceptrons (MLPs), focusing on understanding neural network activations, gradients, initialization, and Batch Normalization. The goal is to build intuition about how activations and gradients behave during training, which is critical for optimizing deeper and more complex architectures like recurrent neural networks (RNNs).
Key Technological Concepts and Product Features:
- MLP Initialization and Loss Behavior:
- Initial losses can be misleadingly high due to improper initialization.
- Expected initial loss for uniform probability over 27 characters is about 3.29 (negative log probability), but improper initialization yielded ~27.
- Problem traced to logits having extreme values causing overconfident wrong predictions.
- Solution: initialize biases to zero and scale weights (e.g., multiply by 0.1) to keep logits near zero, leading to expected loss values and more stable training.
- Activation Saturation and Dead Neurons:
- The tanh activation function squashes inputs into [-1, 1].
- If many activations saturate near ±1, gradients vanish during backpropagation, halting learning ("dead neurons").
- Dead neurons occur if a neuron’s output is always in the flat region of the nonlinearity.
- Similar issues arise with sigmoid and ReLU (ReLU dead neurons occur if always inactive).
- Proper initialization to keep pre-activations near zero reduces saturation and dead neurons.
- Scaling Weights Using Fan-in and Gain:
- Multiplying inputs by weights can cause activation distributions to explode or vanish.
- Proper scaling involves dividing weights by the square root of the fan-in (number of input units).
- The paper "Delving Deep into Rectifiers" (He et al.) introduces a gain factor (e.g., √2 for ReLU) to compensate for nonlinearities discarding part of the distribution.
- For tanh, a gain of 5/3 is recommended.
- PyTorch’s
torch.nn.init.kaiming_normal_implements these initialization schemes.
- Batch Normalization (BatchNorm):
- Introduced in 2015, BatchNorm normalizes layer activations to zero mean and unit variance per batch.
- It stabilizes training of very deep networks by controlling activation statistics.
- BatchNorm standardizes activations and then applies learned scale (gamma) and shift (beta) parameters.
- Running averages of mean and variance are maintained during training to use fixed statistics at inference.
- BatchNorm couples examples in a batch, causing activations for one example to depend on others—this acts as a form of regularization but can cause bugs.
- Biases in layers before BatchNorm become redundant because BatchNorm subtracts the mean.
- Alternatives to BatchNorm (LayerNorm, GroupNorm) avoid coupling batch examples but are not covered in detail here.
- Practical Implementation Details:
- PyTorch modules for Linear layers and BatchNorm layers are explained, including parameters, buffers, and training vs. inference modes.
- Biases are typically disabled in layers preceding BatchNorm.
- Momentum in BatchNorm controls the exponential moving average for running statistics.
- BatchNorm layers are typically placed after linear or convolutional layers and before nonlinearities.
- Analysis of Activations, Gradients, and Parameter Updates:
- Visualizing histograms of activations and gradients helps diagnose issues like saturation or vanishing gradients.
- Gain tuning is critical: too low gain causes activations and gradients to vanish; too high gain causes saturation.
- Without nonlinearities, stacking linear layers collapses to a single linear transformation, limiting model expressiveness.
- Monitoring the ratio of gradient magnitude to parameter magnitude (and update to parameter ratio) over training helps tune learning rates and detect issues.
- BatchNorm reduces sensitivity to gain choice and initialization, making training more robust.
- Limitations and Future Directions:
- The current simple MLP model is limited by context length and architecture; more powerful models like RNNs and Transformers are needed for better performance.
- Initialization and backpropagation remain active research areas; no definitive solutions exist.
- BatchNorm, despite its popularity, has drawbacks and alternatives are gaining traction.
Guides and Tutorials Provided:
- Step-by-step debugging of initialization problems via loss behavior and logits distribution.
- Visualizing and interpreting histograms of activations and gradients.
- Implementing BatchNorm from scratch, including forward pass, backward pass considerations, and training/inference differences.
- Using PyTorch’s initialization utilities and modules for Linear and BatchNorm layers.
- Diagnostic tools: plotting activation saturation, gradient distributions, parameter histograms, and update-to-parameter ratios.
- Practical advice on disabling biases in layers preceding BatchNorm.
Category
Technology