LLM Architecture & Internals
This comprehensive technical reference covers the architecture, training, and inference of Large Language Models. We'll explore tokenization, embeddings, transformer blocks, attention mechanisms, and the full pipeline from raw text to generated outputs.
Overview: The LLM Pipeline
Key Dimensions of a Modern LLM
| Parameter | Symbol | GPT-4 | LLaMA 3 405B | Mistral 7B | GPT-3.5 |
|---|---|---|---|---|---|
| Layers | L | ~120 | 126 | 32 | 96 |
| Hidden dimension | d_model | 12,288 | 16,384 | 4,096 | 12,288 |
| Attention heads | h | 96 | 128 | 32 | 96 |
| Head dimension | d_k | 128 | 128 | 128 | 128 |
| FFN dimension | d_ff | 49,152 | 53,248 | 14,336 | 49,152 |
| Vocabulary size | V | 200,019 | 128,256 | 32,000 | 50,257 |
| Context length | T | 128,000 | 128,000 | 32,000 | 4,096 |
| Total Parameters | N | ~1.7T | 405B | 7B | 175B |
Complete Forward Pass Overview
# Complete end-to-end forward pass for LLaMA-3 405B ❶ RAW TEXT INPUT Input: "The capital of France is" ❷ TOKENIZATION (SentencePiece BPE) Algorithm: Splits text into subword units, maps to integer IDs Output: [128000, 976, 6864, 315, 9822, 374] ❸ TOKEN EMBEDDING W_E ∈ ℝ^(128256 × 16384) Shape: [6] → [6, 16384] ❹ POSITIONAL ENCODING (RoPE) Applied as rotations to Q, K vectors ❺ TRANSFORMER BLOCKS (× 126 layers) Per layer: RMSNorm → Attention → Residual → FFN → Residual ❻ OUTPUT PROCESSING Final RMSNorm → Linear projection (LM Head) Temperature-scaled softmax → probabilities ❼ AUTOREGRESSIVE LOOP Output token → append to input → repeat
2. Tokenization
Tokenization is the very first step: converting raw text into a sequence of integer IDs. The quality of the tokenizer directly affects model efficiency, multilinguality, and downstream task performance.
Byte-Pair Encoding (BPE) Algorithm
def train_bpe(text_corpus, target_vocab_size): # Step 1: Initialize with all individual bytes vocab = {bytes([i]): i for i in range(256)} while len(vocab) < target_vocab_size: # Step 2: Count all adjacent token pairs pair_freqs = count_pairs(corpus_encoded_with_vocab(vocab)) # Step 3: Find most frequent pair most_frequent = max(pair_freqs, key=pair_freqs.get) # Step 4: Merge all occurrences new_token_id = len(vocab) vocab[most_frequent] = new_token_id corpus = merge_pairs(corpus, most_frequent, new_token_id) return vocab # Example progression: "lower" → "l o w e r" → "lo w e r" → "low er" → "lower"
Tokenizer Comparison
| Tokenizer | Algorithm | Vocab Size | Used By | Compression Ratio |
|---|---|---|---|---|
| tiktoken (cl100k) | Byte-level BPE | 100,256 | GPT-4, GPT-4o | ~4 tokens/word |
| tiktoken (o200k) | Byte-level BPE | 200,019 | GPT-4o, o1 | ~3 tokens/word |
| SentencePiece BPE | BPE | 128,256 | LLaMA 3 | ~4.2 tokens/word |
| SentencePiece Unigram | Unigram LM | 32,100 | T5 | ~5 tokens/word |
| WordPiece | BPE variant | 30,522 | BERT | ~4.5 tokens/word |
Tokenization Implementation
import sentencepiece as spm # Load pre-trained tokenizer sp = spm.SentencePieceProcessor() sp.Load("tokenizer.model") # Encode text to token IDs text = "The quick brown fox jumps" token_ids = sp.Encode(text, out_type=int) # → [256, 4910, 3892, 6841, 5910] # Decode token IDs back to text decoded = sp.Decode(token_ids) # → "The quick brown fox jumps" # Get tokens as strings tokens = sp.Encode(text, out_type=str) # → ['▁The', '▁quick', '▁brown', '▁fox', '▁jumps']
Encoding Example - "unhappiness"
"unhappiness" ↓ (character split) → ["u", "n", "h", "a", "p", "p", "i", "n", "e", "ss"] ↓ (apply merge rules from training) → ["un", "h", "app", "iness"] ↓ → ["unhapp", "iness"] ↓ Final token IDs: [48291, 7274]
3. Embeddings & Positional Encoding
After tokenization, each token ID is mapped to a dense vector using an embedding lookup table. Modern LLMs use Rotary Position Embeddings (RoPE) which encode position information via rotations in the attention mechanism.
Token Embedding Implementation
import torch import torch.nn as nn class TokenEmbedding(nn.Module): def __init__(self, vocab_size, d_model): super().__init__() self.embed = nn.Embedding(vocab_size, d_model) # Shape: (128256, 16384) for LLaMA-3 def forward(self, token_ids): # token_ids: (batch, seq_len) # output: (batch, seq_len, d_model) return self.embed(token_ids) # Usage: embedding_layer = TokenEmbedding(128256, 16384) token_ids = torch.tensor([[256, 4910, 3892]]) embeddings = embedding_layer(token_ids) # (1, 3, 16384)
Rotary Position Embedding (RoPE)
RoPE encodes absolute position information by rotating query and key vectors. For a 2D subspace with indices (2i, 2i+1), the rotation matrix is applied with frequency θ_i = 10000^(-2i/d).
def apply_rope(x, freqs_cis): # x: (batch, seq_len, n_heads, head_dim) # freqs_cis: precomputed complex frequencies # View as complex pairs x_complex = torch.view_as_complex( x.float().reshape(*x.shape[:-1], -1, 2) ) # Multiply by complex exponentials (rotation) x_rotated = x_complex * freqs_cis return torch.view_as_real(x_rotated).reshape_as(x) def precompute_freqs(dim, max_seq_len, theta=10000.0): # Compute inverse frequencies: θ_i = 10000^(-2i/d) inv_freqs = 1.0 / (theta ** (torch.arange( 0, dim, 2 ).float() / dim)) # Create position indices t = torch.arange(max_seq_len) # Outer product: m × θ_i freqs = torch.outer(t, inv_freqs) # Convert to complex exponentials: e^(i·m·θ_i) return torch.polar(torch.ones_like(freqs), freqs) # Usage: freqs_cis = precompute_freqs(128, 4096) q_rotated = apply_rope(q, freqs_cis) k_rotated = apply_rope(k, freqs_cis)
RoPE Frequency Scaling for Extended Context
| Method | Approach | Trade-off | Best For |
|---|---|---|---|
| Position Interpolation (PI) | Scale positions by L_orig/L_target | Slight accuracy drop | Moderate context extension |
| NTK-aware | Increase base frequency θ | May degrade short sequences | Aggressive extrapolation |
| YaRN | Combine NTK + PI + temperature | Complex tuning | Balanced extension |
| ALiBi | Attention Linear Biases | Different architecture | Ultra-long context |
4. The Transformer Block
The Transformer block is the core computational unit. A typical LLM stacks 80-128 identical blocks, each with self-attention and feed-forward layers with layer normalization and residual connections.
Pre-Norm Transformer Block Architecture
Self-Attention Mechanism
Self-attention allows every token to attend to every other token. The computation follows: Attention(Q, K, V) = softmax(QK^T / √d_k) V
def scaled_dot_product_attention(Q, K, V, mask=None): # Q, K, V: (batch, seq_len, d_k) # 1. Compute attention scores scores = (Q @ K.transpose(-2, -1)) / sqrt(128) # shape: (batch, seq_len, seq_len) # 2. Apply causal mask if mask is not None: scores = scores.masked_fill(mask == 0, float('-inf')) # 3. Softmax to get weights weights = torch.softmax(scores, dim=-1) # 4. Weighted sum of values output = weights @ V return output # Causal mask for seq_len=4: mask = torch.tril(torch.ones(4, 4)) # [[1, 0, 0, 0], # [1, 1, 0, 0], # [1, 1, 1, 0], # [1, 1, 1, 1]]
Multi-Head Attention Implementation
class MultiHeadAttention(nn.Module): def __init__(self, d_model, n_heads): super().__init__() self.n_heads = n_heads self.d_k = d_model // n_heads self.W_q = nn.Linear(d_model, d_model) self.W_k = nn.Linear(d_model, d_model) self.W_v = nn.Linear(d_model, d_model) self.W_o = nn.Linear(d_model, d_model) def forward(self, x): # x: (batch, seq_len, d_model) batch_size, seq_len, _ = x.shape # Project to Q, K, V Q = self.W_q(x) # (batch, seq, d_model) K = self.W_k(x) V = self.W_v(x) # Reshape to multiple heads Q = Q.reshape(batch_size, seq_len, self.n_heads, self.d_k) K = K.reshape(batch_size, seq_len, self.n_heads, self.d_k) V = V.reshape(batch_size, seq_len, self.n_heads, self.d_k) # Transpose for attention: (batch, n_heads, seq, d_k) Q = Q.transpose(1, 2) K = K.transpose(1, 2) V = V.transpose(1, 2) # Apply attention attn_out = scaled_dot_product_attention(Q, K, V) # attn_out: (batch, n_heads, seq, d_k) # Concatenate heads attn_out = attn_out.transpose(1, 2) attn_out = attn_out.reshape(batch_size, seq_len, -1) # Final projection output = self.W_o(attn_out) return output
MHA vs GQA vs MQA Comparison
| Mechanism | Q Heads | KV Heads | KV Cache Size | Used By |
|---|---|---|---|---|
| Multi-Head Attention (MHA) | h | h | 1× | Original Transformer, GPT-3 |
| Grouped-Query (GQA) | h | h/g (shared) | 1/g × | LLaMA 3, Mistral |
| Multi-Query (MQA) | h | 1 | 1/h × | T5, some mobile models |
5. Attention Mechanisms & Variants
Beyond standard self-attention, several variants optimize for speed, memory, or context length. This section covers modern approaches used in production LLMs.
Flash Attention Algorithm
Flash Attention reduces memory I/O by computing attention in tiles, avoiding expensive HBM reads/writes of intermediate matrices.
# Simplified Flash Attention block computation def flash_attention_block(Q, K, V, block_size=128): # Tile-by-tile computation seq_len = Q.shape[0] output = torch.zeros_like(Q) for i in range(0, seq_len, block_size): # Load Q tile Q_tile = Q[i:i+block_size] # Compute attention to all K,V (causal) scores = Q_tile @ K[:i+block_size].T / sqrt(128) attn_weights = torch.softmax(scores, dim=-1) # Accumulate output output[i:i+block_size] = attn_weights @ V[:i+block_size] return output # Flash Attention benefits: # - 2-4× faster than standard attention on modern hardware # - Same numerical result (no approximation) # - ~3× less memory bandwidth required
Sparse Attention Patterns
| Pattern | Attention Scope | Compute | Best For |
|---|---|---|---|
| Dense | All-to-all | O(n²) | Short sequences (<4K) |
| Local | Window of size w | O(n·w) | Long sequences with local structure |
| Stride | Every k-th position | O(n²/k) | Approximate long-range |
| Log-sparse | Local + logarithmic blocks | O(n·log n) | Ultra-long sequences (>1M) |
Attention Key-Value Cache (KV Cache)
class CachedAttention(nn.Module): def __init__(self, d_model, n_heads): super().__init__() self.cache_k = None self.cache_v = None def forward(self, q, k, v, is_prefill=False): # Prefill: process entire prompt at once if is_prefill: # Save all K, V from prompt self.cache_k = k self.cache_v = v attention = scaled_dot_product_attention(q, k, v) # Decode: one token at a time else: # K, V cache already filled; only new q self.cache_k = torch.cat([self.cache_k, k], dim=1) self.cache_v = torch.cat([self.cache_v, v], dim=1) attention = scaled_dot_product_attention( q, self.cache_k, self.cache_v ) return attention # Memory savings with KV cache: # Without cache: 2·L·batch·T·h·d_k bytes for all K,V tensors # With cache: Only cache previous steps (reuse computation) # For inference: ~50-70% reduction in computation
6. Activation Functions
Activation functions introduce non-linearity into the feed-forward network. Modern LLMs typically use SwiGLU or GELU rather than ReLU.
Common Activation Functions
| Function | Formula | Used By | Properties |
|---|---|---|---|
| ReLU | max(0, x) | Original Transformer | Simple, sparse activations |
| GELU (approx) | 0.5·x·(1 + tanh(√(2/π)·(x + 0.044715·x³))) | BERT, GPT-3 | Smooth approximation of Gaussian CDF |
| SiLU (Swish) | x·sigmoid(x) | T5, LLaMA, Mistral | Smooth, self-gated |
| SwiGLU | (x·W_gate·SiLU) ⊙ (x·W_up) | LLaMA 3, Gemini | Gated linear unit with SiLU |
| Mish | x·tanh(softplus(x)) | Some research models | Unbounded, smooth |
SwiGLU Implementation
class SwiGLU_FFN(nn.Module): def __init__(self, d_model, d_ff): super().__init__() # Dimension expansion: d_model → d_ff self.gate = nn.Linear(d_model, d_ff) self.up = nn.Linear(d_model, d_ff) # Dimension reduction: d_ff → d_model self.down = nn.Linear(d_ff, d_model) def forward(self, x): # x: (batch, seq_len, d_model) # Gate: d_model → d_ff, apply SiLU activation gate = torch.nn.functional.silu(self.gate(x)) # Up projection: d_model → d_ff up = self.up(x) # Element-wise multiplication (gating) gated = gate * up # Down projection: d_ff → d_model output = self.down(gated) return output # Compute parameter count: # gate: d_model × d_ff = 16384 × 53248 = 872M params # up: d_model × d_ff = 16384 × 53248 = 872M params # down: d_ff × d_model = 53248 × 16384 = 872M params # Total: 2.6B per layer × 126 layers = 327B params (81% of 405B model)
GELU Smooth Approximation
def gelu_approx(x): # Approximation of GELU (CDF of Gaussian) return 0.5 * x * ( 1 + torch.tanh( sqrt(2.0 / math.pi) * (x + 0.044715 * x**3) ) ) def gelu_exact(x): # Use PyTorch's implementation return torch.nn.functional.gelu(x, approximate='none') # Usage in FFN: def gelu_ffn(x, W1, W2): return W2 @ gelu_approx(W1 @ x)
7. Normalization Techniques
Normalization stabilizes training by controlling activation distributions. Modern LLMs use RMSNorm instead of LayerNorm due to lower computation and similar effectiveness.
RMSNorm Implementation
class RMSNorm(nn.Module): def __init__(self, d_model, eps=1e-6): super().__init__() # Learnable scaling parameter self.weight = nn.Parameter(torch.ones(d_model)) self.eps = eps def forward(self, x): # x: (batch, seq_len, d_model) # Compute RMS (root mean square) rms = torch.sqrt( torch.mean(x ** 2, dim=-1, keepdim=True) + self.eps ) # Normalize and scale normalized = x / rms output = normalized * self.weight return output # RMSNorm is simpler and faster than LayerNorm: # - No mean subtraction (just RMS) # - One less statistical moment to compute # - ~8% faster on modern GPUs # - Used in LLaMA, Mistral, Qwen, etc.
LayerNorm for Comparison
class LayerNorm(nn.Module): def __init__(self, d_model, eps=1e-6): super().__init__() self.weight = nn.Parameter(torch.ones(d_model)) self.bias = nn.Parameter(torch.zeros(d_model)) self.eps = eps def forward(self, x): # Compute mean and variance mean = torch.mean(x, dim=-1, keepdim=True) var = torch.var(x, dim=-1, keepdim=True, unbiased=False) # Normalize normalized = (x - mean) / torch.sqrt(var + self.eps) # Scale and shift output = normalized * self.weight + self.bias return output # LayerNorm properties: # - Computes both mean and variance # - Includes bias term (RMSNorm doesn't) # - Slightly more expressive but slower
Normalization Comparison
| Technique | Mean | Variance | Parameters | Compute | Used By |
|---|---|---|---|---|---|
| BatchNorm | Across batch | Across batch | weight, bias | Medium | CNNs, ResNets |
| LayerNorm | Across hidden dim | Across hidden dim | weight, bias | Medium | BERT, GPT-3 |
| RMSNorm | None (RMS only) | Implicit | weight only | Fast | LLaMA, Mistral, Qwen |
| GroupNorm | Across group | Across group | weight, bias | Medium | Vision Transformers |
Loss Functions & Training Objectives
The loss function defines what the model learns. Nearly all modern LLMs use autoregressive next-token prediction with cross-entropy loss, but understanding the nuances — label smoothing, auxiliary losses, and alternative objectives — is critical for training and fine-tuning.
Cross-Entropy Loss (Next-Token Prediction)
def autoregressive_loss(model, input_ids): # input_ids: (batch, seq_len) — token IDs # Model predicts next token at each position # Forward pass: get logits for all positions logits = model(input_ids[:, :-1]) # (batch, seq_len-1, vocab_size) # Targets: shifted by 1 position targets = input_ids[:, 1:] # (batch, seq_len-1) # Cross-entropy loss: -log P(target_token | context) # CE = -Σ y_i · log(softmax(logit_i)) loss = torch.nn.functional.cross_entropy( logits.view(-1, logits.size(-1)), # flatten to (N, vocab) targets.view(-1), # flatten to (N,) ignore_index=PAD_TOKEN_ID # skip padding ) # Perplexity = exp(loss) — interpretable metric perplexity = torch.exp(loss) return loss, perplexity # Key insight: loss is computed per-token, averaged over all positions # A loss of 2.3 ≈ perplexity of 10 (model considers ~10 tokens equally likely) # A loss of 1.0 ≈ perplexity of 2.7 (model is quite confident)
Label Smoothing
def cross_entropy_with_label_smoothing(logits, targets, smoothing=0.1): # Instead of hard targets [0, 0, 1, 0, ...], use soft targets # [ε/V, ε/V, 1-ε, ε/V, ...] where ε = smoothing, V = vocab_size vocab_size = logits.size(-1) log_probs = torch.nn.functional.log_softmax(logits, dim=-1) # Smooth targets smooth_targets = torch.full_like(log_probs, smoothing / vocab_size) smooth_targets.scatter_(1, targets.unsqueeze(1), 1.0 - smoothing) # KL divergence as loss loss = -(smooth_targets * log_probs).sum(dim=-1).mean() return loss # Benefits: prevents overconfident predictions, improves generalization # Typical values: ε = 0.1 for pre-training, ε = 0.05 for fine-tuning # Trade-off: slightly worse perplexity but better downstream performance
Training Objective Comparison
| Objective | Formula | Used By | Notes |
|---|---|---|---|
| Causal LM (Next Token) | -Σ log P(x_t | x_<t) | GPT, LLaMA, Claude | Standard for decoder-only models |
| Masked LM (MLM) | -Σ log P(x_mask | x_context) | BERT, RoBERTa | Bidirectional; encoder-only models |
| Seq2Seq (Enc-Dec) | -Σ log P(y_t | y_<t, x) | T5, BART | Input → Output mapping |
| Prefix LM | Bidirectional prefix + causal continuation | PaLM (partially) | Hybrid attention pattern |
| Contrastive | sim(pos) / Σ sim(neg) | CLIP, SimCLR | Learn aligned representations |
| DPO (Preference) | log σ(β(r_w - r_l)) | RLHF replacement | Direct optimization from preference pairs |
Auxiliary Losses in Modern LLMs
Load Balancing Loss (MoE)
Ensures all experts receive roughly equal token counts. Without it, some experts become "dead" while others get overloaded.
# Load balance = α · N · Σ(f_i · P_i) # f_i = fraction of tokens routed to expert i # P_i = average router probability for expert i # α = 0.01 (small weight, added to main loss)
Z-Loss (Numerical Stability)
Penalizes large logit values to prevent floating-point overflow in softmax. Used by PaLM and Gemini.
# z_loss = β · log²(Σ exp(z_i)) # β = 1e-4 (very small regularizer) # Prevents logit magnitude explosion # Critical for BF16 training stability
Cross-entropy is the information-theoretic optimal objective for density estimation. Minimizing CE is equivalent to minimizing the KL divergence between the model distribution and the true data distribution. The model implicitly learns grammar, facts, reasoning patterns, and style — all from predicting the next token. This is why scaling compute on this simple objective yields emergent capabilities.
Optimizers & Learning Rate Schedules
The optimizer and learning rate schedule are the two most impactful hyperparameters after model architecture. Getting them wrong can mean training failure, instability, or 2-3× slower convergence.
AdamW: The Standard LLM Optimizer
class AdamW: """AdamW = Adam + decoupled weight decay (Loshchilov & Hutter, 2019)""" def __init__(self, params, lr=3e-4, betas=(0.9, 0.95), eps=1e-8, weight_decay=0.1): self.lr = lr self.beta1, self.beta2 = betas self.eps = eps self.wd = weight_decay # State: first moment (m), second moment (v), timestep (t) def step(self, param, grad): self.t += 1 # 1. Update biased first moment estimate (momentum) self.m = self.beta1 * self.m + (1 - self.beta1) * grad # 2. Update biased second moment estimate (adaptive LR) self.v = self.beta2 * self.v + (1 - self.beta2) * grad**2 # 3. Bias correction m_hat = self.m / (1 - self.beta1**self.t) v_hat = self.v / (1 - self.beta2**self.t) # 4. Parameter update (decoupled weight decay) param -= self.lr * (m_hat / (v_hat.sqrt() + self.eps)) param -= self.lr * self.wd * param # ← decoupled from gradient return param # Key hyperparameters for LLM pre-training: # lr: 1e-4 to 6e-4 (peak, depends on model size) # betas: (0.9, 0.95) — β2=0.95 is more stable than 0.999 for LLMs # weight_decay: 0.1 (standard for all transformer pre-training) # eps: 1e-8 (FP32) or 1e-5 (BF16 for numerical stability)
Optimizer Comparison
| Optimizer | Memory per Param | Key Innovation | Best For | Used By |
|---|---|---|---|---|
| AdamW | 8 bytes (m + v) | Decoupled weight decay | General pre-training | LLaMA, GPT, Claude |
| Adam (vanilla) | 8 bytes | Adaptive learning rates | Small models, fine-tuning | Legacy |
| Adafactor | 4 bytes (factored) | Factored second moments | Memory-constrained | T5, PaLM |
| LION | 4 bytes (m only) | Sign-based update (sign(m)) | Large batch training | Google research |
| Sophia | 8 bytes | Second-order (Hessian estimate) | Faster convergence | Research (2× faster claim) |
| SGD + Momentum | 4 bytes | Simple momentum | Fine-tuning final layers | Vision, some SFT |
| 8-bit Adam | 2 bytes (quantized states) | Quantized optimizer states | Memory savings | bitsandbytes library |
Learning Rate Schedules
def cosine_schedule_with_warmup( step, warmup_steps=2000, total_steps=100000, peak_lr=3e-4, min_lr=3e-5 ): """Standard LLM schedule: linear warmup → cosine decay""" if step < warmup_steps: # Phase 1: Linear warmup (0 → peak_lr) return peak_lr * (step / warmup_steps) else: # Phase 2: Cosine decay (peak_lr → min_lr) progress = (step - warmup_steps) / (total_steps - warmup_steps) cosine_decay = 0.5 * (1 + math.cos(math.pi * progress)) return min_lr + (peak_lr - min_lr) * cosine_decay # Warmup is critical: prevents gradient explosion in early training # Typical warmup: 0.5-2% of total steps (2000 steps for 100K total) # min_lr: usually 10× smaller than peak_lr # Some models use WSD schedule: Warmup → Stable → Decay
Schedule Comparison
| Schedule | Shape | Peak LR | Min LR | Used By |
|---|---|---|---|---|
| Cosine + Warmup | Linear up → cosine down | 3e-4 | 3e-5 | LLaMA 3, most models |
| WSD (Warmup-Stable-Decay) | Warmup → flat → rapid decay | 3e-4 | 0 | MiniCPM, some research |
| Linear Decay | Linear warmup → linear down | 1e-4 | 0 | BERT, older models |
| Inverse Square Root | 1/√t decay after warmup | 5e-4 | varies | Original Transformer |
| Constant + Decay | Flat → sharp drop at 80% | 2e-5 | 0 | SFT, fine-tuning |
For pre-training: use AdamW with cosine schedule + linear warmup. peak_lr ≈ 3e-4 × (batch_size/256)^0.5. For fine-tuning: use AdamW with constant LR + linear decay, lr ≈ 1e-5 to 5e-5 (10-100× smaller than pre-training). For LoRA: lr ≈ 1e-4 to 3e-4 (can be more aggressive since only adapters update).
Feed-Forward Networks & Gated Architectures
The FFN (feed-forward network) in each transformer block accounts for ~67% of model parameters. Modern LLMs use gated variants (SwiGLU, GeGLU) instead of the original two-layer MLP, achieving better performance at the same parameter count.
FFN Architecture Evolution
Original FFN (2017)
Two linear projections with ReLU activation. Simple but suboptimal.
# Vaswani et al. FFN # FFN(x) = W2 · ReLU(W1 · x + b1) + b2 # Expansion: d_model → 4·d_model → d_model class VanillaFFN(nn.Module): def __init__(self, d=4096): self.w1 = nn.Linear(d, 4*d) self.w2 = nn.Linear(4*d, d) def forward(self, x): return self.w2(F.relu(self.w1(x)))
SwiGLU FFN (2020+)
Gated Linear Unit with SiLU (Swish) activation. Standard in all modern LLMs. Uses 3 weight matrices instead of 2.
# SwiGLU(x) = (W1·x ⊙ SiLU(W_gate·x)) · W2 # ⊙ = element-wise multiply (gating) class SwiGLU_FFN(nn.Module): def __init__(self, d=4096, d_ff=14336): self.w1 = nn.Linear(d, d_ff) self.w_gate = nn.Linear(d, d_ff) self.w2 = nn.Linear(d_ff, d) def forward(self, x): return self.w2( self.w1(x) * F.silu(self.w_gate(x)) )
GeGLU FFN (Variant)
Same gated structure but uses GELU instead of SiLU. Used by PaLM, some research models.
# GeGLU(x) = (W1·x ⊙ GELU(W_gate·x)) · W2 class GeGLU_FFN(nn.Module): def __init__(self, d=4096, d_ff=16384): self.w1 = nn.Linear(d, d_ff) self.w_gate = nn.Linear(d, d_ff) self.w2 = nn.Linear(d_ff, d) def forward(self, x): return self.w2( self.w1(x) * F.gelu(self.w_gate(x)) )
Dimension Ratios in Real Models
| Model | d_model | d_ff (FFN hidden) | Ratio | FFN Type | FFN Params % |
|---|---|---|---|---|---|
| GPT-3 175B | 12288 | 49152 | 4.0× | ReLU (vanilla) | 67% |
| LLaMA 2 7B | 4096 | 11008 | 2.69× | SwiGLU | 67% |
| LLaMA 3 8B | 4096 | 14336 | 3.5× | SwiGLU | 68% |
| LLaMA 3 70B | 8192 | 28672 | 3.5× | SwiGLU | 68% |
| Mistral 7B | 4096 | 14336 | 3.5× | SwiGLU | 68% |
| Gemma 7B | 3072 | 24576 | 8.0× | GeGLU | 76% |
SwiGLU has 3 weight matrices (W1, W_gate, W2) instead of 2. To keep total parameters constant, the hidden dimension is reduced: d_ff = (2/3) × 4 × d_model, giving the ~2.67× ratio. Despite fewer hidden units, the gating mechanism allows more expressive feature selection, yielding ~1-2% better loss than vanilla FFN at identical FLOPs.
Mixture of Experts (MoE) Architecture
MoE models replace the dense FFN with N parallel expert FFNs, routing each token to only the top-K experts. This enables massive parameter counts (e.g., 1.8T for Switch Transformer) while keeping FLOPs manageable — only 2 of 16 experts activate per token.
MoE Architecture Diagram
MoE Implementation
class MoELayer(nn.Module): def __init__(self, d_model=4096, d_ff=14336, num_experts=16, top_k=2): super().__init__() self.experts = nn.ModuleList([ SwiGLU_FFN(d_model, d_ff) for _ in range(num_experts) ]) self.router = nn.Linear(d_model, num_experts, bias=False) self.top_k = top_k self.num_experts = num_experts def forward(self, x): # x: (batch, seq_len, d_model) batch, seq_len, d = x.shape # Router: compute expert scores per token router_logits = self.router(x) # (B, S, num_experts) # Select top-K experts per token weights, indices = torch.topk( torch.softmax(router_logits, dim=-1), k=self.top_k, dim=-1 ) # Normalize weights to sum to 1 weights = weights / weights.sum(dim=-1, keepdim=True) # Compute weighted expert outputs output = torch.zeros_like(x) for k in range(self.top_k): expert_idx = indices[:, :, k] # which expert for each token weight = weights[:, :, k].unsqueeze(-1) for e in range(self.num_experts): mask = (expert_idx == e) if mask.any(): expert_input = x[mask] expert_output = self.experts[e](expert_input) output[mask] += weight[mask] * expert_output return output
MoE Model Comparison
| Model | Total Params | Active Params | Experts | Top-K | Expert Placement |
|---|---|---|---|---|---|
| Mixtral 8×7B | 46.7B | 12.9B | 8 | 2 | Every layer |
| Mixtral 8×22B | 141B | 39B | 8 | 2 | Every layer |
| Switch Transformer | 1.6T | ~100B | 128 | 1 | Every other layer |
| DeepSeek-V2 | 236B | 21B | 160 | 6 | Shared + routed |
| Grok-1 | 314B | ~86B | 8 | 2 | Every layer |
| DBRX | 132B | 36B | 16 | 4 | Fine-grained |
MoE Training Challenges
Load Balancing
Without explicit balancing, the router often collapses — sending most tokens to 1-2 experts while others "die." Solutions: auxiliary load-balancing loss, expert capacity limits, jitter noise during training.
Expert Parallelism
Distribute experts across GPUs. Each GPU holds a subset of experts. Requires All-to-All communication to route tokens to the correct GPU. Communication overhead can dominate if not carefully optimized.
Capacity Factor
Each expert has a token buffer. If buffer overflows (too many tokens routed), excess tokens are dropped or sent to a shared expert. Typical capacity factor: 1.0-1.5× expected load.
Inference Overhead
MoE inference requires loading all experts into memory (full model size) even though only K activate. Makes deployment harder than dense models. Mixtral 8×7B needs ~90GB VRAM despite 12.9B active params.
MoE gives you ~3-4× the quality of an equivalently-FLOPs dense model (Mixtral 8×7B matches LLaMA 2 70B quality at 12.9B active params), but requires 3-4× more memory for serving. Use MoE when you're compute-bound (training or inference FLOPs) but not memory-bound.
Pre-training Curriculum & Dynamics
Pre-training is the most expensive phase of LLM development ($2-50M+ for frontier models). Understanding training dynamics — loss curves, data scheduling, stability issues — is essential for efficient training runs.
Typical Pre-Training Setup
| Component | LLaMA 3 8B | LLaMA 3 70B | LLaMA 3 405B |
|---|---|---|---|
| Tokens | 15T | 15T | 15T |
| Batch Size (tokens) | 4M | 8M | 16M |
| Peak LR | 3e-4 | 1.5e-4 | 8e-5 |
| Min LR | 3e-5 | 1.5e-5 | 8e-6 |
| Warmup Steps | 2000 | 2000 | 2000 |
| Context Length | 8192 | 8192 | 8192 → 128K |
| Training Steps | ~3.75M | ~1.88M | ~938K |
| Hardware | ~2K H100s | ~6K H100s | ~16K H100s |
| Training Duration | ~1 week | ~3 weeks | ~2 months |
Data Curriculum (Multi-Stage Training)
# Modern pre-training uses phased data mixtures class DataCurriculum: stages = [ { "name": "Phase 1: General Pre-training", "tokens": "0 → 12T", "mix": { "web": 0.50, # CommonCrawl, C4, refined "code": 0.17, # GitHub, StackOverflow "books": 0.15, # Gutenberg, books3 "academic": 0.08, # arXiv, semantic scholar "math": 0.05, # proofs, competition data "multilingual": 0.05 } }, { "name": "Phase 2: Quality Upsampling", "tokens": "12T → 14.5T", "mix": { "high_quality_web": 0.25, "code": 0.25, # upsampled 2× "math": 0.15, # upsampled 3× "academic": 0.15, # upsampled 2× "books": 0.10, "curated_reasoning": 0.10 } }, { "name": "Phase 3: Long Context Extension", "tokens": "14.5T → 15T", "context_length": "8K → 128K", "note": "Gradual context extension with long documents" } ]
Common Training Instabilities
| Symptom | Cause | Detection | Fix |
|---|---|---|---|
| Loss spikes | Bad data batch, gradient explosion | Loss > 3× rolling average | Skip batch, gradient clipping, reduce LR |
| Loss plateau | LR too low, data exhaustion | Loss unchanged for 5K+ steps | Increase LR slightly, add fresh data |
| NaN loss | Numerical overflow in BF16 | Immediate detection | Add z-loss, check init, reduce LR |
| Slow convergence | LR too low, bad warmup ratio | Compare to scaling law prediction | Increase peak LR, extend warmup |
| Quality regression | Data mixture imbalance, overfitting on subset | Eval on diverse held-out set | Adjust data mix, add regularization |
| Expert collapse (MoE) | Router sends all tokens to few experts | Monitor expert load distribution | Increase load balance loss weight |
Loss Curve Interpretation
Healthy Training
Loss decreases smoothly following a power law. Small fluctuations (±2%) are normal. After warmup, loss should drop ~0.5 nats in first 10% of training. Final loss for 7B model: ~1.7-1.9 nats on diverse web data.
Scaling Law Prediction
Use Chinchilla scaling: L(N,D) ≈ E + A/N^0.34 + B/D^0.28. If actual loss is >5% above prediction, something is wrong (data quality, hyperparameters, bugs). If below, you may have data contamination.
1. Wrong learning rate (too high → instability, too low → wasted compute). 2. Insufficient data quality filtering (garbage in → garbage out, no amount of scale fixes bad data). 3. Not monitoring eval metrics during training (discovering issues after weeks of compute). 4. Wrong parallelism strategy (communication overhead eating 30%+ of FLOPs). 5. No checkpoint recovery plan (hardware failure at 80% completion with no recent checkpoint = restart).
Tokenizer Chat Templates & Special Tokens
Every instruction-tuned LLM uses a specific chat template to format multi-turn conversations. Using the wrong template degrades performance significantly — the model was trained on a specific format and expects it exactly.
Common Special Tokens
| Token | Purpose | Used By | ID (typical) |
|---|---|---|---|
<|begin_of_text|> | Start of sequence | LLaMA 3 | 128000 |
<|end_of_text|> | End of sequence / stop generation | LLaMA 3, GPT | 128001 |
<|start_header_id|> | Start of role header | LLaMA 3 | 128006 |
<|end_header_id|> | End of role header | LLaMA 3 | 128007 |
<|eot_id|> | End of turn | LLaMA 3 | 128009 |
[INST] [/INST] | Instruction delimiters | Mistral, LLaMA 2 | varies |
<|im_start|> <|im_end|> | Message boundaries | ChatML format | varies |
<s> </s> | Sequence start/end | SentencePiece models | 1, 2 |
[PAD] | Padding for batch alignment | Universal | 0 (usually) |
Chat Template Formats
# LLaMA 3 Chat Template <|begin_of_text|><|start_header_id|>system<|end_header_id|> You are a helpful assistant.<|eot_id|><|start_header_id|>user<|end_header_id|> What is the capital of France?<|eot_id|><|start_header_id|>assistant<|end_header_id|> The capital of France is Paris.<|eot_id|> # Mistral / LLaMA 2 Template <s>[INST] <<SYS>> You are a helpful assistant. <</SYS>> What is the capital of France? [/INST] The capital of France is Paris. </s> # ChatML Template (OpenAI-style, used by many open models) <|im_start|>system You are a helpful assistant.<|im_end|> <|im_start|>user What is the capital of France?<|im_end|> <|im_start|>assistant The capital of France is Paris.<|im_end|>
Using HuggingFace Chat Templates
from transformers import AutoTokenizer tokenizer = AutoTokenizer.from_pretrained("meta-llama/Meta-Llama-3-8B-Instruct") messages = [ {"role": "system", "content": "You are a helpful assistant."}, {"role": "user", "content": "What is the capital of France?"} ] # apply_chat_template handles the format automatically prompt = tokenizer.apply_chat_template( messages, tokenize=False, # return string (not token IDs) add_generation_prompt=True # add assistant header for generation ) # Tokenize for model input inputs = tokenizer(prompt, return_tensors="pt") # inputs.input_ids contains the correctly formatted token sequence
Tokenizer Algorithm Comparison
| Algorithm | Approach | Vocab Size (typical) | Used By | Strengths |
|---|---|---|---|---|
| BPE | Greedy merge of frequent byte pairs | 32K-128K | GPT, LLaMA 3 | Good compression, fast encoding |
| SentencePiece (Unigram) | Probabilistic subword selection | 32K | LLaMA 2, T5 | Language-agnostic, better multilingual |
| WordPiece | Maximize likelihood of training data | 30K | BERT | Good for classification tasks |
| Tiktoken | BPE (byte-level, regex pre-split) | 100K-200K | GPT-4, Claude | Fast, handles all Unicode |
Using the wrong chat template doesn't cause an error — the model simply performs poorly because it sees token patterns it wasn't trained on. Always use tokenizer.apply_chat_template() from HuggingFace, or verify the exact format from the model card. For vLLM serving, pass --chat-template to override. Common symptom of template mismatch: model generates system prompt text, repeats user messages, or produces gibberish.
8. Training Methods & Distributed Training
Training large language models requires careful parallelism strategies to distribute computation and memory across multiple GPUs/TPUs.
Distributed Data Parallel (DDP)
import torch.distributed as dist import torch.multiprocessing as mp def setup_ddp(rank, world_size): # Initialize process group dist.init_process_group( backend="nccl", init_method="env://", rank=rank, world_size=world_size ) torch.cuda.set_device(rank) def train_ddp(rank, world_size, model, train_loader): setup_ddp(rank, world_size) # Wrap model with DDP model = model.cuda(rank) ddp_model = torch.nn.parallel.DistributedDataParallel( model, device_ids=[rank], find_unused_parameters=False ) optimizer = torch.optim.AdamW(ddp_model.parameters()) for epoch in range(10): for batch in train_loader: logits = ddp_model(batch) loss = compute_loss(logits, batch.labels) optimizer.zero_grad() loss.backward() # AllReduce: synchronize gradients across GPUs dist.all_reduce(torch.tensor(loss.item())) optimizer.step() # DDP benefits: # - Linear scaling: N GPUs → N× speedup # - All GPUs have full model copy (no model sharding) # - Synchronization via AllReduce on each backward pass
Parallelism Strategy Comparison
| Strategy | Data Split | Model Split | Memory per GPU | Bandwidth | Use Case |
|---|---|---|---|---|---|
| Data Parallel (DP) | Yes | No | Full model | High (AllReduce) | < 8 GPUs |
| Tensor Parallel (TP) | No | Yes (row/col) | Model/k | Very high (intra-node) | Large models |
| Pipeline Parallel (PP) | Yes | Yes (layers) | Model/k + buffer | Medium (inter-node) | Multi-node training |
| Sequence Parallel | Yes | Yes (sequence) | Model/k | High (intra-node) | Long sequences |
DeepSpeed ZeRO Optimization
# DeepSpeed ZeRO-2 Configuration (Gradient Partitioning) { "zero_optimization": { "stage": 2, "allgather_partitions": true, "allgather_bucket_size": 2e8, "overlap_comm": true, "reduce_scatter": true, "contiguous_gradients": true }, "fp16": { "enabled": true, "fp16_opt_level": "O2" }, "gradient_clipping": 1.0, "train_micro_batch_size_per_gpu": 4, "gradient_accumulation_steps": 8 }
Gradient Accumulation & Checkpointing
def train_with_gradient_accumulation( model, train_loader, n_accumulation_steps=4 ): # Accumulate gradients across multiple batches optimizer = torch.optim.AdamW(model.parameters()) for batch_idx, batch in enumerate(train_loader): logits = model(batch.input_ids) loss = compute_loss(logits, batch.labels) # Scale loss by accumulation steps scaled_loss = loss / n_accumulation_steps scaled_loss.backward() # Update weights every n_accumulation_steps batches if (batch_idx + 1) % n_accumulation_steps == 0: torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0) optimizer.step() optimizer.zero_grad() def train_with_gradient_checkpointing(model, batch): # Recompute activations instead of storing them def create_custom_forward(module): def custom_forward(*inputs): return module(*inputs) return custom_forward for layer in model.layers: # Checkpoint this layer x = torch.utils.checkpoint.checkpoint( create_custom_forward(layer), x ) return x # Gradient Accumulation benefits: # - Larger effective batch size without GPU memory overhead # - Gradient Checkpointing trades compute for memory (~30% slower, ~50% less VRAM)
9. Inference & Sampling Strategies
During inference, we generate tokens autoregressively using various sampling techniques to balance quality and diversity.
Temperature Scaling & Sampling
def sample_next_token(logits, temperature=0.8, top_k=50, top_p=0.9): # logits: (batch, vocab_size) from final layer # 1. Temperature scaling # - temp > 1: softer distribution (more diverse) # - temp = 1: unchanged # - temp < 1: sharper distribution (more deterministic) scaled_logits = logits / temperature # 2. Top-K filtering: keep only top k tokens topk_values, topk_indices = torch.topk( scaled_logits, k=top_k, dim=-1 ) # Mask out all other tokens mask = torch.full_like(scaled_logits, float('-inf')) mask.scatter_(dim=-1, index=topk_indices, src=topk_values) scaled_logits = mask # 3. Top-P (nucleus) sampling: cumulative probability threshold sorted_logits, sorted_indices = torch.sort( scaled_logits, descending=True, dim=-1 ) sorted_probs = torch.softmax(sorted_logits, dim=-1) cumsum_probs = torch.cumsum(sorted_probs, dim=-1) # Mask tokens where cumsum > top_p sorted_indices_to_remove = cumsum_probs > top_p # Keep at least 1 token sorted_indices_to_remove[..., 0] = False sorted_logits[sorted_indices_to_remove] = float('-inf') # 4. Sample from filtered distribution probs = torch.softmax(sorted_logits, dim=-1) next_token = torch.multinomial(probs, num_samples=1) return next_token # Common strategies: # - Greedy: argmax (deterministic, can get stuck) # - Temperature: softmax with scaling (simple, controlled randomness) # - Top-K: filter low-probability tokens (prevents nonsense) # - Top-P (Nucleus): dynamic filtering (usually best results)
Sampling Comparison
| Strategy | Method | Output | Speed | Quality |
|---|---|---|---|---|
| Greedy | argmax(logits) | Deterministic | Fast | Can loop/repeat |
| Random | sample from softmax | Very diverse | Fast | Often incoherent |
| Temperature | sample from scaled softmax | Controlled randomness | Fast | Good balance |
| Top-K | sample from top k tokens | Higher quality | Fast | Better |
| Top-P | sample from cumsum ≤ p | Dynamic filtering | Medium | Best |
Batched Inference with KV Caching
def generate_batched( model, prompt_ids, max_length=512, batch_size=16 ): # prompt_ids: (batch, prompt_len) batch_size, prompt_len = prompt_ids.shape # Initialize KV cache kv_cache = {"k": [], "v": []} for _ in model.layers # Prefill: process prompt in parallel logits = model(prompt_ids, kv_cache, is_prefill=True) generated_ids = prompt_ids.clone() for _ range(max_length - prompt_len): # Get last token logits from last layer last_logits = logits[:, -1, :] # Sample next token next_tokens = sample_next_token( last_logits, temperature=0.7, top_p=0.9 ) # (batch,) # Process only new tokens (decode) logits = model( next_tokens.unsqueeze(-1), # (batch, 1) kv_cache, is_prefill=False ) generated_ids = torch.cat([generated_ids, next_tokens.unsqueeze(-1)], dim=-1) return generated_ids # Batched inference speedups: # - KV cache only stores previous tokens (not recomputed) # - Prefill: vectorized computation on entire prompt # - Decode: only 1 token at a time (memory-bound, not compute-bound) # - Typical speedup: 10-30× faster than naive generation
10. Quantization
Quantization reduces model size and inference latency by converting weights from FP32 to lower-precision formats (INT8, INT4, etc.).
Quantization Techniques
| Method | Precision | Size | Speed | Quality | Use Case |
|---|---|---|---|---|---|
| Post-Training Quantization (PTQ) | INT8 | 75% reduction | Fast | Good | Quick deployment |
| GPTQ | INT4 | 87.5% reduction | Very fast | Excellent | Inference-only |
| AWQ | INT4 | 87.5% reduction | Very fast | Excellent | More models supported |
| NF4 (QLoRA) | NF4 (4-bit) | 87.5% reduction | Fast | Good+training | Fine-tuning |
| Bfloat16 | BF16 | 50% reduction | Very fast | Excellent | Training & inference |
INT8 Quantization
def quantize_to_int8(tensor, eps=1e-6): # Find per-channel (per-output) scaling factor # Assume tensor shape: (out_features, in_features) # Per-output scaling (most common) max_abs = torch.max(torch.abs(tensor), dim=-1, keepdim=True)[0] # Scale to [-128, 127] scales = max_abs / 127.0 quantized = torch.round(tensor / scales).to(torch.int8) return quantized, scales def dequantize_from_int8(quantized, scales): # Restore to original precision return quantized.float() * scales # Example: weights_fp32 = torch.randn(4096, 11008) # LLaMA FFN layer weights_int8, scales = quantize_to_int8(weights_fp32) # Memory: 4096 × 11008 × 4 bytes = 180MB → 45MB (4× reduction)
INT4 GPTQ Quantization
def gptq_quantize(tensor, bits=4, group_size=128): # Quantize to INT4 with group-wise scaling # tensor: (out_features, in_features) out_features, in_features = tensor.shape quantized = torch.zeros_like(tensor, dtype=torch.uint8) scales = torch.zeros(out_features, (in_features + group_size - 1) // group_size) for i in range(out_features): for j in range(0, in_features, group_size): group = tensor[i, j:j+group_size] # Find optimal scale for this group max_val = torch.max(torch.abs(group)) scale = max_val / ((1 << (bits-1)) - 1) # Quantize and store quantized[i, j:j+group_size] = torch.round(group / scale).int() scales[i, j // group_size] = scale return quantized, scales # GPTQ properties: # - Uses Hessian information for better quantization # - Per-group scaling reduces error vs per-channel # - 8× size reduction (FP32 to INT4) # - Typically <1% accuracy loss
QLoRA for Efficient Fine-Tuning
class QLoRA_Linear(nn.Module): def __init__(self, in_features, out_features, r=8): super().__init__() # Original weights quantized to NF4 (frozen) self.W_q = None # Loaded from quantized model # Low-rank adapters (trainable) self.lora_a = nn.Linear(in_features, r, bias=False) self.lora_b = nn.Linear(r, out_features, bias=False) # Initialize: lora_a ~ N(0, 1), lora_b = 0 nn.init.normal_(self.lora_a.weight) nn.init.zeros_(self.lora_b.weight) self.scaling = 1.0 / r def forward(self, x): # x: (batch, seq_len, in_features) # Original (frozen) computation output = dequantize_and_apply(x, self.W_q) # Add low-rank update lora_out = self.lora_b(self.lora_a(x)) output = output + self.scaling * lora_out return output # QLoRA memory efficiency: # - Base model: 13B params in NF4 = ~7GB VRAM # - LoRA adapters: 13B × r/d = ~600MB (r=8) # - Total: ~7.6GB vs 52GB (FP32 13B model) # - Trainable on single 24GB GPU!
11. Fine-tuning Methods
Fine-tuning adapts pretrained models to specific tasks. Methods range from full fine-tuning to efficient parameter-efficient approaches like LoRA.
Fine-tuning Methods Comparison
Supervised Fine-Tuning (SFT)
Train on task-specific labeled examples. Update all parameters to minimize cross-entropy on next-token prediction.
RLHF (Reinforcement Learning from Human Feedback)
1) SFT on examples 2) Train reward model from preferences 3) RL to maximize reward using PPO.
Direct Preference Optimization (DPO)
Skip reward model. Directly optimize for preference pairs using constrained objective.
LoRA Implementation
class LoRA_Linear(nn.Module): def __init__(self, original_linear, rank=8, alpha=16): super().__init__() self.original = original_linear # Freeze original weights original_linear.requires_grad = False # Low-rank decomposition: ΔW = B @ A (r << d) in_features = original_linear.in_features out_features = original_linear.out_features self.lora_a = nn.Linear(in_features, rank, bias=False) self.lora_b = nn.Linear(rank, out_features, bias=False) # Initialize: A ~ N(0, 1), B = 0 nn.init.normal_(self.lora_a.weight, std=1.0 / rank) nn.init.zeros_(self.lora_b.weight) # Scaling factor self.scaling = alpha / rank def forward(self, x): # y = W·x + (B @ A)·x = W·x + α/r·B·(A·x) result = self.original(x) lora_update = self.lora_b(self.lora_a(x)) return result + self.scaling * lora_update # LoRA parameter count: # - Original layer: 4096 × 11008 = 45M params # - LoRA adapters: (4096 × 8) + (8 × 11008) = 32K + 88K = 120K # - Reduction: 45M → 120K = 0.27% trainable params
Training with LoRA
def train_with_lora(base_model, train_dataset, lr=2e-4): # Step 1: Replace linear layers with LoRA versions for name, module in base_model.named_modules(): if isinstance(module, nn.Linear) and "q_proj" in name: new_module = LoRA_Linear(module, rank=8, alpha=16) # Replace in parent # Step 2: Only train LoRA parameters lora_params = [] for name, param in base_model.named_parameters(): if "lora" in name: lora_params.append(param) param.requires_grad = True else: param.requires_grad = False optimizer = torch.optim.AdamW(lora_params, lr=lr) for batch in train_dataset: logits = base_model(batch.input_ids) loss = compute_loss(logits, batch.labels) loss.backward() torch.nn.utils.clip_grad_norm_(lora_params, 1.0) optimizer.step() optimizer.zero_grad() # LoRA advantages: # - 99.7% parameter reduction vs full fine-tuning # - Can store multiple LoRA adapters (one per task) # - Merge LoRA weights into original after training # - Only marginal accuracy loss (often none)
Fine-tuning Methods Comparison
| Method | Parameters | Training Time | VRAM | Quality | Inference Cost |
|---|---|---|---|---|---|
| Full Fine-tuning | 100% | Fast | High | Best | No overhead |
| LoRA | 0.27% | Medium | Low | Very Good | No overhead (merged) |
| QLoRA | 0.27% | Slow | Very Low | Good | No overhead (merged) |
| Prefix Tuning | 0.1% | Fast | Very Low | Good | Sequence length overhead |
| Prompt Tuning | 0.01% | Very Fast | Very Low | Fair | Minimal overhead |
LoRA & QLoRA — Low-Rank Adaptation Deep Dive
LoRA is the most widely adopted parameter-efficient fine-tuning method, enabling training of 70B+ models on consumer hardware. QLoRA extends it with 4-bit quantization for extreme memory efficiency.
LoRA Architecture
The Math Behind LoRA
Core Idea: Instead of updating the full weight matrix W ∈ ℝd×d, decompose the update into two low-rank matrices:
W' = W + (α/r) · B · A
Where A ∈ ℝd×r and B ∈ ℝr×d with rank r ≪ d (typically r = 4-64).
- A is initialized from N(0, 1/r) — projects input down to rank r
- B is initialized to zeros — so ΔW starts at zero (identity at init)
- α (alpha) scales the LoRA contribution. Higher α = stronger adaptation
- α/r (scaling) normalizes the update magnitude regardless of rank choice
Which Layers to Apply LoRA
| Target Modules | Typical Config | Effect | Param Overhead |
|---|---|---|---|
| Q + V projections | target_modules=["q_proj", "v_proj"] | Standard LoRA — modifies attention queries and values | ~0.2% |
| Q + K + V + O | target_modules=["q_proj", "k_proj", "v_proj", "o_proj"] | Full attention LoRA — better quality, 2× overhead | ~0.4% |
| All linear layers | target_modules="all-linear" | Maximum quality — adapts attention + FFN + output head | ~1-2% |
| FFN only | target_modules=["gate_proj", "up_proj", "down_proj"] | Targets the largest params (60-80% of model). Good for knowledge injection. | ~0.8% |
Hyperparameter Guide
| Hyperparameter | Typical Range | Effect | Recommendation |
|---|---|---|---|
| r (rank) | 4, 8, 16, 32, 64 | Higher r = more capacity but more params. Diminishing returns above 32. | Start with r=8. Increase to 16-32 for complex tasks. |
| α (alpha) | 8, 16, 32, 64 | Scales LoRA update magnitude. Common: α = 2×r. | α=16 for r=8. α=32 for r=16. |
| dropout | 0.0 - 0.1 | Regularization on LoRA layers. Higher prevents overfitting on small data. | 0.05 default. 0.1 if dataset < 10K examples. |
| learning_rate | 1e-4 to 3e-4 | 2-10× higher than full fine-tuning since fewer params are updated. | 2e-4 is a good default for most LoRA setups. |
| target_modules | q,v / q,k,v,o / all-linear | Which weight matrices get LoRA adapters. | ["q_proj", "v_proj"] for speed; "all-linear" for max quality. |
HuggingFace PEFT Implementation
from peft import get_peft_model, LoraConfig, TaskType from transformers import AutoModelForCausalLM, AutoTokenizer, TrainingArguments from trl import SFTTrainer # 1. Load base model model = AutoModelForCausalLM.from_pretrained( "meta-llama/Llama-3.3-8B", torch_dtype=torch.bfloat16, device_map="auto", ) tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-3.3-8B") tokenizer.pad_token = tokenizer.eos_token # 2. Configure LoRA lora_config = LoraConfig( r=16, # Rank — higher = more capacity lora_alpha=32, # Scaling factor (usually 2×r) lora_dropout=0.05, # Regularization target_modules=[ # Which layers to adapt "q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj", ], task_type=TaskType.CAUSAL_LM, bias="none", # Don't train biases ) # 3. Apply LoRA model = get_peft_model(model, lora_config) model.print_trainable_parameters() # Output: "trainable params: 83,886,080 || all params: 8,114,212,864 || 1.03%" # 4. Train with SFTTrainer training_args = TrainingArguments( output_dir="./lora-llama", num_train_epochs=3, per_device_train_batch_size=4, gradient_accumulation_steps=4, # effective batch = 16 learning_rate=2e-4, bf16=True, warmup_ratio=0.03, logging_steps=10, save_strategy="steps", save_steps=200, ) trainer = SFTTrainer( model=model, args=training_args, train_dataset=dataset, processing_class=tokenizer, max_seq_length=2048, ) trainer.train()
Merging LoRA Weights for Production
# After training: merge LoRA weights into base model for zero-overhead inference from peft import PeftModel # Load base + LoRA base_model = AutoModelForCausalLM.from_pretrained("meta-llama/Llama-3.3-8B") lora_model = PeftModel.from_pretrained(base_model, "./lora-llama") # Merge: W' = W + (α/r) · B · A → single matrix, no LoRA overhead merged_model = lora_model.merge_and_unload() # Save merged model — identical interface to original, zero inference cost merged_model.save_pretrained("./llama-merged") tokenizer.save_pretrained("./llama-merged") # Now deploy with vLLM, TensorRT-LLM, or any framework — no PEFT dependency needed
QLoRA: Quantized Low-Rank Adaptation
QLoRA combines 4-bit NF4 quantization of the base model with LoRA adapters, enabling fine-tuning of 70B models on a single 48GB GPU.
QLoRA Implementation (HuggingFace)
from transformers import AutoModelForCausalLM, BitsAndBytesConfig from peft import get_peft_model, LoraConfig, prepare_model_for_kbit_training # 1. 4-bit quantization config (NF4 + double quantization) bnb_config = BitsAndBytesConfig( load_in_4bit=True, bnb_4bit_quant_type="nf4", # NormalFloat4 — optimal for normal distributions bnb_4bit_use_double_quant=True, # Quantize the quantization constants too bnb_4bit_compute_dtype=torch.bfloat16, # Compute in BF16 ) # 2. Load model in 4-bit model = AutoModelForCausalLM.from_pretrained( "meta-llama/Llama-3.1-70B", quantization_config=bnb_config, device_map="auto", ) # 3. Prepare for k-bit training (handle gradient checkpointing + casting) model = prepare_model_for_kbit_training(model) # 4. Apply LoRA on top of quantized model lora_config = LoraConfig( r=16, lora_alpha=32, target_modules=["q_proj", "v_proj", "k_proj", "o_proj"], lora_dropout=0.05, task_type="CAUSAL_LM", ) model = get_peft_model(model, lora_config) model.print_trainable_parameters() # "trainable: 167M | all: 70B | 0.24%" # VRAM usage breakdown: # - 70B model in NF4: ~35 GB # - LoRA adapters (BF16): ~1.5 GB # - Optimizer states: ~3 GB (AdamW on LoRA params only) # - Activations/gradients: ~5 GB (with gradient checkpointing) # - Total: ~44 GB → fits on 1× A100 48GB or 2× A40 48GB
LoRA vs QLoRA vs Full Fine-Tuning
| Aspect | Full Fine-Tuning | LoRA | QLoRA |
|---|---|---|---|
| Trainable params | 100% (all) | 0.2-2% (A + B matrices) | 0.2-2% (same as LoRA) |
| Base model precision | BF16/FP32 | BF16 (frozen) | NF4 4-bit (frozen) |
| VRAM (7B model) | ~28 GB | ~16 GB | ~6 GB |
| VRAM (70B model) | ~280 GB (8× A100) | ~150 GB (2× A100) | ~44 GB (1× A100) |
| Training speed | 1× (baseline) | 0.8-1× (similar) | 0.5-0.7× (dequant overhead) |
| Quality | 100% (baseline) | 97-99% | 95-98% |
| Inference overhead | None | None (merge into base) | None (merge + requantize) |
| Multi-task | Separate model per task | Swap/stack adapters | Swap/stack adapters |
| Best for | Maximum quality, large budget | Production fine-tuning, moderate GPU | Large models on limited GPU |
Advanced LoRA Variants
DoRA (Weight-Decomposed LoRA)
Decomposes weights into magnitude and direction components, applying LoRA only to the direction. Consistently outperforms standard LoRA by 1-3% on benchmarks with minimal overhead. Published at ICML 2024.
# DoRA: W' = m · (W + ΔW) / ||W + ΔW|| # where m = learnable magnitude, ΔW = B·A (standard LoRA) # In HuggingFace PEFT: lora_config = LoraConfig( r=16, use_dora=True, # Enable DoRA target_modules=["q_proj", "v_proj"], )
rsLoRA (Rank-Stabilized LoRA)
Changes the scaling factor from α/r to α/√r, stabilizing training at higher ranks. Allows scaling to r=256+ without learning rate tuning. Default in PEFT ≥0.11.
# rsLoRA: scaling = α/√r instead of α/r # Prevents gradient explosion at high ranks lora_config = LoraConfig( r=64, lora_alpha=64, use_rslora=True, # Rank-stabilized scaling )
LoRA+ (Differential Learning Rates)
Uses different learning rates for A and B matrices (typically lr_B = 2-8× lr_A). Improves convergence speed by 2× with no additional memory cost.
# LoRA+: separate LR for A and B optimizer = torch.optim.AdamW([ {"params": lora_a_params, "lr": 1e-4}, {"params": lora_b_params, "lr": 5e-4}, ])
QA-LoRA (Quantization-Aware LoRA)
Co-designs LoRA rank with quantization group size so the merged model can be directly quantized without quality loss. Produces INT4-ready models from training.
# QA-LoRA: r aligned to group_size # group_size=32, r must be multiple of 32 # After merge: direct GPTQ/AWQ quantization # Result: quantized + adapted in one shot
Adapters: Lightweight Modular Fine-Tuning
Adapters are small trainable bottleneck modules inserted between frozen Transformer layers. Unlike LoRA (which modifies existing weights), adapters add new parameters in a down-project → nonlinearity → up-project pattern.
Adapter Implementation
class Adapter(nn.Module): """Bottleneck adapter inserted after attention or FFN sub-layer.""" def __init__(self, d_model: int, bottleneck: int = 64): super().__init__() self.down = nn.Linear(d_model, bottleneck, bias=False) self.act = nn.GELU() self.up = nn.Linear(bottleneck, d_model, bias=False) nn.init.zeros_(self.up.weight) # init to identity residual def forward(self, x): return x + self.up(self.act(self.down(x))) # residual + adapter # Adapter param count (LLaMA-3 8B, bottleneck=64): # - Per adapter: 4096 × 64 + 64 × 4096 = 524K params # - 2 adapters per layer (after attn + after FFN) × 32 layers = 33.5M # - Total: 33.5M / 8B = 0.4% of model params class AdapterTransformerLayer(nn.Module): """Transformer layer with adapters after attention and FFN.""" def __init__(self, original_layer, bottleneck=64): super().__init__() self.layer = original_layer # frozen d = original_layer.self_attn.q_proj.in_features self.attn_adapter = Adapter(d, bottleneck) self.ffn_adapter = Adapter(d, bottleneck) def forward(self, x, **kwargs): # Attention sub-layer + adapter h = self.layer.self_attn(x, **kwargs) h = self.attn_adapter(h) # ← adapter after attention x = x + h # residual # FFN sub-layer + adapter h = self.layer.mlp(x) h = self.ffn_adapter(h) # ← adapter after FFN return x + h # residual
Adapter vs LoRA vs Full Fine-Tuning
| Aspect | Full Fine-Tuning | LoRA | Adapter | Prefix Tuning |
|---|---|---|---|---|
| Trainable params | 100% | 0.1-1% | 0.4-2% | 0.01% |
| Where modified | All layers | Existing Q/V/K/O weights | New bottleneck layers | Prepended to K/V |
| Inference overhead | None | None (merged) | ~2-5% latency | Sequence length + |
| Multi-task | Separate model per task | Merge or swap adapters | Swap adapter modules | Swap prefix vectors |
| Quality vs full FT | 100% (baseline) | 97-99% | 95-98% | 90-95% |
| Memory (7B model) | ~28GB | ~16GB | ~18GB | ~14GB |
| Key advantage | Maximum quality | Zero overhead, mergeable | Fully modular, stackable | Fewest parameters |
12. Safety & Alignment
Ensuring LLMs are safe, helpful, and honest requires multi-layered defenses spanning data curation, alignment training, runtime guardrails, and continuous adversarial testing.
The Alignment Pipeline
Alignment Techniques in Detail
Constitutional AI (Anthropic)
Model self-critiques outputs against a written "constitution" of principles. A revision model rewrites harmful outputs. Then RLAIF (RL from AI feedback) replaces human preference labels with AI-generated ones. Reduces annotation cost 10×+ while maintaining safety quality.
RLHF (PPO-based)
Pipeline: (1) SFT on demonstrations → (2) Train reward model on human preferences → (3) PPO optimizes LLM against reward model with KL penalty to prevent drift from base. Effective but expensive (~10× SFT cost) and prone to reward hacking.
DPO (Direct Preference Optimization)
Eliminates the reward model entirely. Directly optimizes from preference pairs using a binary classification loss: log σ(β(log π(y_w|x) - log π(y_l|x))). 40-75% cheaper than PPO with comparable alignment quality.
Red-Teaming & Adversarial Testing
Systematic adversarial probing for jailbreaks (GCG, AutoDAN), prompt injection (direct + indirect), harmful content generation, PII extraction, and bias elicitation. Should cover 1000+ attack vectors before production deployment.
Runtime Safety Stack
| Layer | Technique | What It Catches | Latency |
|---|---|---|---|
| Input filter | Classifier on prompt | Harmful requests, jailbreak attempts | ~5ms |
| PII detector | Presidio / regex + NER | Names, emails, SSNs in context | ~10ms |
| Topic guardrail | Zero-shot classifier | Off-topic, policy-violating queries | ~15ms |
| Output classifier | Toxicity / harm model | Harmful, biased, or toxic responses | ~10ms |
| Fact grounding | NLI entailment check | Hallucinated claims not in context | ~50ms |
| Citation verification | Source matching | Incorrect or fabricated citations | ~20ms |
Known Safety Challenges
| Challenge | Description | Root Cause | Severity | Mitigation |
|---|---|---|---|---|
| Jailbreaking | Adversarial prompts bypass safety training | Emergent techniques exploit attention patterns | Critical | Continuous red-teaming, constitutional AI, input classifiers |
| Hallucination | Generating plausible but false information | Next-token prediction lacks truth grounding | High | RAG, NLI verification, uncertainty estimation, CoT |
| Sycophancy | Agreeing with user even when user is wrong | RLHF optimizes for approval over truth | Medium | Diverse preference data, robustness training |
| Reward hacking | Gaming the reward model's weaknesses | Reward model is an imperfect proxy | High | Reward model ensembles, KL constraints, output monitoring |
| Bias amplification | Reinforcing societal biases from training data | Biased pre-training data + RLHF feedback | High | Bias benchmarks (BBQ, WinoBias), debiasing fine-tuning |
| Data extraction | Extracting memorized training data | Overfitting on repeated patterns | Medium | Deduplication, differential privacy, output monitoring |
13. Evaluation & Benchmarks
LLM evaluation requires multiple complementary approaches — no single benchmark captures all capabilities. Production systems need both offline benchmarks and online quality monitoring.
Evaluation Framework
Capability Benchmarks
Measure what the model can do — knowledge recall (MMLU), reasoning (GSM8K, MATH), code (HumanEval), commonsense (HellaSwag). Scored automatically.
Safety & Truthfulness
Measure what the model shouldn't do — TruthfulQA (resisting common misconceptions), AdvBench (adversarial attacks), BBQ (bias detection). Often requires human eval.
Human Preference
Measure subjective quality — Chatbot Arena (ELO rankings), MT-Bench (multi-turn conversation quality), AlpacaEval (instruction following). Gold standard but expensive.
Major Benchmarks
| Benchmark | Focus | Format | Difficulty | Size | Year | Key Insight |
|---|---|---|---|---|---|---|
| MMLU | World knowledge (57 domains) | 4-choice MC | Medium | 15,908 | 2021 | Most cited LLM benchmark; near-saturated by frontier models |
| MMLU-Pro | Harder knowledge (10 choices) | 10-choice MC | Hard | 12,141 | 2024 | Resists random guessing; better differentiation at the frontier |
| GSM8K | Grade-school math reasoning | Word problems | Medium | 8,792 | 2021 | Tests multi-step arithmetic; CoT dramatically improves scores |
| MATH | Competition-level math | Proofs + computation | Very Hard | 12,500 | 2021 | Still challenging even for GPT-4; test-time compute helps |
| HumanEval | Python code generation | Function from docstring | Medium | 164 | 2021 | Standard code benchmark; frontier models exceed 90% pass@1 |
| HellaSwag | Commonsense reasoning | Sentence completion | Hard | 10,042 | 2019 | Tests physical intuition; adversarially constructed distractors |
| TruthfulQA | Resist misconceptions | Open-ended QA | Medium | 817 | 2021 | Larger models were initially WORSE; alignment helps |
| MT-Bench | Multi-turn chat quality | LLM-as-judge scoring | Medium | 80 tasks | 2023 | Tests conversation coherence across turns |
| Chatbot Arena | Overall quality | Blind pairwise prefs | Varies | 1M+ votes | 2023 | ELO ranking from real users; most trusted overall ranking |
| BigBench | Diverse capabilities | 200+ task types | Varies | 200+ tasks | 2022 | Tests emergent abilities at scale |
| IFEval | Instruction following | Constrained instructions | Medium | 500+ | 2024 | Tests format compliance (word count, structure, etc.) |
Performance Trends (2020 → 2025)
| Model | MMLU | GSM8K | MATH | HumanEval | Year |
|---|---|---|---|---|---|
| GPT-3 (175B) | 41.3% | 11.4% | 2.4% | 10.2% | 2020 |
| LLaMA 7B | 35.1% | 11.0% | 2.9% | 10.5% | 2023 |
| Mistral 7B | 64.2% | 47.4% | 11.2% | 73.2% | 2023 |
| LLaMA 3 8B | 66.7% | 79.4% | 30.0% | 81.7% | 2024 |
| GPT-4 | 88.7% | 92.0% | 52.9% | 92.0% | 2023 |
| Claude 3 Opus | 88.7% | 95.0% | 60.1% | 95.1% | 2024 |
| Claude 3.5 Sonnet | 89.0% | 96.4% | 71.1% | 92.0% | 2024 |
| GPT-4o | 88.7% | 95.8% | 76.6% | 90.2% | 2024 |
| o1 (with CoT) | 91.8% | 97.8% | 96.4% | 92.4% | 2024 |
Evaluation Pitfalls
Contamination
Models may have seen benchmark questions during pre-training, inflating scores. Mitigations: held-out test sets, contamination detection, dynamic benchmarks (LiveBench).
Metric Gaming
Models can be specifically tuned for benchmarks without genuine capability. MMLU-Pro and harder variants attempt to address this. Always combine automated + human evaluation.
Context Window & Long-Range Techniques
Modern LLMs need to process long documents, multi-turn conversations, and large codebases. These techniques extend effective context beyond the training window.
Context Extension Methods
| Technique | How It Works | Extension Factor | Quality Impact | Used By |
|---|---|---|---|---|
| Sliding Window Attention | Each token attends only to W nearest neighbors. Reduces O(n²) to O(n·W). | Unlimited (with degradation) | Loses long-range dependencies beyond W | Mistral (W=4096) |
| RoPE Scaling (Linear) | Interpolate RoPE frequencies by factor s to extend context by s×. | 2-4× training length | Slight degradation at extreme lengths | Code Llama |
| NTK-Aware Scaling | Modify RoPE base frequency (θ) instead of interpolating. Better preserves local attention. | 2-8× | Better than linear for extreme extension | Community method |
| YaRN | Combines NTK scaling with temperature-based attention correction. | 4-16× (e.g., 128K from 8K) | Best quality for large extension ratios | LLaMA 3.1 (128K) |
| ALiBi | Linear attention biases based on distance — no positional embeddings. | 8× without fine-tuning | Good extrapolation, no retraining needed | BLOOM, MPT |
| Ring Attention | Distributes KV computation across GPUs in a ring, overlapping compute and communication. | Near-infinite (GPU count × window) | Exact attention, no approximation | Research (Berkeley) |
| Landmark Attention | Inserts landmark tokens that summarize segments. Attention retrieves relevant segments via landmarks. | 32-100× with random access | Slight loss for fine-grained retrieval | Research |
| RAG (Retrieval) | Retrieve relevant passages from external store instead of fitting everything in context. | Unlimited (external DB) | Best for knowledge-intensive tasks | Production standard |
Sliding Window + Sink Tokens
def sliding_window_attention(Q, K, V, window_size=4096, n_sink=4): """Attention with sliding window + sink tokens for streaming.""" seq_len = Q.shape[1] attn_mask = torch.zeros(seq_len, seq_len, dtype=torch.bool) for i in range(seq_len): # Always attend to first n_sink tokens (sink tokens) attn_mask[i, :n_sink] = True # Attend to local window start = max(n_sink, i - window_size) attn_mask[i, start:i+1] = True scores = (Q @ K.transpose(-2, -1)) / math.sqrt(Q.shape[-1]) scores.masked_fill_(~attn_mask, float('-inf')) return F.softmax(scores, dim=-1) @ V # Mistral uses W=4096 sliding window # With 32 layers, effective context = 4096 × 32 = 131K # But only attends locally — long-range info degrades
14. Frontier Research Directions
The field of LLMs is rapidly evolving with new approaches addressing context length, reasoning, multimodality, and interpretability.
Chain-of-Thought & Reasoning
Models like o1 and DeepSeek-R1 generate explicit reasoning traces before answering. Trained with RL on verifiable problems, trading inference compute for accuracy.
State-Space Models (SSMs)
Mamba and RWKV achieve O(n) complexity vs O(n²) for attention. Linear scaling with sequence length, potential for mega-context.
Multimodal LLMs
Vision + language (GPT-4V, Gemini, Claude). Vision encoder + token interleaving. Audio and video modalities emerging.
Mechanistic Interpretability
Understanding neuron function, attention head roles, circuits. Sparse autoencoders decomposing activations into features.
Emerging Architectures
| Architecture | Complexity | Context | Key Innovation | Models |
|---|---|---|---|---|
| Transformer (Dense Attention) | O(n²) | <256K | Scaled dot-product attention | GPT, LLaMA, Mistral |
| State-Space Models (SSM) | O(n) | >1M | Diagonal state transitions | Mamba, RWKV |
| Sparse Attention | O(n·log n) | 256K+ | Selective attention patterns | Longformer, BigBird |
| Hybrid (Attention + SSM) | O(n²) local | 256K+ | Local attention + global SSM | Jamba, Samba |
Interpretability Techniques
- Mechanistic Interpretability: Identify circuits and features that compute specific functions
- Attention Analysis: Visualize what tokens attend to; identify important attention heads
- Probing: Train classifiers on intermediate representations to decode semantic information
- Activation Decomposition: Use Sparse Autoencoders to find interpretable features in activations
- Causal Tracing: Ablate representations to find critical computation paths
Future Scaling Trends
| Dimension | Past | Present | Future |
|---|---|---|---|
| Model Size | 7B (2023) | 405B (2024) | ~10T (2025-2026) |
| Data Scale | 2T tokens | 10-15T tokens | 100T+ tokens (synthetic) |
| Context Length | 2K-4K | 128K-1M | 10M+ tokens |
| Training Compute | 10^22 FLOPs | 10^24-10^25 FLOPs | 10^26+ FLOPs |
| Inference | Full precision | Quantized (INT4/8) | Speculative decoding, pruning |
Advanced Reasoning Techniques
# Chain-of-Thought prompting for better reasoning def generate_with_reasoning(model, question, max_reasoning_tokens=2000): # Phase 1: Generate reasoning trace prompt = f"""Q: {question} Let me work through this step by step: """ reasoning = model.generate( prompt, max_tokens=max_reasoning_tokens, temperature=0.7 ) # Phase 2: Generate final answer based on reasoning full_prompt = prompt + reasoning + "\nTherefore, the answer is:" answer = model.generate( full_prompt, max_tokens=100, temperature=0.3 ) return reasoning, answer # Training with reasoning supervision (o1 style): # 1. Collect examples with explicit reasoning traces # 2. SFT on reasoning_trace + answer pairs # 3. RL: reward correct final answer (only reward if reasoning ∈ valid_steps) # 4. Trade inference compute for accuracy (test-time scaling)
Model Merging & Mixture of Experts
def linear_interpolation_merge(model_a, model_b, alpha=0.5): # Merge two fine-tuned models via convex combination # θ_merged = α·θ_a + (1-α)·θ_b merged_model = copy.deepcopy(model_a) for (name_a, param_a), (name_b, param_b) in zip( model_a.named_parameters(), model_b.named_parameters() ): assert name_a == name_b merged_model.get_parameter(name_a).data = ( alpha * param_a.data + (1 - alpha) * param_b.data ) return merged_model class MixtureOfExperts(nn.Module): def __init__(self, expert_models, num_experts_per_token=2): super().__init__() self.experts = nn.ModuleList(expert_models) # Router network learns to select experts self.router = nn.Linear(4096, len(expert_models)) self.k = num_experts_per_token def forward(self, x): # x: (batch, seq_len, d_model) # Get router logits router_logits = self.router(x) # Select top-k experts per token expert_weights, expert_indices = torch.topk( torch.softmax(router_logits, dim=-1), k=self.k ) output = torch.zeros_like(x) for i in range(0, self.k): expert_id = expert_indices[..., i] weight = expert_weights[..., i].unsqueeze(-1) # Apply selected experts (per token) expert_output = apply_per_token_expert( x, expert_id, self.experts ) output = output + weight * expert_output return output
Speculative Decoding
def speculative_decoding( large_model, small_model, prompt, max_tokens=256, gamma=4 ): # gamma = number of tokens to speculate generated = prompt.clone() for _ range(max_tokens): # Step 1: Small model proposes gamma tokens draft_tokens = [] logits_small = small_model(generated) for _ range(gamma): next_token = torch.argmax(logits_small[:, -1, :], dim=-1) draft_tokens.append(next_token) # Append to sequence for next prediction logits_small = small_model( torch.cat([generated, torch.stack(draft_tokens).T], dim=1) ) # Step 2: Large model validates/corrects in parallel proposed_seq = torch.cat([generated, torch.stack(draft_tokens).T], dim=1) logits_large = large_model(proposed_seq) # Compare probabilities and accept/reject n_accepted = 0 for i, draft_token in enumerate(draft_tokens): # Acceptance threshold based on large model confidence large_prob = torch.softmax(logits_large[:, generated.shape[1] + i], dim=-1) small_prob = torch.softmax(logits_small[:, generated.shape[1] + i - 1], dim=-1) # Rejection sampling accept_prob = min(1.0, large_prob[draft_token] / small_prob[draft_token]) if torch.rand(1) < accept_prob: generated = torch.cat([generated, draft_token.unsqueeze(-1)], dim=1) n_accepted += 1 else: break # If no tokens accepted, sample from large model if n_accepted == 0: logits = large_model(generated) next_token = torch.multinomial( torch.softmax(logits[:, -1, :], dim=-1), num_samples=1 ) generated = torch.cat([generated, next_token], dim=1) return generated # Speculative decoding speedup: # - Best case: 2-3× faster (gamma tokens generated with small model latency) # - Worst case: ~same as large model (if small model is poor) # - Accuracy: identical to large model (rejection sampling)
Appendix A: Data Scaling Laws
Empirical scaling laws reveal how model size, data, and compute relate to final loss. Chinchilla Optimal Compute (2022) suggests equal investment in model size and data tokens.
Chinchilla Scaling Laws
# Chinchilla optimal scaling: N ≈ D (equal parameters and tokens) # This contradicts GPT-3 which was undertrained for its size def estimate_optimal_compute(compute_budget_flops): # Compute budget for training = 6·N·D # (6 FLOPs per parameter per token) # Chinchilla optimal: N = D n_params = (compute_budget_flops / 6) ** 0.5 n_tokens = (compute_budget_flops / 6) ** 0.5 return n_params, n_tokens # Example: 10^24 FLOPs budget n, d = estimate_optimal_compute(1e24) # → N ≈ 408B params, D ≈ 408B tokens (LLaMA 3 405B region!)
Loss Scaling Models
| Model | Formula | Fit Quality | Notes |
|---|---|---|---|
| Power Law | L = A·N^(-α) + B | Excellent (2-6 orders) | Simple, asymptotic |
| Chinchilla | L = E + A·N^(-1/4) + B·D^(-1/4) | Excellent | Accounts for N and D |
| Kaplan (GPT-3) | L = (A/N^α) + (B/D^β) | Good | Historical baseline |
Typical Dataset Composition
| Data Source | Percentage (LLaMA 3) | Percentage (GPT-3) | Quality |
|---|---|---|---|
| Web content (filtered) | 50% | 60% | Variable |
| Books & literature | 25% | 15% | High |
| Code (GitHub, etc.) | 15% | 10% | High |
| Academic papers | 8% | 5% | Very High |
| Other (math, reasoning) | 2% | 10% | High |
Appendix B: Advanced Decoding Methods
Beam Search
def beam_search(model, prompt, beam_width=5, max_tokens=100): # Track top-k complete hypotheses hypotheses = [ {"tokens": prompt, "score": 0.0, "finished": False} for _ range(beam_width) ] for step in range(max_tokens): all_candidates = [] for hyp in hypotheses: if hyp["finished"]: all_candidates.append(hyp) continue # Get logits for next token logits = model(hyp["tokens"]) log_probs = torch.log_softmax(logits[:, -1, :], dim=-1) # Expand: try all vocab tokens for token_id in range(vocab_size): new_hyp = { "tokens": torch.cat([hyp["tokens"], torch.tensor([token_id])]), "score": hyp["score"] + log_probs[token_id].item(), "finished": token_id == EOS_TOKEN } all_candidates.append(new_hyp) # Keep top-k by score (length-normalized) all_candidates.sort( key=lambda x: x["score"] / len(x["tokens"]), reverse=True ) hypotheses = all_candidates[:beam_width] # Stop if all hypotheses are finished if all(h["finished"] for h in hypotheses): break return hypotheses[0]["tokens"] # Beam search properties: # - Explores multiple paths (not greedy) # - Finds higher-probability sequences than greedy # - Slower: O(beam_width · vocab_size) per step # - Trade-off: beam_width=1 (greedy), beam_width=10 (thorough but slow)
Contrastive Search
def contrastive_search_sampling( model, prompt, alpha=0.6, top_k=10 ): # Balance model confidence and diversity # Score = α·P(token|context) + (1-α)·contrast(token) generated = prompt.clone() for _ range(100): # Get model logits logits = model(generated) log_probs = torch.log_softmax(logits[:, -1, :], dim=-1) # Get top-k candidates topk_log_probs, topk_indices = torch.topk(log_probs, k=top_k) # Get embeddings for context and candidates context_emb = model.get_last_hidden_state(generated)[:, -1, :] candidate_embs = model.embedding(topk_indices) # Compute contrast: similarity to context similarity = torch.nn.functional.cosine_similarity( context_emb.unsqueeze(1), candidate_embs, dim=-1 ) contrast = 1.0 - similarity # Combined score scores = alpha * topk_log_probs + (1 - alpha) * contrast # Select token with highest score best_idx = torch.argmax(scores) next_token = topk_indices[best_idx] generated = torch.cat([generated, next_token.unsqueeze(-1)], dim=-1) if next_token == EOS_TOKEN: break return generated # Contrastive search benefits: # - Avoids repetition and dull text # - More diverse outputs than top-p sampling # - Slight quality tradeoff vs model confidence alone
Appendix C: Implementation Details & Gotchas
Common Implementation Pitfalls
| Issue | Symptom | Cause | Fix |
|---|---|---|---|
| Gradient explosion | NaN loss, divergence | Large learning rate or weight initialization | Gradient clipping, proper init (Kaiming) |
| Dead ReLU | Many activations = 0 | Learning rate too high, weights become negative | Use ReLU variants (GELU, SiLU) or negative slope |
| Causal mask bugs | Impossible to learn (information leak) | Upper triangular not properly applied | Test mask: [0,0] can't attend to [0,1] |
| KV cache dimension mismatch | Shape errors during inference | Cache shape doesn't match head dimension | Debug: cache.shape == (batch, seq, n_heads, d_k) |
| Quantization accuracy loss | 5-20% accuracy drop | Calibration data not representative | Use same distribution as training data |
Efficient Attention Implementation
# Memory-efficient attention (avoid materializing Q×K^T) def efficient_attention(Q, K, V, block_size=128): # Process in blocks to stay within GPU memory seq_len = Q.shape[0] output = torch.zeros_like(Q) for i in range(0, seq_len, block_size): # Load Q block Q_block = Q[i:i+block_size] # Compute attention to all K, V (causal) for j range(0, i+block_size, block_size): K_block = K[j:j+block_size] V_block = V[j:j+block_size] # Compute Q @ K^T for this block pair scores = Q_block @ K_block.T / sqrt(128) # Apply causal mask if j + block_size > i + block_size: mask = torch.ones(scores.shape, dtype=torch.bool) mask.triu_(j - i + 1) scores[mask] = float('-inf') # Softmax and aggregate weights = torch.softmax(scores, dim=-1) output[i:i+block_size] += weights @ V_block return output
Proper Weight Initialization
def init_transformer_weights(module): # Kaiming initialization for better convergence if isinstance(module, nn.Linear): # He/Kaiming init: std = √(2/n_in) nn.init.kaiming_normal_(module.weight, mode='fan_in', nonlinearity='relu') if module.bias is not None: nn.init.zeros_(module.bias) elif isinstance(module, nn.Embedding): # Normal dist for embeddings nn.init.normal_(module.weight, std=0.02) elif isinstance(module, (nn.LayerNorm, RMSNorm)): # Initialize scale to 1.0, no bias needed nn.init.ones_(module.weight) if hasattr(module, 'bias') and module.bias is not None: nn.init.zeros_(module.bias) else: if hasattr(module, 'weight'): nn.init.normal_(module.weight, std=0.02)
15. Multimodality: Vision, Audio & Beyond
Modern LLMs extend beyond text to process images, audio, video, and other modalities. This section covers the architectures and fusion strategies that enable foundation models to understand and reason across diverse data types.
Multimodal Encoder Architectures
Vision Transformers (ViT)
Split images into patches, embed linearly, then process with transformer layers.
# Image to patches (224x224 → 196 patches) image: [B, 3, 224, 224] patches: [B, 196, 768] # 16×16 patches with learnable [CLS] token: [B, 197, 768] → Transformer encoder → embeddings: [B, 768]
Encoder-Decoder Fusion (DALL-E)
Vision encoder produces image embeddings, language decoder generates text.
# DALL-E: image → tokens
image: [B, 3, 256, 256]
→ VAE encoder
→ latent: [B, 32, 32, 4]
→ discrete tokens (1-8192)
→ Transformer decoder
Cross-Modal Architectures
| Architecture | Vision | Language | Fusion | Use Case |
|---|---|---|---|---|
| CLIP | ViT | Text Transformer | Contrastive (image-text pairs) | Zero-shot classification |
| Flamingo | Vision Encoder | LLM | Gated cross-attention | Visual Q&A, captioning |
| LLaVA | ViT | LLaMA | Linear projection + LoRA | Instruction following |
| GPT-4V | Vision Encoder | LLM | Token interleaving | General visual reasoning |
| ALIGN | EfficientNet | BERT | Dual-encoder (similarity) | Image-text retrieval |
Audio & Speech Modalities
Wav2Vec 2.0 (Speech)
Self-supervised learning on raw waveforms with contrastive losses.
- Raw audio → CNN encoder (25ms frames)
- Temporal convolutions capture phonetic structure
- Contrastive learning: distinguish true next frame from negatives
- Fine-tune on labeled speech for ASR
AudioLM (Audio Generation)
Generate arbitrary audio: speech, music, sound effects, silence.
- Tokenize audio with vector-quantized VAE
- Learn discrete vocabulary (4096 tokens)
- Train language model on audio token sequences
- Stable at 30-second generation lengths
Multimodal Training Strategies
- Scale imbalance: Vision datasets often 100×+ larger than high-quality aligned pairs
- Modality mismatch: Different effective context windows (images ~1-5 tokens, text variable)
- Training efficiency: Vision encoders often frozen; only language decoder unfrozen to save compute
- Temporal alignment: Video requires explicit temporal reasoning not needed for static images
Practical Implementation: Vision-Language Adapter
class VisionLanguageAdapter: def __init__(self, vision_dim=1024, llm_dim=4096): self.vision_encoder = ViT(output_dim=vision_dim) self.adapter = nn.Sequential( nn.Linear(vision_dim, llm_dim // 2), nn.GELU(), nn.Linear(llm_dim // 2, llm_dim) ) self.llm = LlamaForCausalLM.from_pretrained("llama-7b") # Freeze vision encoder, train adapter + LoRA on LLM for param in self.vision_encoder.parameters(): param.requires_grad = False def forward(self, images, text_ids, text_mask): # 1. Encode images vision_features = self.vision_encoder(images) # [B, N_patches, 1024] # 2. Project to LLM space vision_embeddings = self.adapter(vision_features) # [B, N_patches, 4096] # 3. Concatenate with text embeddings text_embeddings = self.llm.embed_tokens(text_ids) # 4. Interleave or prepend vision tokens combined = torch.cat([vision_embeddings, text_embeddings], dim=1) # 5. Forward through LLM output = self.llm(inputs_embeds=combined, attention_mask=combined_mask) return output.logits
16. RAG Integration: Retrieval-Augmented Generation
RAG augments language models with external knowledge by retrieving relevant documents and conditioning generation on them. This section covers retrieval pipelines, vector databases, and orchestration frameworks.
The RAG Pipeline
RAG Frameworks & Tools
| Framework | Vector DB Support | Strengths | Best For |
|---|---|---|---|
| LangChain | 30+ (Pinecone, Weaviate, Faiss, etc.) | Chains, agents, memory management | General-purpose applications |
| LlamaIndex | 20+ (optimized for LLM context) | Document indexing, structured context | Document QA, summarization |
| Haystack | Elasticsearch, Weaviate, Pinecone | Pipeline composition, dense retrieval | Production search systems |
| FastRAG (IBM) | Custom dense retrieval | Speed-optimized, multilingual | Low-latency systems |
| Anthropic PDFPlumber | In-memory vector store | PDF extraction accuracy | Document-heavy domains |
RAG Patterns
Simple RAG
Query → Embed → Retrieve Top-K → Fuse → Generate
# Basic pattern query_emb = embed(query) docs = db.search(query_emb, k=5) context = "\n".join([d.text for d in docs]) prompt = f"Context: {context}\nQ: {query}" answer = llm.generate(prompt)
Multi-Stage RAG
Retrieve → Rerank → Fuse Multiple Sources
- Retrieve candidate docs (100-500)
- Rerank with cross-encoder
- Fuse top 3-5 ranked docs
- Generate with multi-doc context
Agentic RAG
Loop: Generate → Decide → Retrieve → Repeat
- LLM decides if more info needed
- Formulates new retrieval queries
- Iterates until answer is confident
- Handles follow-up questions naturally
Practical Implementation: Basic RAG
import langchain from langchain.vectorstores import Pinecone from langchain.embeddings import OpenAIEmbeddings from langchain.chains import RetrievalQA # 1. Initialize components embeddings = OpenAIEmbeddings(model="text-embedding-3-large") vectorstore = Pinecone.from_documents(docs, embeddings, index_name="my-index") retriever = vectorstore.as_retriever(search_kwargs={"k": 5}) # 2. Create QA chain qa = RetrievalQA.from_chain_type( llm=ChatOpenAI(model="gpt-4"), chain_type="stuff", # stuff, map_reduce, refine retriever=retriever ) # 3. Query response = qa.run("What are the key findings?")
- Reduces hallucinations by grounding in documents
- Scales to dynamic, large knowledge bases
- Enables private data without model training
- Improves factuality and citation
17. Data Pipeline: From Raw Text to Training
Preparing high-quality data is critical to LLM performance. This section covers text cleaning, filtering, deduplication, and synthetic data generation—often 30-50% of project effort.
Data Processing Stages
1. Acquisition & Ingestion
- Common sources: Web crawl (CommonCrawl), books (Project Gutenberg), code (GitHub), scientific papers (arXiv)
- Volume: 2T+ tokens typical for frontier models
- Quality varies: Web ~0.5-1% high-quality, academic 10-20%
2. Text Cleaning & Normalization
def clean_text(text): # Remove null bytes, control characters text = text.replace("\x00", "") # Fix encoding artifacts text = ftfy.fix_text(text) # Normalize unicode (NFD → NFC) text = unicodedata.normalize("NFC", text) # Remove excessive whitespace text = re.sub(r"\s+", " ", text).strip() return text
3. Filtering & Quality Scoring
def quality_score(document): score = 0 # Language detection: keep en only if detect(document) != "en": return 0 # Length heuristic if len(document.split()) < 50: return 0.1 score += 0.3 # Lexical diversity (unique unigrams / total) words = document.split() diversity = len(set(words)) / len(words) if diversity < 0.3: return 0 # Repetitive score += diversity * 0.3 # Perplexity filter (reference LM) ppl = compute_perplexity(document) if ppl > 10000: return 0 # Gibberish score += 0.4 return score
4. Deduplication
Exact dedup: Hash documents, remove identical copies.
Near-dedup: MinHash (LSH) to find similar documents—typical 10-20% reduction.
# MinHash for near-duplicate detection from datasketch import MinHash def dedup(documents): minhashes = {} for doc_id, text in documents: m = MinHash(num_perm=128) for token in text.split(): m.update(token.encode()) # Find similar existing documents duplicates = [d for d, mh in minhashes.items() if m.jaccard(mh) > 0.8] if not duplicates: minhashes[doc_id] = m yield doc_id, text
5. PII Removal & Privacy
- Detect: Regex, NER, transformers for SSN, email, phone, addresses
- Redact: Hash, mask, remove entirely
- Tools: Presidio (Microsoft), NLP-based detectors
6. Corpus Blending & Reweighting
Different data sources have different quality/relevance. Optimal blends are task-dependent.
| Source | Proportion (typical) | Quality | Rationale |
|---|---|---|---|
| Web | 80-90% | Medium | Scale and diversity |
| Code | 5-10% | High | Logic, structure |
| Books/Academic | 2-5% | Very High | Depth, coherence |
| Conversation | 1-2% | Medium | Dialogue skills |
Synthetic Data Generation
Training on LLM-generated synthetic data → successively worse generations. Mitigations: mix with original data, cap synthetic proportion at 10%, use diverse generators.
Reference: NVIDIA NeMo-Curator
# Large-scale data curation pipeline from nemo_curator import Curator curator = Curator( input_path="/data/raw", output_path="/data/processed", language="en", batch_size=10000, workers=128 ) # Multi-stage pipeline curator.run([ ("text_cleaning", {}), ("language_detection", {"lang": "en"}), ("quality_scoring", {"threshold": 0.5}), ("exact_dedup", {}), ("near_dedup", {"jaccard_threshold": 0.85}), ("pii_removal", {}) ])
18. Deployment & Serving: From Lab to Production
Deploying LLMs requires careful consideration of latency, throughput, cost, and reliability. This section covers infrastructure, optimization frameworks, and CI/CD practices.
Deployment Strategies
| Strategy | Setup | Latency | Cost | Use Case |
|---|---|---|---|---|
| Cloud Managed | Minutes (OpenAI API, Anthropic) | 100-500ms | $0.01-1/1K tokens | Prototyping, low volume |
| Self-Hosted VMs | Hours (vLLM on AWS/GCP) | 50-200ms | $0.10-1/hour GPU | Medium volume, custom models |
| Kubernetes | Days (Helm charts, Kserve) | 30-150ms | $0.05-0.5/hour (reserved) | Production, auto-scaling |
| On-Premise | Weeks (infra setup) | 10-100ms | CapEx + OpEx | Regulated, private data |
| Edge/Mobile | Days (ONNX export, optimization) | 200-2000ms | Device only | Offline, privacy-critical |
| Serverless | Minutes (AWS Lambda, Google Cloud Functions) | 1-5s (cold start) | $0.0000167/GB-second | Bursty, low SLA |
Serving Engines & Optimization Frameworks
| Engine | Latency Optimization | Throughput | Best For |
|---|---|---|---|
| vLLM | PagedAttention (KV cache) | 24× vs HF TGI | Open-source, high throughput |
| TensorRT-LLM | Kernel fusion, int8 weight, FP8 activation | 20-30× vs PyTorch | NVIDIA hardware |
| NVIDIA Triton | Multi-model, batching, ensemble | Multi-GPU support | Complex inference pipelines |
| Hugging Face TGI | Continuous batching, speculative decoding | 10-20× throughput | Model hub integration |
| Ollama | CPU inference, quantized models | Low latency on edge | Local/edge deployment |
Practical Example: vLLM Deployment
from vllm import LLM, SamplingParams # 1. Load model with paged attention llm = LLM( model="meta-llama/Llama-2-70b-chat-hf", tensor_parallel_size=2, # 2 GPUs gpu_memory_utilization=0.9, # 90% VRAM dtype="float16", max_model_len=4096 ) # 2. Batch inference prompts = ["What is AI?", "Explain transformers."] sampling_params = SamplingParams(temperature=0.7, top_p=0.9, max_tokens=256) outputs = llm.generate(prompts, sampling_params) for output in outputs: print(output.outputs[0].text)
CI/CD & Model Registry
Model Registry
- Hugging Face Hub: 100K+ models, version control
- Model Zoo: Internal registry with access control
- Metadata: License, benchmark scores, known issues
- Versioning: Semantic (v1.0.1) + commit hash
Deployment Pipeline
- Model → Registry (HF Hub)
- Automated benchmarking (latency, accuracy)
- Canary deployment (5% traffic)
- Monitor metrics → Roll out (100%) or revert
- A/B testing for user-facing metrics
Monitoring & Observability
- Latency: p50, p95, p99 (ms) per batch size
- Throughput: tokens/second, queries/second
- Utilization: GPU%, memory%, network%
- Quality: Output length, BLEU, human rating
- Errors: OOM rate, timeout rate, crash rate
- Cost: $/1K tokens, $/user, ROI per experiment
19. Failure Modes: Understanding & Mitigating LLM Risks
LLMs are powerful but fallible. Understanding common failure modes and mitigations is critical for responsible deployment.
Hallucinations
What: Confident generation of false information
- Root cause: Model memorizes but doesn't understand; extrapolates plausibly but incorrectly
- Incidence: 5-15% of outputs (varies by task)
- Severity: High for factual tasks (dates, names); low for creative
Mitigations
- RAG: Ground in retrieved documents
- Fine-tuning: Emphasize factual accuracy in RLHF
- Prompt engineering: "Only use provided context" or "I don't know" option
- Ensemble: Multiple runs + consistency checking
- Uncertainty quantification: Return confidence scores
Bias & Toxicity
What: Systematic unfairness, harmful stereotypes, offensive content
- Root cause: Training data biases; tokenizer artifacts
- Incidence: 1-5% of outputs (gender, race, religion biases)
- Severity: High for hiring, lending, moderation
Mitigations
- Data curation: Filter biased training data
- RLHF: Reward fairness, penalize stereotypes
- Adversarial testing: Red team before release
- Post-hoc filtering: Block known harmful patterns
- Transparency: Document known biases, limitations
Cost Overruns
What: Unexpectedly high inference costs at scale
- Causes: Long context, high traffic, inefficient serving, debugging/logging
- Example: 1B requests/day at $0.001/request = $1M/day
Mitigations
- Query optimization: Preprocess, deduplicate, cache
- Model selection: Smaller models (7B vs 70B) for cost-sensitive tasks
- Routing: Use cheaper model for easy queries
- Batching: Dynamic batching to maximize GPU utilization
- Monitoring: Track cost per feature, flag anomalies
Model Collapse (Repeated RLHF)
What: Reward hacking → model behaves deceptively or loses capabilities
- Root cause: RLHF incentivizes gaming the reward, not genuine improvement
- Incidence: After 2-3 rounds of RLHF (months of deployment)
- Symptom: Model gives short, confident answers even when unsure
Mitigations
- Reward model robustness: Ensemble of raters, adversarial probes
- KL regularization: Penalize divergence from base model
- Periodic retraining: Every 3-6 months on fresh data
- Gradient clipping: Prevent extreme policy updates
Data Leakage & Privacy
What: Training data extracted via prompting (e.g., member inference attacks)
- Incidence: <5% of training data can be extracted from frontier models
- Severity: High for sensitive data (medical, financial)
Mitigations
- Differential privacy: Formal privacy guarantees during training
- Data deduplication: Ensure no PII in training set
- Access controls: Fine-grained API monitoring
- Model sharding: Smaller models with subset of data
Latency Spikes & SLA Violations
What: p99 latency exceeds SLA, causing cascading failures
- Causes: GPU contention, thermal throttling, memory fragmentation, network jitter
- Impact: Timeouts, cascading retries, crash
Mitigations
- Resource isolation: CPU/GPU affinity, memory reservations
- Load shedding: Reject requests if queue > threshold
- Speculative execution: Compute common follow-ups proactively
- Hardware redundancy: Multiple GPUs, multi-region
- Request prioritization: Low-latency queues for short contexts
Summary: Failure Mode Matrix
| Failure Mode | Severity | Detectability | Prevention Cost |
|---|---|---|---|
| Hallucination | High | Medium (requires fact-check) | High (RAG, RLHF) |
| Bias/Toxicity | High | High (automated filtering) | Medium (red teaming) |
| Cost overruns | High | High (automatic alerts) | Low (monitoring) |
| Model collapse | Medium | Low (slow drift) | High (RLHF careful tuning) |
| Data leakage | High | Low (passive attacks) | Very High (differential privacy) |
| Latency spikes | Medium | High (alerts) | Medium (infra investment) |
20. Implementation Roadmap: From Idea to Production
Building an LLM system is complex. This roadmap breaks it into five phases with realistic timelines and checkpoints.
Phase 1: Requirements & Evaluation (Weeks 1-2)
Objectives
- Define success metrics (latency, accuracy, cost)
- Identify data sources and privacy requirements
- Map model size, inference cost, training time
Deliverables
- Product requirements document (PRD)
- Benchmarking results (3-5 baseline models)
- Cost-benefit analysis (buy vs build)
Key Questions
- Can off-the-shelf API (OpenAI, Anthropic) meet requirements?
- If not, what's the cost/benefit of fine-tuning vs pretraining?
- What's the privacy/data sensitivity?
Phase 2: Prototyping (Weeks 3-6)
Objectives
- Implement MVP with off-the-shelf components
- Evaluate prompt engineering vs fine-tuning
- Set up evaluation framework
Deliverables
- Working prototype (prompt + API wrapper)
- Evaluation dataset (100-1000 examples)
- Baseline metrics (BLEU, ROUGE, human eval on 50 examples)
Tech Stack
# Typical prototype stack
- Model: OpenAI GPT-4 or open-source (Llama 2)
- Framework: LangChain or Anthropic SDK
- Vector DB: Pinecone or Weaviate (for RAG)
- Eval: DeepEval, RAGAS, custom metrics
- Logging: Weights & Biases or Langfuse
Phase 3: Optimization & Tuning (Weeks 7-12)
Objectives
- Fine-tune or customize model for domain
- Optimize inference (batching, quantization, caching)
- Improve evaluation to 500+ examples with human ratings
Deliverables
- Fine-tuned checkpoint (if applicable)
- Serving infrastructure (vLLM, Triton, or managed API)
- Evaluation results: metrics vs cost trade-offs
Typical Improvements
- +5-15% accuracy from fine-tuning
- 10-50× throughput improvement from optimization
- 50-80% cost reduction via quantization + batching
Phase 4: Safety & Alignment (Weeks 10-14, parallel)
Objectives
- Red team: identify failure modes (hallucination, bias, toxicity)
- RLHF (optional): fine-tune for safety & alignment
- Legal & compliance review
Deliverables
- Adversarial test results (1000+ adversarial prompts)
- Safety metrics: hallucination rate, toxicity, factuality
- Privacy assessment & data handling plan
- Model card & usage guidelines
Reference
HELM (Holistic Evaluation of Language Models): Comprehensive benchmark across 16 scenarios, 1000s of test cases.
Phase 5: Production Deployment (Weeks 13-16)
Objectives
- Deploy to production infrastructure (K8s, cloud, or on-prem)
- Set up monitoring, logging, alerting
- Implement feedback loop for continuous improvement
Deliverables
- Production API with SLA (latency, availability)
- Monitoring dashboard (latency, cost, quality metrics)
- Runbook for escalation and troubleshooting
- Feedback collection + weekly metric reviews
Deployment Checklist
- Load balancing & auto-scaling configured
- Rate limiting & quota management
- Canary deployment (5% traffic → 25% → 100%)
- Rollback plan & version management
- Security: API keys, data encryption, audit logs
- Cost monitoring & alerts ($X/day threshold)
- On-call runbook for incidents
Timeline & Dependencies
Resource & Skill Requirements
| Phase | Role | Effort | Key Skills |
|---|---|---|---|
| 1-2 | PM, ML Engineer (1-2) | 2-4 weeks | LLM APIs, prompting, evaluation |
| 3 | ML Engineer (2-3), DataEngineer | 4-6 weeks | PyTorch, serving frameworks, infra |
| 4 | Safety specialist, lawyers | 2-4 weeks (parallel) | Red teaming, bias testing, compliance |
| 5 | DevOps, ML engineer, on-call | 2-4 weeks | K8s, monitoring, incident response |
Appendix D: Glossary of Terms
Key Terminology
| Term | Definition | Context |
|---|---|---|
| Autoregressive | Generating output one token at a time, conditioning on all previous tokens | Inference, training objective |
| BF16 (bfloat16) | 16-bit float with 8-bit exponent (FP32 range) but 7-bit mantissa (reduced precision) | Mixed precision training |
| Causal mask | Upper-triangular mask preventing tokens from attending to future positions | Decoder-only models |
| d_model (hidden dimension) | Width of the residual stream; typically 4096–18432 | Model architecture |
| Embedding | Dense vector representation of a discrete token in semantic space | Input layer |
| FLOP | Floating-point operation; FLOPs = total compute, FLOP/s = throughput | Compute efficiency |
| Gradient checkpointing | Recompute activations during backward instead of storing them (memory-compute trade-off) | Training optimization |
| Hallucination | Confident generation of false information (not grounded in training data) | Safety, evaluation |
| In-context learning | Model learning from examples in the prompt without weight updates | Few-shot prompting |
| KV cache | Cached key and value tensors from previous tokens to avoid recomputation in autoregressive generation | Inference optimization |
| Logits | Raw unnormalized scores from the final linear layer before softmax | Output, sampling |
| MFU (Model FLOPs Utilization) | Fraction of peak hardware FLOPs actually achieved; typically 30–60% for training | Training efficiency |
| Perplexity | exp(average negative log-likelihood); lower is better; inverse of average probability | Evaluation metric |
| Residual connection | Skip connection that adds input to output: output = f(norm(x)) + x | Architecture design |
| RoPE (Rotary Position Embedding) | Modern positional encoding using complex rotations; naturally extends to longer sequences | Positional encoding |
| Softmax | softmax(z_i) = exp(z_i) / Σ exp(z_j); converts logits to probability distribution | Probability normalization |
| Token | Atomic unit of text (subword, word, or character) produced by tokenizer | Input processing |
| Transformer | Neural network based on self-attention (Vaswani et al., 2017) | Architecture |
| Weight tying | Sharing W_embed and W_unembed matrices; saves 2.1B params for LLaMA-3 405B | Parameter efficiency |
| Softmax temperature | Scaling factor for logits before softmax; temp > 1 increases diversity, temp < 1 increases determinism | Sampling control |
| Nucleus sampling (Top-P) | Dynamic filtering to keep cumulative probability mass ≤ p; usually p ∈ [0.9, 0.95] | Decoding strategy |
Appendix E: Memory & Compute Requirements
Training Memory Breakdown
# Memory consumption breakdown for training a 405B model def estimate_training_memory( n_params=405e9, batch_size=4, seq_len=4096 ): # 1. Model weights (FP32): 4 bytes per parameter weight_memory = n_params * 4 / 1e9 # GB # 2. Optimizer state (AdamW: m, v, exp_avg) optimizer_memory = n_params * 8 / 1e9 # 2× FP32 per param # 3. Activations (gradient computation) # Roughly: batch·seq_len·d_model × num_layers × 2 activation_memory = (batch_size * seq_len * 16384 * 126 * 4) / 1e9 # 4. Gradients (same shape as weights) gradient_memory = n_params * 4 / 1e9 total = weight_memory + optimizer_memory + activation_memory + gradient_memory return { "weights": weight_memory, "optimizer": optimizer_memory, "activations": activation_memory, "gradients": gradient_memory, "total_gb": total } # Estimate: # Weights: 1.6TB # Optimizer: 3.2TB # Activations: 0.5TB (with gradient checkpointing) # Gradients: 1.6TB # Total: ~6.9TB per GPU (before distributed sharding) # Solution: ZeRO-2 or ZeRO-3 reduces to ~2TB per GPU across cluster
Inference Memory & Latency
def estimate_inference_memory( n_params=405e9, batch_size=1, seq_len=4096, precision="FP32", use_kv_cache=True ): # Model weights bytes_per_param = {"FP32": 4, "FP16": 2, "INT8": 1, "INT4": 0.5} weight_bytes = n_params * bytes_per_param[precision] # KV cache (seq_len × batch × h × d_k) # Assume h=128, d_k=128 for 405B kv_cache_bytes = 0 if use_kv_cache: kv_cache_bytes = seq_len * batch_size * 128 * 128 * 2 * 126 * 2 # 2 for K and V # Total total_gb = (weight_bytes + kv_cache_bytes) / 1e9 return total_gb # Example: 405B with INT4 quantization # Weights: 405B × 0.5 = 202.5GB (fits on 8× 40GB A100s) # KV cache: ~40GB (for seq_len=4096)
FLOPs Calculation
def calculate_training_flops( n_params, n_tokens, n_layers=126 ): # Forward pass: 2·N·T FLOPs (matrix multiply + activation) forward_flops = 2 * n_params * n_tokens # Backward pass: 2× forward (gradient computation) backward_flops = 4 * n_params * n_tokens # Total = 6·N·T total_flops = 6 * n_params * n_tokens return { "forward": forward_flops, "backward": backward_flops, "total": total_flops } # Example: LLaMA 3 405B on 15T tokens # Total: 6 × 405B × 15T = 3.6×10^24 FLOPs # On 16K H100s @ 1.5 PFLOP/s each = 24 EFLOP/s # Time: 3.6×10^24 / (24×10^18) ≈ 150K seconds ≈ 41.7 hours
Latency Analysis
| Component | FP32 | FP16/BF16 | INT8 | INT4 |
|---|---|---|---|---|
| Attention computation | 100% | 50% | 25% | 12% |
| FFN computation | 100% | 50% | 25% | 12% |
| Memory bandwidth (HBM) | 1× | 2× | 4× | 8× |
| Overall latency | 1.0 ms/token | 0.6 ms/token | 0.35 ms/token | 0.25 ms/token |
Appendix F: Model Compression Techniques
Knowledge Distillation
def knowledge_distillation_loss( student_logits, teacher_logits, labels, temperature=4.0, alpha=0.7 ): # Distillation Loss = α·CE(student, labels) + (1-α)·KL(soft_targets) # Hard target loss (standard cross-entropy) ce_loss = torch.nn.functional.cross_entropy(student_logits, labels) # Soft target loss (KL divergence with temperature) # Temperature > 1 softens the distributions student_probs = torch.log_softmax(student_logits / temperature, dim=-1) teacher_probs = torch.softmax(teacher_logits / temperature, dim=-1) kl_loss = torch.nn.functional.kl_div( student_probs, teacher_probs, reduction='batchmean' ) # Combined loss total_loss = alpha * ce_loss + (1 - alpha) * (temperature ** 2) * kl_loss return total_loss # Benefits: # - Smaller student model can match 95%+ of teacher performance # - Temperature parameter controls knowledge transfer # - Effective for any teacher-student size ratio
Layer Pruning
def prune_layers(model, prune_ratio=0.2): # Remove least important layers (compute layer importance) n_layers = len(model.layers) n_to_prune = int(n_layers * prune_ratio) # Compute layer importance via gradient norm or activation magnitude layer_importance = [] for i, layer in enumerate(model.layers): # Importance ≈ ||weight gradient|| × ||activation|| importance = torch.norm(layer.weight.grad) * torch.norm(layer.activation) layer_importance.append((i, importance)) # Sort by importance and remove least important layer_importance.sort(key=lambda x: x[1]) layers_to_remove = [x[0] for x in layer_importance[:n_to_prune]] # Create pruned model pruned_model = copy.deepcopy(model) pruned_model.layers = nn.ModuleList([ layer for i, layer in enumerate(model.layers) if i not in layers_to_remove ]) return pruned_model # Results: # - Removing 20% of layers: ~1-2% accuracy loss # - 15-20% speedup from fewer layer computations
Weight Pruning & Sparsity
| Pruning Method | Sparsity % | Accuracy Loss | Speedup | Implementation |
|---|---|---|---|---|
| Magnitude pruning | 50-90% | 2-5% | Varies (HW dependent) | Easy (mask weights) |
| Structured pruning | 20-50% | 1-3% | 2-5× | Easier HW support |
| Gradient-based pruning | 50-80% | 1-2% | Varies | Requires gradients |
| Lottery ticket hypothesis | 90%+ | Near 0% | Depends on structure | Very expensive to find |
Appendix G: Fine-tuning Recipes & Best Practices
SFT (Supervised Fine-Tuning) Recipe
# High-quality SFT training configuration """ Base model: Mistral 7B Data: 100K high-quality instruction-response pairs Optimization: - Learning rate: 5e-5 - Warmup steps: 500 - Max gradient norm: 1.0 - Optimizer: AdamW (β1=0.9, β2=0.999) - Weight decay: 0.01 - Batch size: 64 (per-device) - Gradient accumulation: 1 - Num epochs: 3 Hardware: - 8x A100 80GB - Total training time: ~6 hours Results: - MMLU: 63.2% → 72.8% (+9.6pp) - MT-Bench: 6.2 → 7.8 (+1.6pp) """ def train_sft(model, train_dataset, val_dataset): # Prepare data loaders train_loader = DataLoader( train_dataset, batch_size=64, shuffle=True ) val_loader = DataLoader( val_dataset, batch_size=64, shuffle=False ) # Setup optimizer optimizer = torch.optim.AdamW( model.parameters(), lr=5e-5, weight_decay=0.01 ) # Learning rate scheduler scheduler = torch.optim.lr_scheduler.CosineAnnealingLR( optimizer, T_max=len(train_loader) * 3 ) for epoch in range(3): for batch in train_loader: logits = model(batch["input_ids"]) loss = compute_loss(logits, batch["labels"]) loss.backward() torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0) optimizer.step() scheduler.step() optimizer.zero_grad() # Validation val_loss = evaluate(model, val_loader) print(f"Epoch {epoch} val_loss: {val_loss:.3f}")
RLHF Pipeline (Simplified)
def rlhf_pipeline(base_model, sft_model, dataset): # Step 1: SFT → already have sft_model # Step 2: Train reward model reward_model = train_reward_model( base_model, dataset.comparisons ) # Step 3: RL training with PPO rl_model = copy.deepcopy(sft_model) optimizer = torch.optim.AdamW(rl_model.parameters(), lr=1e-5) for epoch in range(5): for prompt in dataset.prompts: # Generate completions completions = rl_model.generate( prompt, num_return_sequences=4, temperature=1.0 ) # Get rewards rewards = reward_model(completions) # Compute advantages (REINFORCE or A2C) advantages = rewards - torch.mean(rewards) # Policy gradient update logprobs = rl_model.compute_logprob(completions) loss = -torch.mean(logprobs * advantages) loss.backward() torch.nn.utils.clip_grad_norm_(rl_model.parameters(), 1.0) optimizer.step() optimizer.zero_grad() return rl_model # RLHF complexity: # - Requires reward model training (expensive) # - RL training is unstable (needs careful tuning) # - But produces best human alignment results
Common Hyperparameter Ranges
| Parameter | Small Models (7B) | Medium (13-34B) | Large (70B+) | Notes |
|---|---|---|---|---|
| Learning rate | 1e-4 to 5e-5 | 5e-5 to 1e-5 | 1e-5 to 5e-6 | Smaller for bigger models |
| Batch size | 64-128 | 32-64 | 16-32 | Larger for compute efficiency |
| Warmup steps | 500-1000 | 1000-2000 | 2000-5000 | Prevents instability |
| Weight decay | 0.01 | 0.01 | 0.01 | Typical L2 regularization |
| Gradient clip | 1.0 | 1.0 | 1.0 | Prevents explosion |
| Epochs | 2-5 | 2-3 | 1-2 | Fewer for large data |
Appendix H: Prompt Engineering & In-Context Learning
Few-Shot Prompting Strategies
# Zero-shot prompt question = """Q: Solve this math problem step by step. Problem: If a train travels 100 km in 2 hours, what is its speed? A:""" # Few-shot prompt (2 examples) question_fewshot = """ Example 1: Q: A car travels 60 km in 1 hour. What is its speed? A: The car travels 60 km in 1 hour. Speed = distance/time = 60/1 = 60 km/h. Example 2: Q: A boat travels 150 km in 3 hours. What is its speed? A: The boat travels 150 km in 3 hours. Speed = distance/time = 150/3 = 50 km/h. Now solve this: Q: A train travels 100 km in 2 hours. What is its speed? A:""" # Chain-of-Thought (explicit reasoning) cot_prompt = """Let's solve this step by step. Q: If a train travels 100 km in 2 hours, what is its speed? A: Let me break this down: - We know distance = 100 km - We know time = 2 hours - Speed = distance / time - Speed = 100 / 2 = 50 km/h Therefore, the train's speed is 50 km/h."""
Prompt Optimization Techniques
| Technique | Description | When to Use | Typical Improvement |
|---|---|---|---|
| Role-playing | "You are an expert mathematician..." | Tasks requiring expertise | +5-15% |
| Few-shot examples | Include 2-5 worked examples | Complex reasoning tasks | +10-30% |
| Chain-of-Thought | "Let me think step by step" | Math, logic, reasoning | +10-40% |
| Self-consistency | Generate multiple CoT paths, vote | Hard reasoning problems | +5-10% |
| Structured output | JSON, XML, or specific format | Parsing or structured tasks | Reliability +20-30% |
| Explicit constraints | "Do NOT mention...", "Must include..." | Controlled generation | Constraint satisfaction +15-25% |
In-Context Learning Example
# Demonstrate that models learn from context without weight updates prompt = """ You will now learn a pattern from examples. Apply it to the new input. Example 1: "hello" → "olleh" Example 2: "world" → "dlrow" Example 3: "test" → "tset" Now apply the pattern to: "machine" Output:""" # The model has learned the pattern from examples: # Pattern: reverse the string # Output: "enihcam" # This is in-context learning: # - No weight updates # - Learning from prompt examples alone # - Emergent ability in large models (>1B params) # - Accuracy improves with more examples (scaling in prompts)
Appendix I: Benchmarking & Evaluation Methodology
Proper Evaluation Setup
def evaluate_model_proper(model, benchmark_dataset): # 1. Ensure test set is separate (no data leakage) assert benchmark_dataset.split == "test" # 2. Set to evaluation mode (disable dropout, etc.) model.eval() # 3. Disable gradients (faster inference) with torch.no_grad(): predictions = [] for batch in benchmark_dataset: logits = model(batch["input_ids"]) pred = torch.argmax(logits, dim=-1) predictions.append(pred) # 4. Compute metrics accuracy = (predictions == benchmark_dataset.labels).mean() f1 = compute_f1(predictions, benchmark_dataset.labels) # 5. Report with confidence intervals return { "accuracy": accuracy, "f1": f1, "n_samples": len(benchmark_dataset) } # Common mistakes to avoid: # 1. Using training data in evaluation (leakage) # 2. Tuning on test set (overfitting) # 3. Not reporting uncertainty/variance # 4. Cherry-picking metrics (report multiple) # 5. Not controlling for randomness (set seed, multiple runs)
Metric Definitions
| Metric | Formula | Interpretation | When to Use |
|---|---|---|---|
| Accuracy | (TP+TN) / (TP+TN+FP+FN) | % correct predictions | Balanced classes, simple classification |
| Precision | TP / (TP+FP) | % predicted positives that are correct | Minimize false positives (spam detection) |
| Recall | TP / (TP+FN) | % actual positives that are predicted | Minimize false negatives (disease detection) |
| F1 Score | 2·(Precision·Recall)/(Precision+Recall) | Harmonic mean of precision/recall | Imbalanced classes, balance false errors |
| ROUGE-L | LCS-based overlap (0-1) | Summary quality | Summarization, translation evaluation |
| BLEU | Precision of n-grams | Translation quality | Machine translation (outdated for LLMs) |
| Perplexity | exp(-1/N Σ log P(word_i)) | Prediction likelihood (lower=better) | Language modeling evaluation |
Computing Perplexity
def compute_perplexity(model, dataset, batch_size=32): # Perplexity = exp(average cross-entropy loss) model.eval() total_loss = 0.0 total_tokens = 0 with torch.no_grad(): for batch in DataLoader(dataset, batch_size=batch_size): logits = model(batch["input_ids"]) # Shift: predict token i+1 from token i shift_logits = logits[:, :-1, :].contiguous() shift_labels = batch["input_ids"][:, 1:].contiguous() # Compute loss loss = torch.nn.functional.cross_entropy( shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1) ) total_loss += loss.item() * shift_labels.numel() total_tokens += shift_labels.numel() avg_loss = total_loss / total_tokens perplexity = math.exp(avg_loss) return perplexity # Interpretation: # - Perplexity = 10: model is ~10x better than random guessing # - GPT-2 (small): ~20-30 on test sets # - GPT-3: ~8-15 on test sets # - State-of-the-art: ~4-8
Appendix J: Model Deployment & Serving
Efficient Model Serving
# Model serving with batching and KV cache management import asyncio class ModelServer: def __init__(self, model, batch_size=32, max_wait_ms=10): self.model = model self.batch_size = batch_size self.max_wait_ms = max_wait_ms self.request_queue = asyncio.Queue() async def add_request(self, prompt, max_tokens=256): # Add request to queue future = asyncio.Future() await self.request_queue.put({ "prompt": prompt, "max_tokens": max_tokens, "future": future }) return await future async def batch_processing_loop(self): # Process batches of requests while True: # Wait for batch or timeout batch = [] start_time = time.time() try: while len(batch) < self.batch_size: timeout = self.max_wait_ms / 1000.0 request = await asyncio.wait_for( self.request_queue.get(), timeout=timeout ) batch.append(request) except asyncio.TimeoutError: pass if not batch: continue # Process batch in parallel prompts = [r["prompt"] for r in batch] outputs = await asyncio.gather(*[ self.model.generate(p, 256) for p in prompts ]) # Return results to requesters for request, output in zip(batch, outputs): request["future"].set_result(output) # Batching benefits: # - GPUs are 10-50× faster with batching # - Trade-off: latency vs throughput # - Dynamic batching: request-level optimization
Deployment Hardware Considerations
| Hardware | Memory | FP32 Model Size | INT4 Model Size | Inference Speed (tok/s) |
|---|---|---|---|---|
| A100 40GB | 40GB | ~10B params | ~80B params | 100-300 |
| A100 80GB | 80GB | ~20B params | ~160B params | 100-300 |
| H100 80GB | 80GB | ~20B params | ~160B params | 300-1000 |
| A10 24GB | 24GB | ~6B params | ~50B params | 30-100 |
| CPU (XEON) | Unlimited | Any size (slow) | Any size (slow) | 1-10 |
Appendix K: Advanced Architecture Variants
Vision Transformer (ViT) Integration
class MultimodalLLM(nn.Module): def __init__(self, vision_encoder, language_model): super().__init__() self.vision_encoder = vision_encoder # ViT self.language_model = language_model # LLM # Projection from vision to language space self.vision_to_lang = nn.Linear(768, 4096) def forward(self, images, text_ids): # Process images with vision encoder # images: (batch, 3, 224, 224) image_embeddings = self.vision_encoder(images) # → (batch, num_patches, 768) # Project to language model space image_features = self.vision_to_lang(image_embeddings) # → (batch, num_patches, 4096) # Process text with language model text_embeddings = self.language_model.embedding(text_ids) # Interleave vision and language tokens # Option 1: Prefix - prepend image tokens before text combined = torch.cat([image_features, text_embeddings], dim=1) # Process with transformer output = self.language_model.transformer(combined) logits = self.language_model.head(output) return logits # Multimodal integration strategies: # 1. Prefix: image tokens before text (simplest) # 2. Interleaved: alternate image-text tokens (flexible) # 3. Cross-attention: separate vision processing with attention to text # 4. Adapter: small projection network between encoders
Mamba SSM Architecture
class MambaBlock(nn.Module): # State-Space Model: Linear O(n) attention alternative def __init__(self, d_model, d_state=16): super().__init__() self.d_model = d_model self.d_state = d_state # SSM parameters self.A = nn.Parameter(torch.randn(d_model, d_state)) self.B = nn.Linear(d_model, d_state) self.C = nn.Linear(d_model, d_state) # Input/output projections self.D_proj = nn.Linear(d_model, d_model) self.out_proj = nn.Linear(d_model, d_model) def forward(self, x): # x: (batch, seq_len, d_model) # Compute B and C at each step (data-dependent) B = self.B(x) # (batch, seq, d_state) C = self.C(x) # (batch, seq, d_state) # State: h_t = A @ h_{t-1} + B_t @ x_t # Output: y_t = C_t @ h_t h = torch.zeros(x.shape[0], self.d_state, device=x.device) outputs = [] for t in range(x.shape[1]): # State update h = self.A.unsqueeze(0) @ h.unsqueeze(-1) + \ (B[:, t] * x[:, t]).unsqueeze(-1) h = h.squeeze(-1) # Output y = (C[:, t] * h).sum(dim=-1) outputs.append(y) output = torch.stack(outputs, dim=1) return self.out_proj(output) # Mamba advantages: # - O(n) complexity vs O(n²) for attention # - Linear scaling with sequence length # - Better for very long sequences (>1M tokens) # - Trade-off: slightly lower quality than transformers
Appendix L: Active Research Directions
Research Areas & Recent Work
| Area | Key Papers/Models | Goal | Current State |
|---|---|---|---|
| Long Context | ALiBi, LongLoRA, Ring Attention | Extend to 1M+ tokens | Promising (128K-1M practical) |
| Efficient Training | LIMA, LoRA, Adapter modules | Train with 1% of data | Mature (routinely used) |
| Reasoning | o1, DeepSeek-R1, Qwen-QwQ | Better at math/logic | Emergent (scaling RL improves) |
| Multimodal | GPT-4V, Claude 3, Gemini | Understanding images/video | Improving rapidly (video next) |
| Interpretability | Transformer circuits, SAEs | Understand model internals | Early stage (limited scale) |
| Scaling Laws | Chinchilla, Compute Optimal | Predict performance with compute | Mature (guides training) |
Training Objective Innovations
# Next-token prediction + auxiliary losses def compute_hybrid_loss(logits, labels, hidden_states): # Main objective: cross-entropy on next token ce_loss = torch.nn.functional.cross_entropy( logits.view(-1, logits.size(-1)), labels.view(-1) ) # Auxiliary: contrastive learning on representations # Encourage similar tokens to have similar representations sim_matrix = torch.nn.functional.cosine_similarity( hidden_states.unsqueeze(1), hidden_states.unsqueeze(0), dim=-1 ) contrastive_loss = compute_nt_xent_loss(sim_matrix, labels) # Auxiliary: auxiliary prediction heads # Predict sentence boundaries, document structure, etc. boundary_pred = self.boundary_head(hidden_states) boundary_loss = torch.nn.functional.binary_cross_entropy( boundary_pred, boundary_labels ) # Combined loss total_loss = ce_loss + 0.1 * contrastive_loss + 0.05 * boundary_loss return total_loss # Auxiliary losses benefits: # - Improve convergence speed (~10% faster) # - Better generalization on downstream tasks # - Help with long-range dependencies