Summary of "Let's reproduce GPT-2 (124M)"
Goal
Reproduce OpenAI’s GPT-2 small (124M) in PyTorch from scratch, load the official weights, then train a fresh model to match or surpass it. Emphasis is on understanding the architecture, exact weight/layout conventions, the training recipe, and practical performance tuning for modern GPUs.
What was covered (technical concepts & implementation)
1. Model architecture and weights
- GPT-2 is a decoder-only Transformer (no encoder, no cross-attention).
- 124M configuration: 12 layers, hidden dim 768, 12 heads, block size 1024, vocab size 50257.
- Differences vs original Transformer:
- layer-norm placement (pre-norm in modern implementations),
- an extra final LayerNorm,
- GELU nonlinearity (original GPT-2 used a historical approximate GELU).
- Embeddings:
- Token embeddings (Wte) and learned positional embeddings (1024 × 768). Positional embeddings tend to recover sinusoidal-like structure during training.
- Weight tying:
- Token embedding and output (LM head) matrices are tied — large parameter and memory saving (~30% of params).
2. From TensorFlow weights to PyTorch
- Hugging Face Transformers provides a PyTorch implementation and converters.
- Building a small, readable PyTorch GPT class with the same naming/schema as HF simplifies weight loading and debugging.
- Some TensorFlow-sourced weights need transposing when loading into PyTorch.
3. Forward, loss, and batching
- Forward returns logits shape (B, T, V). Cross-entropy implemented by flattening to (BT, V) and (BT).
- Batching: convert a long token sequence to B×T by reshaping (contiguous slices). Create labels by offsetting inputs by 1.
- Tiny Shakespeare used as a debugging toy dataset.
4. Initialization details
- Follow OpenAI initialization: weight std ≈ 0.02, positional embeddings ~0.01; layer norm scale = 1, bias = 0.
- Residual-scale trick: scale residual block weights by 1 / sqrt(2 * n_layers) to compensate variance growth along the residual stream.
5. MLP & nonlinearities
- Standard two-layer MLP with GELU nonlinearity.
- Note on GELU: exact vs approximate implementations; approximate GELU was historically used in GPT-2 to reproduce exact behavior.
6. Attention implementation details
- Multi-head attention implemented via qkv projections, split/transpose to create heads, causal mask, softmax, weighted sum.
- Sampling: top-k sampling (k=50 default to match HF pipeline) implemented with torch.multinomial over the probability distribution.
7. Optimizer & training recipe
- Optimizer: AdamW (prefer fused implementation if available).
- betas = (0.9, 0.95), eps = 1e-8, weight_decay (GPT-3 used 0.1).
- Gradient clipping (global norm) at 1.0.
- LR schedule: cosine decay with linear warmup (paper uses warmup tokens, decay to 10% over large horizon).
- Use gradient accumulation to simulate very large batch sizes when hardware-limited.
8. Large-scale training infra & performance engineering
- Mixed precision:
- TF32 (transparent and faster on Ampere) and bfloat16 via torch.autocast.
- bfloat16 avoids many gradient-scaling headaches compared to float16.
- torch.compile: can fuse kernels and remove Python overhead for significant speedups; may interact poorly with some custom sampling/eval code.
- FlashAttention: fused attention kernel that avoids materializing the full (T×T) attention matrix using an online-softmax trick — large speedups and memory reductions.
- Kernel/block tuning:
- Pad “ugly” sizes (e.g., vocab 50257 → 50304) so CUDA kernels use nice block sizes (powers of two) for better performance.
- Memory hierarchy explained: registers / L1 / L2 / HBM, and why memory bandwidth often dominates over raw FLOPs.
- Use fused AdamW and careful parameter grouping (no weight decay on layer-norm params and biases).
9. Distributed training (multi-GPU)
- Use PyTorch DistributedDataParallel (DDP) with spawned processes (one per GPU).
- Data sharding per process to avoid duplicate work.
- Gradient accumulation with DDP:
- Avoid synchronizing gradients during intermediate micro-steps (use no_sync or similar) and only all-reduce once per optimizer step.
- Checkpoint model and optimizer state regularly.
10. Datasets and evaluation
- Dataset: FineWeb / FineWeb-edu (Hugging Face) — filtered, higher-quality subset of CommonCrawl; 10B token sample used.
- Sharded dataset storage (e.g., 100M-token shards) for efficient IO.
- Evaluation:
- Validation loss and HellaSwag multiple-choice via likelihood ranking (convert choices to candidate continuations and pick highest token likelihood / lowest avg loss).
- Results:
- With ~10B tokens and the tuned recipe, reproduced GPT-2 124M matched/surpassed some OpenAI GPT-2 124M metrics on these evals. Longer runs (40B tokens) improved HellaSwag further.
- Caveats: dataset differences, potential leakage, and data-ordering artifacts.
11. Tooling & reproducibility
- Use Jupyter inside VSCode; set seeds and deterministic device handling.
- Provided a compact, ~<100-line readable GPT implementation to replace larger HF files for learning and debugging.
- Codebase planned for release with commit history (Zero→Hero / NanoGPT-like).
- Also demonstrated a pure C/CUDA implementation (llm.c / lm.C) that reproduces behavior and can run faster per step than the PyTorch reference.
Practical step-by-step guide (high level)
- Inspect GPT-2 paper and official code; use HF Transformers for pre-converted PyTorch weights.
- Re-implement a small, readable GPT class with HF-matching key names to load the state_dict easily.
- Load weights and verify parameter shapes (token & positional embeddings, etc.).
- Implement data loader: tokenize documents, shard, form B×T batches; set labels = inputs shifted by 1.
- Implement loss (flatten logits & labels), optimizer (AdamW/fused), LR schedule with warmup, and gradient clipping.
- Debug on a tiny dataset (Tiny Shakespeare); overfit a small batch to verify optimizer and gradients.
- Move to mixed precision (autocast bfloat16 on Ampere); enable TF32 where appropriate; use torch.compile and FlashAttention; pad vocab for kernel-friendly sizes.
- Use gradient accumulation and DDP to scale effective batch size; checkpoint and evaluate periodically (val loss + HellaSwag).
- Save and log metrics; sample text occasionally to inspect generations.
Performance & cost notes
- Reproducing GPT-2 124M can be fast on modern GPUs: minimal runs can be ~1 hour and low cloud cost (rough commentary suggested roughly ~$10 / ~1 hour on contemporary cloud hardware).
- Major speedups come from:
- TF32, bfloat16 mixed precision,
- torch.compile,
- FlashAttention,
- fused optimizers,
- padding to kernel-friendly sizes,
- multi-GPU DDP,
- weight tying (embeddings).
Common issues & TODOs
- torch.compile sometimes broke sampling/eval — needs careful debugging.
- Data ordering / insufficient shuffling across shards can cause periodic loss artifacts; shuffle/permute documents across epochs to fix.
- Ensure dataset deduplication and avoid leakage for fair evaluation (leakage can inflate eval results).
Tools, libraries & references used
- PyTorch (DDP, torch.autocast, torch.compile)
- Hugging Face Transformers (pretrained weights), datasets
- FlashAttention (Tri Dao et al.)
- Fused AdamW optimizer support
- Tokenizers (GPT-2 tokenizer / tiktoken)
- CUDA / NVIDIA A100 hardware; tensor cores, bfloat16 / TF32
- NanoGPT-style references, and lm.C / llm.c implementations
- Datasets: Tiny Shakespeare (debug), FineWeb / FineWeb-edu (Hugging Face 10B sample)
- Evaluations: HellaSwag
Practical takeaways / recommendations
- For learning: reimplement a minimal, readable GPT in PyTorch (match HF key names for easier loading and debugging).
- For efficient training: use bfloat16 (or TF32 selectively), torch.compile, FlashAttention, fused AdamW, pad tensors to kernel-friendly sizes, and scale with DDP + gradient accumulation.
- Weight tying (embeddings ↔ LM head) is important for parameter efficiency and matches the original GPT practice.
- Validate with a held-out validation set plus targeted task evals (e.g., HellaSwag). Be cautious about dataset leakage and ordering artifacts.
Main speakers / sources
- Speaker / tutorial author: Andrej Karpathy (Zero→Hero series)
- Primary sources referenced:
- OpenAI GPT-2 paper & code
- OpenAI GPT-3 paper (training hyperparameters)
- Hugging Face Transformers
- FlashAttention (Tri Dao et al.)
- “Attention Is All You Need” (Vaswani et al.)
- GELU discussions
- Hugging Face FineWeb dataset
- HellaSwag paper
- NanoGPT and llm.c projects
Available extracts / auxiliary materials
- Minimal PyTorch GPT class skeleton and weight-loading steps
- Exact training loop with gradient accumulation + DDP
- FlashAttention usage example
- Recommended device/mixed-precision toggles and a hyperparameter table for 124M reproduction
Category
Technology
Share this summary
Is the summary off?
If you think the summary is inaccurate, you can reprocess it with the latest model.
Preparing reprocess...