LLM Model Distillation — Techniques, Training & Deployment
Comprehensive guide to knowledge distillation for large language models — from embedding compression through LoRA/QLoRA fine-tuning to quantization-aware training and production deployment.
A comprehensive guide to knowledge distillation techniques for large language models. Learn theory, training recipes, evaluation strategies, deployment patterns, and production optimization.
What is Model Distillation?
Teacher-Student Paradigm for Knowledge Transfer
Why Distillation Matters in Production
- ✓ 10-50x latency reduction in production
- ✓ 70-90% cost savings on inference
- ✓ Deploy to edge, mobile, low-latency APIs
- ✓ Reduce hallucination via better grounding
- ✓ Enable real-time retrieval augmentation
Key Concepts
- • Teacher: Large, accurate model
- • Student: Smaller, faster model
- • Knowledge: Outputs, features, attention
- • Loss: KL divergence + task loss
- • Temperature: Controls softness of targets
Distillation transfers the teacher's knowledge (logits, intermediate features, attention patterns) into a student model, preserving 90-98% of quality while achieving 10-50x speedup. Perfect for production systems where you need fast, cost-efficient components.
Distillation Fundamentals
Core Techniques & Loss Functions
Logit Distillation
Transfer teacher's output probability distribution via KL divergence. Student learns soft targets from teacher at temperature T.
L_KL = T² × KL(teacher_p, student_p)
Feature Distillation
Match intermediate layer representations. Student learns hidden states from teacher's layers via MSE loss.
L_feat = MSE(student_h, teacher_h)
Attention Transfer
Align attention weights between teacher and student. Guides student to focus on same tokens.
L_attn = MSE(student_A, teacher_A)
Contrastive Distillation
Use contrastive learning to align teacher-student embeddings. Useful for embedding model distillation.
L_cont = -log(exp(sim) / Σexp(sims))
Temperature Scaling
Temperature (T) controls softness of teacher's output distribution. Higher T → softer targets → more information about wrong classes. At inference, use T=1 (normal softmax). During distillation training, use T=3-20.
soft_probs = softmax(logits / T)
Combined Loss Function
# Combined distillation loss
def distillation_loss(logits_student, logits_teacher, labels, T=4, alpha=0.7):
# Soft targets from teacher
teacher_soft = F.softmax(logits_teacher / T, dim=-1)
student_soft = F.log_softmax(logits_student / T, dim=-1)
# KL divergence (logit distillation)
kl_loss = F.kl_div(student_soft, teacher_soft, reduction='batchmean')
# Cross-entropy on hard targets
ce_loss = F.cross_entropy(logits_student, labels)
# Combined loss
loss = alpha * (T ** 2) * kl_loss + (1 - alpha) * ce_loss
return loss
Temperature (T): 3-20 typical. Higher = softer knowledge transfer. Alpha (α): 0.5-0.9 typical. Higher = more weight on distillation vs task loss. Learning rate: 2-5x lower than standard fine-tuning.
Distillation Landscape
Where & When to Apply Distillation in ML Pipelines
| Component | What to Distill | Expected Speedup | Quality Loss | Cost Savings |
|---|---|---|---|---|
| Embedding | Logits, intermediate layers | 10-20x | 2-8% | 70-85% |
| Reranker | Scores, attention | 15-30x | 1-5% | 80-90% |
| Generator | Logits, hidden states | 5-15x | 3-10% | 70-95% |
| Query Transform | Logits, hidden states | 20-50x | 2-6% | 85-95% |
High-volume components (retrieval, reranking): Distill aggressively. Speedup pays for latency. Single-call components (generation): Moderate distillation. Quality is critical. Cascaded components: Distill each stage independently, then validate end-to-end.
Embedding Model Distillation
Efficient Retrievers via Dimension Reduction & Matryoshka Loss
Problem
- • text-embedding-3-large: 3072d, slow
- • Memory: 12GB+ for inference
- • Latency: 100-500ms per query
- • Cost: $0.13 per 1M tokens
Solution
- • Distill to 384d embedding
- • Memory: 500MB
- • Latency: 5-10ms per query
- • Cost: 100x cheaper
Techniques
Dimension Reduction
Project 3072d → 384d via linear layer. Student learns to map teacher embeddings to lower dimension while preserving semantic relationships.
Matryoshka Embeddings
Train with multi-scale loss. At layer i, enforce meaningful embeddings at dimension 2^i (64, 128, 256, 384). Enables flexible dimension selection.
from sentence_transformers import SentenceTransformer, models
import torch
import torch.nn as nn
# Load teacher model
teacher_model = SentenceTransformer("text-embedding-3-large")
teacher_dim = 3072
student_dim = 384
# Create student with dimension reduction
word_embedding_model = models.Transformer("microsoft/MiniLM-L6-H384")
pooling_model = models.Pooling(word_embedding_model.get_word_embedding_dimension())
dense = models.Dense(
in_features=word_embedding_model.get_word_embedding_dimension(),
out_features=student_dim,
activation_function=nn.Tanh()
)
student_model = SentenceTransformer(modules=[word_embedding_model, pooling_model, dense])
# Contrastive distillation loss
def embedding_distillation_loss(teacher_emb, student_emb, temperature=0.07):
# Normalize embeddings
teacher_emb = nn.functional.normalize(teacher_emb, p=2, dim=1)
student_emb = nn.functional.normalize(student_emb, p=2, dim=1)
# Cosine similarity matrix
sim_matrix = torch.matmul(student_emb, teacher_emb.t()) / temperature
# Contrastive loss (NT-Xent)
labels = torch.arange(student_emb.size(0), device=student_emb.device)
loss = nn.CrossEntropyLoss()(sim_matrix, labels)
return loss
| Model | Dimension | Latency (ms) | Recall@1 | Recall@10 | Cost per 1M tokens |
|---|---|---|---|---|---|
| text-embedding-3-large | 3072 | 250-500 | 92.5% | 98.1% | $0.13 |
| text-embedding-3-small | 1536 | 100-200 | 90.2% | 97.4% | $0.02 |
| Distilled MiniLM (384d) | 384 | 5-10 | 89.8% | 96.9% | $0.001 |
Data: Use same domain as production retrieval data. Temperature: 0.05-0.07 for embeddings. Batch size: 256-512 to maximize contrastive signal. Training steps: 50K-100K for high-quality distillation.
Reranker Distillation
Cross-Encoder to Bi-Encoder & Score Distillation
Cross-Encoder Reranker
- • Model: DeBERTa-large, 434M params
- • Input: Concatenate [Q, SEP, D]
- • Output: Relevance score (0-1)
- • Speed: 200-500ms per doc
- • NDCG@10: 0.625
Bi-Encoder Reranker
- • Model: MiniLM, 22M params
- • Input: Embed Q & D separately
- • Output: Dot product similarity
- • Speed: 5-10ms per doc
- • NDCG@10: 0.615 (98% quality)
Distillation Strategy: Score Margin Loss
import torch
import torch.nn.functional as F
# Cross-encoder teacher scores relevant & irrelevant docs
def margin_mse_loss(student_scores, teacher_scores, margin=0.5):
# Assume scores are [batch_size, num_docs]
# Positive: score[0] (relevant), Negative: score[1:] (irrelevant)
pos_score_student = student_scores[:, 0] # Positive doc
neg_score_student = student_scores[:, 1] # Negative doc
pos_score_teacher = teacher_scores[:, 0]
neg_score_teacher = teacher_scores[:, 1]
# Margin loss: enforce student margin >= teacher margin - margin_param
student_margin = pos_score_student - neg_score_student
teacher_margin = pos_score_teacher - neg_score_teacher
loss = F.relu(teacher_margin - student_margin + margin)
return loss.mean()
# Alternative: KL divergence on ranking softmax
def listwise_kl_loss(student_scores, teacher_scores, temperature=4):
teacher_probs = F.softmax(teacher_scores / temperature, dim=-1)
student_log_probs = F.log_softmax(student_scores / temperature, dim=-1)
loss = F.kl_div(student_log_probs, teacher_probs, reduction='batchmean')
return loss * (temperature ** 2)
ColBERT: Late Interaction Distillation
Hybrid approach: compute embeddings separately (like bi-encoder), then match at interaction layer (like cross-encoder). Much faster than full cross-encoder, more accurate than simple bi-encoder.
score = Σ_i max_j sim(Q_i, D_j) # Maximize interaction
| Model | Architecture | Latency/doc | NDCG@10 | MRR@10 | Cost (1M queries) |
|---|---|---|---|---|---|
| DeBERTa-large (Teacher) | Cross-encoder | 300ms | 0.625 | 0.758 | $150 |
| ColBERT (Distilled) | Late interaction | 15ms | 0.618 | 0.752 | $8 |
| MiniLM Bi-encoder | Bi-encoder | 2ms | 0.615 | 0.745 | $2 |
Reranker distillation is sensitive to margin parameter. Start with margin=0.3, increase gradually. Use in-batch negatives to stabilize training. Monitor margin distribution across iterations.
Generator Distillation for Production
Distilling GPT-4 & Claude into Smaller Models
Teacher: GPT-4/Claude
- • Model: 1T+ params (estimated)
- • Quality: 95%+ factually correct
- • Cost: $30-60 per 1M tokens
- • Latency: 2-5 sec
- • Hallucination: ~5%
Student: Llama-3-8B
- • Model: 8B params
- • Quality: 88-92% (after distillation)
- • Cost: $0.50 per 1M tokens
- • Latency: 200-400ms
- • Hallucination: ~12% (with context)
Output Distillation (Synthetic Data)
# Step 1: Generate synthetic training data with teacher
def generate_distillation_data(queries, documents, teacher_model, num_examples=5000):
data = []
for query, doc in zip(queries, documents):
# Teacher generates response with context
prompt = f"""Answer the question based on the context.
Context: {doc}
Question: {query}
Answer:"""
response = teacher_model.generate(prompt) # e.g., GPT-4
data.append({
"query": query,
"context": doc,
"response": response
})
return data
# Step 2: Fine-tune student on synthetic data
from transformers import Trainer, TrainingArguments, AutoModelForCausalLM
student_model = AutoModelForCausalLM.from_pretrained("meta-llama/Llama-3-8b")
training_args = TrainingArguments(
output_dir="./distilled-llama",
learning_rate=2e-5,
per_device_train_batch_size=8,
num_train_epochs=3,
gradient_accumulation_steps=4,
)
trainer = Trainer(
model=student_model,
args=training_args,
train_dataset=dataset,
)
trainer.train()
Chain-of-Thought Distillation
Distill not just final answer but reasoning steps. Teacher outputs step-by-step reasoning, student learns to generate intermediate thoughts. Improves factuality and makes errors more traceable.
[Thought 1] → [Thought 2] → ... → [Answer]
| Model | Params | Latency | RAGAS Faithfulness | RAGAS Relevance | Cost per 1K tokens |
|---|---|---|---|---|---|
| GPT-4 (Teacher) | 1T+ | 3-5s | 0.94 | 0.92 | $0.06 |
| Llama-3-70B | 70B | 1-2s | 0.88 | 0.87 | $0.01 |
| Llama-3-8B (Distilled) | 8B | 200-400ms | 0.89 | 0.86 | $0.0005 |
Context: Always include retrieved documents in prompt. This grounds student. Diversity: Mix easy & hard examples. Temperature: 0.3 for deterministic outputs during distillation. Validation: Use RAGAS metrics to track faithfulness.
Query Transformer Distillation
Fast Query Expansion & Rewrite
Use Case: Query Expansion
Teacher (GPT-4) generates 3-5 reformulations of user query to improve retrieval coverage. Student (T5-small) learns to do same in 5ms.
"basketball" → ["basketball game", "NBA", "court sport", ...]
Use Case: Query Rewrite
Teacher rewrites conversational queries to standalone form. Improves multi-turn retrieval.
"What about alternatives?" → "What are alternatives to X?"
Training Recipe
# Generate query expansion training data
def generate_query_expansion_data(original_queries, teacher_model):
training_pairs = []
for query in original_queries:
prompt = f"""Generate 3 alternative search queries for: {query}
Format: query1 ||| query2 ||| query3"""
expansions = teacher_model.generate(prompt)
training_pairs.append({
"input": query,
"target": expansions
})
return training_pairs
# Fine-tune T5-small on seq2seq task
from transformers import T5ForConditionalGeneration, T5Tokenizer
model = T5ForConditionalGeneration.from_pretrained("google/flan-t5-small")
tokenizer = T5Tokenizer.from_pretrained("google/flan-t5-small")
def compute_loss(batch):
inputs = tokenizer(batch["input"], max_length=128, padding="max_length", return_tensors="pt")
labels = tokenizer(batch["target"], max_length=256, padding="max_length", return_tensors="pt")
outputs = model(input_ids=inputs.input_ids, labels=labels.input_ids)
return outputs.loss
| Model | Params | Latency | Expansion Quality | Recall Lift |
|---|---|---|---|---|
| GPT-4 (Teacher) | 1T+ | 2-3s | 95% | +18% |
| T5-small (Distilled) | 60M | 5-10ms | 92% | +17% |
Data: Use production queries + relevance judgments. Target diversity: Generate 3-5 expansions per query. Evaluation: Measure recall lift, not exact match. Integration: Combine expansions via union retrieval.
End-to-End Training Recipes
Complete Pipelines with Data Prep & Hypertuning
Data Preparation
Step 1: Generate Teacher Labels
Use teacher model to label training data. For embeddings: pair queries with hard negatives. For rerankers: score documents. For generators: generate answers with context.
Step 2: Filter & Balance
Remove ambiguous examples (low teacher confidence). Balance difficulty. For reranker: ensure positive scores much higher than negative.
# Complete training pipeline with HuggingFace Trainer
from transformers import Trainer, TrainingArguments, AutoModelForSequenceClassification
import torch
# Training arguments
training_args = TrainingArguments(
output_dir="./distilled-model",
num_train_epochs=3,
learning_rate=2e-5,
per_device_train_batch_size=32,
per_device_eval_batch_size=64,
warmup_steps=500,
weight_decay=0.01,
logging_steps=100,
eval_strategy="steps",
eval_steps=500,
save_strategy="steps",
save_steps=500,
load_best_model_at_end=True,
metric_for_best_model="eval_loss",
greater_is_better=False,
gradient_accumulation_steps=4,
)
# Load student model
model = AutoModelForSequenceClassification.from_pretrained(
"microsoft/MiniLM-L6-H384-uncased",
num_labels=1
)
# Custom distillation loss
class DistillationTrainer(Trainer):
def __init__(self, teacher_model=None, temperature=4, alpha=0.7, *args, **kwargs):
super().__init__(*args, **kwargs)
self.teacher = teacher_model
self.temperature = temperature
self.alpha = alpha
def compute_loss(self, model, inputs, return_outputs=False):
# Student forward
student_outputs = model(**inputs)
student_logits = student_outputs.logits
# Teacher forward
with torch.no_grad():
teacher_outputs = self.teacher(**inputs)
teacher_logits = teacher_outputs.logits
# Distillation loss
soft_targets = torch.nn.functional.softmax(teacher_logits / self.temperature, dim=-1)
student_log_soft = torch.nn.functional.log_softmax(student_logits / self.temperature, dim=-1)
kl_loss = torch.nn.functional.kl_div(student_log_soft, soft_targets, reduction='batchmean')
# Task-specific loss (MSE for scoring)
task_loss = torch.nn.functional.mse_loss(student_logits, teacher_logits)
# Combined loss
loss = self.alpha * (self.temperature ** 2) * kl_loss + (1 - self.alpha) * task_loss
return (loss, student_outputs) if return_outputs else loss
# Initialize trainer
trainer = DistillationTrainer(
model=model,
teacher_model=teacher,
temperature=4,
alpha=0.7,
args=training_args,
train_dataset=train_dataset,
eval_dataset=eval_dataset,
)
# Train
trainer.train()
Hardware & Optimization
| Setup | GPU Memory | Batch Size | Training Time (100K examples) | Cost |
|---|---|---|---|---|
| Single A100 (40GB) | 40GB | 32 | 2-4 hours | $8-16 |
| 4× A100 (40GB) DDP | 160GB | 256 | 30-45 min | $20-30 |
| Single A40 (48GB) with DeepSpeed | 48GB | 64 | 1.5-2 hours | $5-8 |
Learning rate: 1e-5 to 5e-5. Lower for larger models. Temperature: 3-20. Higher = more soft targets. Alpha: 0.5-0.9. Higher = more distillation vs task. Warmup: 5-10% of steps. Weight decay: 0.01-0.1.
Evaluation & Benchmarking
Measuring Distillation Quality by Component
Embedding Model Evaluation
MTEB Benchmark Metrics
- • Recall@K: Fraction of relevant items in top-K
- • NDCG@K: Normalized discounted cumulative gain
- • MAP: Mean average precision across queries
- • MRR: Mean reciprocal rank
# Evaluate embedding model with MTEB
from mteb import MTEB
tasks = ["STS12", "STS13", "STS14", "STS15", "STS16",
"STSBenchmark", "SummEval"]
evaluation = MTEB(tasks=tasks, task_langs=["en"])
results = evaluation.run(model, output_folder="results")
# Retrieval benchmark
retrieval_tasks = ["TREC-COVID", "DBpedia", "SCIFACT"]
results = MTEB(tasks=retrieval_tasks).run(model)
# Example: Check recall@1
for task, score in results.items():
print(f"{task}: Recall@1 = {score['recall@1']:.3f}")
Reranker Evaluation
| Metric | Definition | Target Threshold |
|---|---|---|
| NDCG@10 | Discounted gain at position 10 | >95% of teacher |
| MRR@10 | Reciprocal rank of first relevant | >95% of teacher |
| MAP@1000 | Mean average precision across ranking | >93% of teacher |
Generator Evaluation (RAGAS)
# Evaluate generation quality with RAGAS
from ragas import evaluate
from ragas.metrics import faithfulness, answer_relevancy, context_recall
# Prepare evaluation dataset
rag_results = {
"question": [...],
"answer": [...], # Generated by student model
"contexts": [...], # Retrieved documents
"ground_truth": [...] # Reference answers
}
# Compute metrics
score = evaluate(
rag_results,
metrics=[faithfulness, answer_relevancy, context_recall]
)
print(f"Faithfulness: {score['faithfulness']:.3f}")
print(f"Answer Relevancy: {score['answer_relevancy']:.3f}")
print(f"Context Recall: {score['context_recall']:.3f}")
A/B Testing in Production
Canary deployment: Route 5-10% of traffic to distilled model. Monitor latency, cost, quality metrics. If stable, increase to 50%, then 100%.
Metrics to track: Latency P50/P95/P99, hallucination rate (human review sample), user satisfaction (thumbs up/down), business KPIs (conversions, retention).
Rollback trigger: >5% regression in any critical metric. Keep teacher model running in parallel for 48 hours.
Set alerts for >2% drop in key metrics. Use sequential probability ratio tests (SPRT) for early stopping. Monitor distribution shift—if data changes significantly, re-distill.
LoRA & Parameter-Efficient Distillation
Combine PEFT Methods with Knowledge Transfer
LoRA (Low-Rank Adaptation)
Add low-rank matrices to attention layers. Train only 0.1-1% of parameters. Compatible with distillation.
W = W_frozen + α(A × B)
QLoRA (Quantized LoRA)
Quantize base model to 4-bit. Add LoRA on top. Fit 70B model in 48GB GPU.
70B model → 16GB VRAM
LoRA + Distillation Training
from peft import get_peft_model, LoraConfig
from transformers import AutoModelForCausalLM, BitsAndBytesConfig
# QLoRA config: 4-bit quantization
bnb_config = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_use_double_quant=True,
bnb_4bit_quant_type="nf4",
bnb_4bit_compute_dtype=torch.bfloat16,
)
# Load student model with quantization
model = AutoModelForCausalLM.from_pretrained(
"meta-llama/Llama-3-8b",
quantization_config=bnb_config,
device_map="auto"
)
# LoRA config: target attention weights
lora_config = LoraConfig(
r=8, # LoRA rank
lora_alpha=16,
target_modules=["q_proj", "v_proj"],
lora_dropout=0.05,
bias="none",
task_type="CAUSAL_LM"
)
model = get_peft_model(model, lora_config)
print(model.print_trainable_parameters()) # ~0.1% of 8B
# Distillation loss with LoRA
def lora_distillation_loss(student_logits, teacher_logits, temperature=4):
teacher_soft = F.softmax(teacher_logits / temperature, dim=-1)
student_soft = F.log_softmax(student_logits / temperature, dim=-1)
kl_loss = F.kl_div(student_soft, teacher_soft, reduction='batchmean')
return kl_loss * (temperature ** 2)
# Train with minimal memory
optimizer = torch.optim.AdamW(model.parameters(), lr=1e-4)
for batch in dataloader:
student_out = model(**batch)
teacher_out = teacher_model(**batch)
loss = lora_distillation_loss(student_out.logits, teacher_out.logits)
loss.backward()
optimizer.step()
| Method | GPU Memory | Trainable Params | Distillation Quality | Training Speed |
|---|---|---|---|---|
| Full Fine-tune | 80GB | 100% | Best | 1x |
| LoRA (r=8) | 48GB | 0.2% | 98% | 0.95x |
| QLoRA (r=8, 4-bit) | 16GB | 0.2% | 97% | 0.8x |
Rank (r): 4-16 typical. Higher rank = more capacity but slower. Target modules: Attention weights (q_proj, v_proj) usually sufficient. Initialization: Gaussian random with std = 1/rank. Merging: After training, merge LoRA weights into base model for inference.
Quantization + Distillation Synergy
Maximum Compression via Combined Techniques
Quantization-Aware Distillation
Simulate quantization during training. Student learns to be robust to quantization noise. Better quality than post-hoc quantization.
Execution Order
1. Distill → 2. Quantize
vs.
1. Distill aware of quantization (better quality)
Quantization Techniques
INT8
8-bit integer weights. 4x compression. Minimal quality loss. Easy integration.
INT4
4-bit quantization. 8x compression. Requires distillation for quality.
GPTQ/AWQ
Weight-only quantization. Fast inference. Good for LLMs.
# Quantization-aware distillation with fake quantization
import torch.quantization as quant
# Prepare model for QAT (Quantization Aware Training)
model.qconfig = quant.get_default_qat_qconfig('fbgemm')
quant.prepare_qat(model, inplace=True)
# Training loop with fake quantization
def qat_distillation_step(student, teacher, batch, temperature=4):
# Student forward (includes fake quant)
student_out = student(**batch)
# Teacher forward (no quant)
with torch.no_grad():
teacher_out = teacher(**batch)
# Distillation loss
teacher_soft = F.softmax(teacher_out.logits / temperature, dim=-1)
student_log = F.log_softmax(student_out.logits / temperature, dim=-1)
loss = F.kl_div(student_log, teacher_soft) * (temperature ** 2)
return loss
# Post-training: convert to INT8
quant.convert(model, inplace=True)
# Using GPTQ for 4-bit LLM quantization
from auto_gptq import AutoGPTQForCausalLM, BaseQuantizeConfig
gptq_config = BaseQuantizeConfig(
bits=4, # 4-bit quantization
group_size=128,
desc_act=False,
)
model = AutoGPTQForCausalLM.from_pretrained(
"meta-llama/Llama-3-8b",
quantize_config=gptq_config
)
| Technique | Bits | Compression | Quality Loss | Inference Speed | Combined with Distill |
|---|---|---|---|---|---|
| Original FP32 | 32 | 1x | — | 1x | — |
| INT8 Only | 8 | 4x | 1-2% | 2x | ↓ Loss to 0.5% |
| INT4 Only | 4 | 8x | 5-10% | 4x | ↓ Loss to 2-3% |
| Distilled + GPTQ 4-bit | 4 | 80x | 3-5% | 20x | Combined |
Step 1: Distill teacher to student (90-95% quality). Step 2: Quantize student with QAT (distill with fake quant active). Step 3: Convert to INT4/GPTQ. Result: 80x compression with 93-97% quality.
Production Deployment
Serving Distilled Models at Scale
Inference Frameworks
vLLM
High-throughput LLM inference. Paged attention, continuous batching. 10-50x faster than vanilla HuggingFace.
TensorRT
NVIDIA's inference optimizer. Optimized kernels, automatic optimization. Best for NVIDIA GPUs.
ONNX Runtime
Cross-platform, cross-hardware. CPU/GPU/mobile. Good for edge deployment.
Triton Inference Server
Multi-model serving. Dynamic batching, ensemble pipelines. For production LLM endpoints.
A/B Testing & Canary Deployment
# A/B testing setup with vLLM
from vllm import LLM, SamplingParams
import random
# Load teacher and student models
teacher_model = LLM(model="gpt2", gpu_memory_utilization=0.8)
student_model = LLM(model="distilled-gpt2", gpu_memory_utilization=0.8)
sampling_params = SamplingParams(temperature=0.7, top_p=0.9, max_tokens=256)
# Canary deployment: 10% student, 90% teacher
def generate_with_ab_test(prompt, canary_rate=0.1):
if random.random() < canary_rate:
# Use student model (distilled)
model = student_model
variant = "student"
else:
# Use teacher model (original)
model = teacher_model
variant = "teacher"
outputs = model.generate([prompt], sampling_params=sampling_params)
# Log for analysis
log_metric({
"prompt": prompt,
"variant": variant,
"output": outputs[0].outputs[0].text,
"latency": outputs[0].metrics["latency"],
})
return outputs[0].outputs[0].text
# Increase canary rate over time
canary_schedule = {
"hour_0": 0.05, # 5%
"hour_2": 0.10, # 10%
"hour_6": 0.25, # 25%
"hour_12": 0.50, # 50%
"hour_24": 1.00, # 100% full cutover
}
Monitoring & Alerting
| Metric | Normal | Warning | Critical |
|---|---|---|---|
| P50 Latency | <50ms | 50-100ms | >100ms |
| P99 Latency | <200ms | 200-500ms | >500ms |
| Quality (NDCG) | >0.60 | 0.57-0.60 | <0.57 |
| Error Rate | <0.1% | 0.1-0.5% | >0.5% |
Automatic: If error rate >1% or quality drops >5%, auto-rollback to teacher. Manual: Keep teacher running in parallel for 48 hours. Hotfix: Disable distilled model, re-train if data distribution changed significantly.
Cost Analysis & ROI
When Distillation Pays Off
Per-Component Cost Breakdown
| Component | Teacher Model | Distilled Model | Cost per 1M Requests | Savings (1M req/day) |
|---|---|---|---|---|
| Embedding (1M docs) | text-embedding-3-large: $0.13 | MiniLM-384d: $0.001 | $130 → $1 | $129/day = $47K/yr |
| Reranker (10K docs/req) | DeBERTa-large: $150 | ColBERT: $8 | $150 → $8 | $142/day = $52K/yr |
| Generator (512 tokens) | GPT-4: $30 | Llama-8B: $0.50 | $30 → $0.50 | $29.50/day = $10.8K/yr |
| TOTAL (Full Pipeline) | $180 per 1M | $9.50 per 1M | $180 → $9.50 | $170.50/day = $62.2K/yr |
Training Investment vs. Savings
One-Time Training Cost
- • 4× A100 GPU: 24 hours
- • GPU cost: $1,500
- • Labeling/data prep: $2,000
- • Validation/testing: $1,000
- • Total: $4,500
Break-Even Analysis
- • Savings per day: $170
- • Break-even: $4,500 / $170 = 26 days
- • Monthly ROI: 11x
- • Yearly ROI: 138x
- • 1M requests/day: YES
Distillation Makes Sense When:
- ✓ High volume: >100K requests/day per component
- ✓ Cost-sensitive: Inference cost is significant (>10% of budget)
- ✓ Latency-critical: Need <100ms P99 latency
- ✓ Stable workload: Data distribution doesn't change rapidly
- ✓ Quality tolerance: Can accept 2-5% quality drop
Volume-Based ROI Calculator
Monthly Cost (No Distill) = daily_requests × daily_cost
Monthly Cost (Distilled) = daily_requests × distilled_daily_cost
Monthly Savings = Cost(No Distill) - Cost(Distilled)
Payback Period (months) = Training Cost / Monthly Savings
Example: 1M requests/day for production LLM systems
• Current cost: $180/day → $5,400/month
• Distilled cost: $9.50/day → $285/month
• Savings: $5,115/month
• Training cost: $4,500 (one-time)
• Payback: 0.88 months (27 days)
• Year 1 ROI: 16.8x
Month 1: Distill components (est. $4.5K). Month 2-12: Save $5.1K/month. Year 2: Pure savings ($61K+). Total 2-year savings: $66.5K after investment.
Real-World Case Studies
Production Success Stories
Case 1: E-Commerce Document Retrieval
Challenge
E-commerce platform with 10M product descriptions. text-embedding-3-large too slow (300ms/query). Cost: $40K/month.
Solution
Distill to MiniLM-384d using contrastive loss on 100K product pairs.
Results:
- • Latency: 300ms → 8ms (37x faster)
- • Recall@10: 94.2% → 92.8% (98.5% quality)
- • Cost: $40K/mo → $1.2K/mo (97% savings)
- • Training: 2 A100-days ($400) + labeling ($1K)
- • Payback: 9 days | Year 1 ROI: 405x
Case 2: SaaS Question-Answering System
Challenge
Support chatbot using GPT-4 with retrieval-augmented generation. High latency (3s), high cost ($100K/month). Users frustrated with wait times.
Solution
Distill GPT-4 to Llama-3-8B using 50K QA pairs + output distillation + LoRA.
Results:
- • Latency: 3000ms → 350ms (8.6x faster)
- • RAGAS Faithfulness: 0.94 → 0.89 (94.7% quality)
- • Cost: $100K/mo → $2.5K/mo (97.5% savings)
- • Training: 4× A100 × 2 days ($3K) + labeling ($5K)
- • Payback: 4 days | Year 1 ROI: 315x
Case 3: Search Ranking with Reranker
Challenge
Cross-encoder reranker (DeBERTa) bottleneck. Must rerank top-100 per query. 200ms per request. P99 latency: 500ms.
Solution
Distill to ColBERT (late interaction). Score distillation + margin loss.
Results:
- • Latency: 200ms → 12ms (16.7x faster)
- • NDCG@10: 0.625 → 0.618 (98.8% quality)
- • P99 latency: 500ms → 35ms (14x improvement)
- • Hardware: 4 GPUs → 1 GPU (75% cost)
- • Training: 1 A100-day ($200) + 10K labeled pairs ($2K)
- • Payback: 5 days | Year 1 ROI: 200x
1. High volume: All cases had 1M+ requests/day. 2. Clear bottleneck: One slow component identified. 3. High-quality teacher: Started with strong model (GPT-4, DeBERTa). 4. Fast payback: Most broke even <2 weeks. 5. Conservative rollout: Canary to 100% over 48 hours.
Production Checklist
De-Risk Your Distillation Rollout
Data Quality
- ☐ Collected 100K+ training examples
- ☐ Verified label distribution matches production
- ☐ Removed outliers/ambiguous examples
- ☐ Split: 80/10/10 train/val/test
- ☐ Validated teacher consistency
Training
- ☐ Ran hyperparameter sweep
- ☐ Logged learning curves
- ☐ Confirmed convergence on val set
- ☐ Saved checkpoints every 500 steps
- ☐ Tested inference latency
Evaluation
- ☐ Computed metrics on test set
- ☐ Verified >93% quality vs teacher
- ☐ Human evaluation (100 examples)
- ☐ Error analysis completed
- ☐ Documented regressions
Deployment
- ☐ Converted to ONNX/TensorRT
- ☐ Optimized model for target hardware
- ☐ Benchmarked P50/P95/P99
- ☐ Set up model serving (vLLM/Triton)
- ☐ Containerized application
A/B Testing
- ☐ Setup canary at 5%
- ☐ Monitoring dashboards ready
- ☐ Alerts configured (>2% regression)
- ☐ Rollback procedure documented
- ☐ Run for 24+ hours at each %
Monitoring
- ☐ Latency: P50/P95/P99 tracked
- ☐ Quality: Daily metric dashboard
- ☐ Error rate: <0.5% acceptable
- ☐ User feedback collected
- ☐ Weekly review of metrics
Cost Tracking
- ☐ Calculated baseline cost
- ☐ Projected monthly savings
- ☐ Break-even timeline confirmed
- ☐ ROI tracker setup
- ☐ Monthly cost report
Documentation
- ☐ Training recipe documented
- ☐ Hyperparameters recorded
- ☐ Reproducibility verified
- ☐ Inference pipeline documented
- ☐ Runbooks for support team
Go/No-Go criteria: Quality ≥93% of teacher + P95 latency <150ms + Error rate <0.5% + Cost savings >50% + All checklist items completed. If all met, proceed with 5% canary.
Post-Launch (Day 1-30)
- ☐ Day 1: 5% traffic to distilled model. Monitor every 15 min.
- ☐ Day 2: If stable, increase to 10%. Check latency P99.
- ☐ Day 3: 25% traffic. Validate quality with sample review.
- ☐ Day 4-7: 50% traffic. Full monitoring active.
- ☐ Day 8-14: 75-90% traffic. Collect user feedback.
- ☐ Day 15-30: 100% traffic. Daily metrics review.
- ☐ Day 30: Finalize ROI report. Decommission teacher model if stable.