LLM Model Distillation Techniques
Knowledge Transfer for Cost-Efficient RAG Systems
A comprehensive guide to distilling large language models for retrieval-augmented generation pipelines. Learn theory, training recipes, evaluation strategies, and production deployment patterns.
What is Model Distillation?
Teacher-Student Paradigm for Knowledge Transfer
Why Distillation Matters for RAG
- ✓ 10-50x latency reduction in production
- ✓ 70-90% cost savings on inference
- ✓ Deploy to edge, mobile, low-latency APIs
- ✓ Reduce hallucination via better grounding
- ✓ Enable real-time retrieval augmentation
Key Concepts
- • Teacher: Large, accurate model
- • Student: Smaller, faster model
- • Knowledge: Outputs, features, attention
- • Loss: KL divergence + task loss
- • Temperature: Controls softness of targets
Distillation transfers the teacher's knowledge (logits, intermediate features, attention patterns) into a student model, preserving 90-98% of quality while achieving 10-50x speedup. Perfect for RAG where you need fast, cost-efficient components.
Distillation Fundamentals
Core Techniques & Loss Functions
Logit Distillation
Transfer teacher's output probability distribution via KL divergence. Student learns soft targets from teacher at temperature T.
L_KL = T² × KL(teacher_p, student_p)
Feature Distillation
Match intermediate layer representations. Student learns hidden states from teacher's layers via MSE loss.
L_feat = MSE(student_h, teacher_h)
Attention Transfer
Align attention weights between teacher and student. Guides student to focus on same tokens.
L_attn = MSE(student_A, teacher_A)
Contrastive Distillation
Use contrastive learning to align teacher-student embeddings. Useful for embedding model distillation.
L_cont = -log(exp(sim) / Σexp(sims))
Temperature Scaling
Temperature (T) controls softness of teacher's output distribution. Higher T → softer targets → more information about wrong classes. At inference, use T=1 (normal softmax). During distillation training, use T=3-20.
soft_probs = softmax(logits / T)
Combined Loss Function
# Combined distillation loss
def distillation_loss(logits_student, logits_teacher, labels, T=4, alpha=0.7):
# Soft targets from teacher
teacher_soft = F.softmax(logits_teacher / T, dim=-1)
student_soft = F.log_softmax(logits_student / T, dim=-1)
# KL divergence (logit distillation)
kl_loss = F.kl_div(student_soft, teacher_soft, reduction='batchmean')
# Cross-entropy on hard targets
ce_loss = F.cross_entropy(logits_student, labels)
# Combined loss
loss = alpha * (T ** 2) * kl_loss + (1 - alpha) * ce_loss
return loss
Temperature (T): 3-20 typical. Higher = softer knowledge transfer. Alpha (α): 0.5-0.9 typical. Higher = more weight on distillation vs task loss. Learning rate: 2-5x lower than standard fine-tuning.
RAG Distillation Landscape
Where & When to Apply Distillation in RAG Pipeline
| Component | What to Distill | Expected Speedup | Quality Loss | Cost Savings |
|---|---|---|---|---|
| Embedding | Logits, intermediate layers | 10-20x | 2-8% | 70-85% |
| Reranker | Scores, attention | 15-30x | 1-5% | 80-90% |
| Generator | Logits, hidden states | 5-15x | 3-10% | 70-95% |
| Query Transform | Logits, hidden states | 20-50x | 2-6% | 85-95% |
High-volume components (retrieval, reranking): Distill aggressively. Speedup pays for latency. Single-call components (generation): Moderate distillation. Quality is critical. Cascaded components: Distill each stage independently, then validate end-to-end.
Embedding Model Distillation
Efficient Retrievers via Dimension Reduction & Matryoshka Loss
Problem
- • text-embedding-3-large: 3072d, slow
- • Memory: 12GB+ for inference
- • Latency: 100-500ms per query
- • Cost: $0.13 per 1M tokens
Solution
- • Distill to 384d embedding
- • Memory: 500MB
- • Latency: 5-10ms per query
- • Cost: 100x cheaper
Techniques
Dimension Reduction
Project 3072d → 384d via linear layer. Student learns to map teacher embeddings to lower dimension while preserving semantic relationships.
Matryoshka Embeddings
Train with multi-scale loss. At layer i, enforce meaningful embeddings at dimension 2^i (64, 128, 256, 384). Enables flexible dimension selection.
from sentence_transformers import SentenceTransformer, models
import torch
import torch.nn as nn
# Load teacher model
teacher_model = SentenceTransformer("text-embedding-3-large")
teacher_dim = 3072
student_dim = 384
# Create student with dimension reduction
word_embedding_model = models.Transformer("microsoft/MiniLM-L6-H384")
pooling_model = models.Pooling(word_embedding_model.get_word_embedding_dimension())
dense = models.Dense(
in_features=word_embedding_model.get_word_embedding_dimension(),
out_features=student_dim,
activation_function=nn.Tanh()
)
student_model = SentenceTransformer(modules=[word_embedding_model, pooling_model, dense])
# Contrastive distillation loss
def embedding_distillation_loss(teacher_emb, student_emb, temperature=0.07):
# Normalize embeddings
teacher_emb = nn.functional.normalize(teacher_emb, p=2, dim=1)
student_emb = nn.functional.normalize(student_emb, p=2, dim=1)
# Cosine similarity matrix
sim_matrix = torch.matmul(student_emb, teacher_emb.t()) / temperature
# Contrastive loss (NT-Xent)
labels = torch.arange(student_emb.size(0), device=student_emb.device)
loss = nn.CrossEntropyLoss()(sim_matrix, labels)
return loss
| Model | Dimension | Latency (ms) | Recall@1 | Recall@10 | Cost per 1M tokens |
|---|---|---|---|---|---|
| text-embedding-3-large | 3072 | 250-500 | 92.5% | 98.1% | $0.13 |
| text-embedding-3-small | 1536 | 100-200 | 90.2% | 97.4% | $0.02 |
| Distilled MiniLM (384d) | 384 | 5-10 | 89.8% | 96.9% | $0.001 |
Data: Use same domain as production retrieval data. Temperature: 0.05-0.07 for embeddings. Batch size: 256-512 to maximize contrastive signal. Training steps: 50K-100K for high-quality distillation.
Reranker Distillation
Cross-Encoder to Bi-Encoder & Score Distillation
Cross-Encoder Reranker
- • Model: DeBERTa-large, 434M params
- • Input: Concatenate [Q, SEP, D]
- • Output: Relevance score (0-1)
- • Speed: 200-500ms per doc
- • NDCG@10: 0.625
Bi-Encoder Reranker
- • Model: MiniLM, 22M params
- • Input: Embed Q & D separately
- • Output: Dot product similarity
- • Speed: 5-10ms per doc
- • NDCG@10: 0.615 (98% quality)
Distillation Strategy: Score Margin Loss
import torch
import torch.nn.functional as F
# Cross-encoder teacher scores relevant & irrelevant docs
def margin_mse_loss(student_scores, teacher_scores, margin=0.5):
# Assume scores are [batch_size, num_docs]
# Positive: score[0] (relevant), Negative: score[1:] (irrelevant)
pos_score_student = student_scores[:, 0] # Positive doc
neg_score_student = student_scores[:, 1] # Negative doc
pos_score_teacher = teacher_scores[:, 0]
neg_score_teacher = teacher_scores[:, 1]
# Margin loss: enforce student margin >= teacher margin - margin_param
student_margin = pos_score_student - neg_score_student
teacher_margin = pos_score_teacher - neg_score_teacher
loss = F.relu(teacher_margin - student_margin + margin)
return loss.mean()
# Alternative: KL divergence on ranking softmax
def listwise_kl_loss(student_scores, teacher_scores, temperature=4):
teacher_probs = F.softmax(teacher_scores / temperature, dim=-1)
student_log_probs = F.log_softmax(student_scores / temperature, dim=-1)
loss = F.kl_div(student_log_probs, teacher_probs, reduction='batchmean')
return loss * (temperature ** 2)
ColBERT: Late Interaction Distillation
Hybrid approach: compute embeddings separately (like bi-encoder), then match at interaction layer (like cross-encoder). Much faster than full cross-encoder, more accurate than simple bi-encoder.
score = Σ_i max_j sim(Q_i, D_j) # Maximize interaction
| Model | Architecture | Latency/doc | NDCG@10 | MRR@10 | Cost (1M queries) |
|---|---|---|---|---|---|
| DeBERTa-large (Teacher) | Cross-encoder | 300ms | 0.625 | 0.758 | $150 |
| ColBERT (Distilled) | Late interaction | 15ms | 0.618 | 0.752 | $8 |
| MiniLM Bi-encoder | Bi-encoder | 2ms | 0.615 | 0.745 | $2 |
Reranker distillation is sensitive to margin parameter. Start with margin=0.3, increase gradually. Use in-batch negatives to stabilize training. Monitor margin distribution across iterations.
Generator Distillation for RAG
Distilling GPT-4 & Claude into Smaller Models
Teacher: GPT-4/Claude
- • Model: 1T+ params (estimated)
- • Quality: 95%+ factually correct
- • Cost: $30-60 per 1M tokens
- • Latency: 2-5 sec
- • Hallucination: ~5%
Student: Llama-3-8B
- • Model: 8B params
- • Quality: 88-92% (after distillation)
- • Cost: $0.50 per 1M tokens
- • Latency: 200-400ms
- • Hallucination: ~12% (with context)
Output Distillation (Synthetic Data)
# Step 1: Generate synthetic training data with teacher
def generate_distillation_data(queries, documents, teacher_model, num_examples=5000):
data = []
for query, doc in zip(queries, documents):
# Teacher generates response with context
prompt = f"""Answer the question based on the context.
Context: {doc}
Question: {query}
Answer:"""
response = teacher_model.generate(prompt) # e.g., GPT-4
data.append({
"query": query,
"context": doc,
"response": response
})
return data
# Step 2: Fine-tune student on synthetic data
from transformers import Trainer, TrainingArguments, AutoModelForCausalLM
student_model = AutoModelForCausalLM.from_pretrained("meta-llama/Llama-3-8b")
training_args = TrainingArguments(
output_dir="./distilled-llama",
learning_rate=2e-5,
per_device_train_batch_size=8,
num_train_epochs=3,
gradient_accumulation_steps=4,
)
trainer = Trainer(
model=student_model,
args=training_args,
train_dataset=dataset,
)
trainer.train()
Chain-of-Thought Distillation
Distill not just final answer but reasoning steps. Teacher outputs step-by-step reasoning, student learns to generate intermediate thoughts. Improves factuality and makes errors more traceable.
[Thought 1] → [Thought 2] → ... → [Answer]
| Model | Params | Latency | RAGAS Faithfulness | RAGAS Relevance | Cost per 1K tokens |
|---|---|---|---|---|---|
| GPT-4 (Teacher) | 1T+ | 3-5s | 0.94 | 0.92 | $0.06 |
| Llama-3-70B | 70B | 1-2s | 0.88 | 0.87 | $0.01 |
| Llama-3-8B (Distilled) | 8B | 200-400ms | 0.89 | 0.86 | $0.0005 |
Context: Always include retrieved documents in prompt. This grounds student. Diversity: Mix easy & hard examples. Temperature: 0.3 for deterministic outputs during distillation. Validation: Use RAGAS metrics to track faithfulness.
Query Transformer Distillation
Fast Query Expansion & Rewrite
Use Case: Query Expansion
Teacher (GPT-4) generates 3-5 reformulations of user query to improve retrieval coverage. Student (T5-small) learns to do same in 5ms.
"basketball" → ["basketball game", "NBA", "court sport", ...]
Use Case: Query Rewrite
Teacher rewrites conversational queries to standalone form. Improves multi-turn retrieval.
"What about alternatives?" → "What are alternatives to X?"
Training Recipe
# Generate query expansion training data
def generate_query_expansion_data(original_queries, teacher_model):
training_pairs = []
for query in original_queries:
prompt = f"""Generate 3 alternative search queries for: {query}
Format: query1 ||| query2 ||| query3"""
expansions = teacher_model.generate(prompt)
training_pairs.append({
"input": query,
"target": expansions
})
return training_pairs
# Fine-tune T5-small on seq2seq task
from transformers import T5ForConditionalGeneration, T5Tokenizer
model = T5ForConditionalGeneration.from_pretrained("google/flan-t5-small")
tokenizer = T5Tokenizer.from_pretrained("google/flan-t5-small")
def compute_loss(batch):
inputs = tokenizer(batch["input"], max_length=128, padding="max_length", return_tensors="pt")
labels = tokenizer(batch["target"], max_length=256, padding="max_length", return_tensors="pt")
outputs = model(input_ids=inputs.input_ids, labels=labels.input_ids)
return outputs.loss
| Model | Params | Latency | Expansion Quality | Recall Lift |
|---|---|---|---|---|
| GPT-4 (Teacher) | 1T+ | 2-3s | 95% | +18% |
| T5-small (Distilled) | 60M | 5-10ms | 92% | +17% |
Data: Use production queries + relevance judgments. Target diversity: Generate 3-5 expansions per query. Evaluation: Measure recall lift, not exact match. Integration: Combine expansions via union retrieval.