LLM Model Distillation — Techniques, Training & Deployment
Comprehensive guide to knowledge distillation for large language models — from embedding compression through LoRA/QLoRA fine-tuning to quantization-aware training and production deployment.
A comprehensive guide to knowledge distillation techniques for large language models. Learn theory, training recipes, evaluation strategies, deployment patterns, and production optimization.
What is Model Distillation?
Teacher-Student Paradigm for Knowledge Transfer
Why Distillation Matters in Production
- ✓ 10-50x latency reduction in production
- ✓ 70-90% cost savings on inference
- ✓ Deploy to edge, mobile, low-latency APIs
- ✓ Reduce hallucination via better grounding
- ✓ Enable real-time retrieval augmentation
Key Concepts
- • Teacher: Large, accurate model
- • Student: Smaller, faster model
- • Knowledge: Outputs, features, attention
- • Loss: KL divergence + task loss
- • Temperature: Controls softness of targets
Distillation transfers the teacher's knowledge (logits, intermediate features, attention patterns) into a student model, preserving 90-98% of quality while achieving 10-50x speedup. Perfect for production systems where you need fast, cost-efficient components.
Distillation Fundamentals
Core Techniques & Loss Functions
Logit Distillation
Transfer teacher's output probability distribution via KL divergence. Student learns soft targets from teacher at temperature T.
L_KL = T² × KL(teacher_p, student_p)
Feature Distillation
Match intermediate layer representations. Student learns hidden states from teacher's layers via MSE loss.
L_feat = MSE(student_h, teacher_h)
Attention Transfer
Align attention weights between teacher and student. Guides student to focus on same tokens.
L_attn = MSE(student_A, teacher_A)
Contrastive Distillation
Use contrastive learning to align teacher-student embeddings. Useful for embedding model distillation.
L_cont = -log(exp(sim) / Σexp(sims))
Temperature Scaling
Temperature (T) controls softness of teacher's output distribution. Higher T → softer targets → more information about wrong classes. At inference, use T=1 (normal softmax). During distillation training, use T=3-20.
soft_probs = softmax(logits / T)
Combined Loss Function
# Combined distillation loss
def distillation_loss(logits_student, logits_teacher, labels, T=4, alpha=0.7):
# Soft targets from teacher
teacher_soft = F.softmax(logits_teacher / T, dim=-1)
student_soft = F.log_softmax(logits_student / T, dim=-1)
# KL divergence (logit distillation)
kl_loss = F.kl_div(student_soft, teacher_soft, reduction='batchmean')
# Cross-entropy on hard targets
ce_loss = F.cross_entropy(logits_student, labels)
# Combined loss
loss = alpha * (T ** 2) * kl_loss + (1 - alpha) * ce_loss
return loss
Temperature (T): 3-20 typical. Higher = softer knowledge transfer. Alpha (α): 0.5-0.9 typical. Higher = more weight on distillation vs task loss. Learning rate: 2-5x lower than standard fine-tuning.
Distillation Landscape
Where & When to Apply Distillation in ML Pipelines
| Component | What to Distill | Expected Speedup | Quality Loss | Cost Savings |
|---|---|---|---|---|
| Embedding | Logits, intermediate layers | 10-20x | 2-8% | 70-85% |
| Reranker | Scores, attention | 15-30x | 1-5% | 80-90% |
| Generator | Logits, hidden states | 5-15x | 3-10% | 70-95% |
| Query Transform | Logits, hidden states | 20-50x | 2-6% | 85-95% |
High-volume components (retrieval, reranking): Distill aggressively. Speedup pays for latency. Single-call components (generation): Moderate distillation. Quality is critical. Cascaded components: Distill each stage independently, then validate end-to-end.
Embedding Model Distillation
Efficient Retrievers via Dimension Reduction & Matryoshka Loss
Problem
- • text-embedding-3-large: 3072d, slow
- • Memory: 12GB+ for inference
- • Latency: 100-500ms per query
- • Cost: $0.13 per 1M tokens
Solution
- • Distill to 384d embedding
- • Memory: 500MB
- • Latency: 5-10ms per query
- • Cost: 100x cheaper
Techniques
Dimension Reduction
Project 3072d → 384d via linear layer. Student learns to map teacher embeddings to lower dimension while preserving semantic relationships.
Matryoshka Embeddings
Train with multi-scale loss. At layer i, enforce meaningful embeddings at dimension 2^i (64, 128, 256, 384). Enables flexible dimension selection.
from sentence_transformers import SentenceTransformer, models
import torch
import torch.nn as nn
# Load teacher model
teacher_model = SentenceTransformer("text-embedding-3-large")
teacher_dim = 3072
student_dim = 384
# Create student with dimension reduction
word_embedding_model = models.Transformer("microsoft/MiniLM-L6-H384")
pooling_model = models.Pooling(word_embedding_model.get_word_embedding_dimension())
dense = models.Dense(
in_features=word_embedding_model.get_word_embedding_dimension(),
out_features=student_dim,
activation_function=nn.Tanh()
)
student_model = SentenceTransformer(modules=[word_embedding_model, pooling_model, dense])
# Contrastive distillation loss
def embedding_distillation_loss(teacher_emb, student_emb, temperature=0.07):
# Normalize embeddings
teacher_emb = nn.functional.normalize(teacher_emb, p=2, dim=1)
student_emb = nn.functional.normalize(student_emb, p=2, dim=1)
# Cosine similarity matrix
sim_matrix = torch.matmul(student_emb, teacher_emb.t()) / temperature
# Contrastive loss (NT-Xent)
labels = torch.arange(student_emb.size(0), device=student_emb.device)
loss = nn.CrossEntropyLoss()(sim_matrix, labels)
return loss
| Model | Dimension | Latency (ms) | Recall@1 | Recall@10 | Cost per 1M tokens |
|---|---|---|---|---|---|
| text-embedding-3-large | 3072 | 250-500 | 92.5% | 98.1% | $0.13 |
| text-embedding-3-small | 1536 | 100-200 | 90.2% | 97.4% | $0.02 |
| Distilled MiniLM (384d) | 384 | 5-10 | 89.8% | 96.9% | $0.001 |
Data: Use same domain as production retrieval data. Temperature: 0.05-0.07 for embeddings. Batch size: 256-512 to maximize contrastive signal. Training steps: 50K-100K for high-quality distillation.
Reranker Distillation
Cross-Encoder to Bi-Encoder & Score Distillation
Cross-Encoder Reranker
- • Model: DeBERTa-large, 434M params
- • Input: Concatenate [Q, SEP, D]
- • Output: Relevance score (0-1)
- • Speed: 200-500ms per doc
- • NDCG@10: 0.625
Bi-Encoder Reranker
- • Model: MiniLM, 22M params
- • Input: Embed Q & D separately
- • Output: Dot product similarity
- • Speed: 5-10ms per doc
- • NDCG@10: 0.615 (98% quality)
Distillation Strategy: Score Margin Loss
import torch
import torch.nn.functional as F
# Cross-encoder teacher scores relevant & irrelevant docs
def margin_mse_loss(student_scores, teacher_scores, margin=0.5):
# Assume scores are [batch_size, num_docs]
# Positive: score[0] (relevant), Negative: score[1:] (irrelevant)
pos_score_student = student_scores[:, 0] # Positive doc
neg_score_student = student_scores[:, 1] # Negative doc
pos_score_teacher = teacher_scores[:, 0]
neg_score_teacher = teacher_scores[:, 1]
# Margin loss: enforce student margin >= teacher margin - margin_param
student_margin = pos_score_student - neg_score_student
teacher_margin = pos_score_teacher - neg_score_teacher
loss = F.relu(teacher_margin - student_margin + margin)
return loss.mean()
# Alternative: KL divergence on ranking softmax
def listwise_kl_loss(student_scores, teacher_scores, temperature=4):
teacher_probs = F.softmax(teacher_scores / temperature, dim=-1)
student_log_probs = F.log_softmax(student_scores / temperature, dim=-1)
loss = F.kl_div(student_log_probs, teacher_probs, reduction='batchmean')
return loss * (temperature ** 2)
ColBERT: Late Interaction Distillation
Hybrid approach: compute embeddings separately (like bi-encoder), then match at interaction layer (like cross-encoder). Much faster than full cross-encoder, more accurate than simple bi-encoder.
score = Σ_i max_j sim(Q_i, D_j) # Maximize interaction
| Model | Architecture | Latency/doc | NDCG@10 | MRR@10 | Cost (1M queries) |
|---|---|---|---|---|---|
| DeBERTa-large (Teacher) | Cross-encoder | 300ms | 0.625 | 0.758 | $150 |
| ColBERT (Distilled) | Late interaction | 15ms | 0.618 | 0.752 | $8 |
| MiniLM Bi-encoder | Bi-encoder | 2ms | 0.615 | 0.745 | $2 |
Reranker distillation is sensitive to margin parameter. Start with margin=0.3, increase gradually. Use in-batch negatives to stabilize training. Monitor margin distribution across iterations.
Generator Distillation for Production
Distilling GPT-4 & Claude into Smaller Models
Teacher: GPT-4/Claude
- • Model: 1T+ params (estimated)
- • Quality: 95%+ factually correct
- • Cost: $30-60 per 1M tokens
- • Latency: 2-5 sec
- • Hallucination: ~5%
Student: Llama-3-8B
- • Model: 8B params
- • Quality: 88-92% (after distillation)
- • Cost: $0.50 per 1M tokens
- • Latency: 200-400ms
- • Hallucination: ~12% (with context)
Output Distillation (Synthetic Data)
# Step 1: Generate synthetic training data with teacher
def generate_distillation_data(queries, documents, teacher_model, num_examples=5000):
data = []
for query, doc in zip(queries, documents):
# Teacher generates response with context
prompt = f"""Answer the question based on the context.
Context: {doc}
Question: {query}
Answer:"""
response = teacher_model.generate(prompt) # e.g., GPT-4
data.append({
"query": query,
"context": doc,
"response": response
})
return data
# Step 2: Fine-tune student on synthetic data
from transformers import Trainer, TrainingArguments, AutoModelForCausalLM
student_model = AutoModelForCausalLM.from_pretrained("meta-llama/Llama-3-8b")
training_args = TrainingArguments(
output_dir="./distilled-llama",
learning_rate=2e-5,
per_device_train_batch_size=8,
num_train_epochs=3,
gradient_accumulation_steps=4,
)
trainer = Trainer(
model=student_model,
args=training_args,
train_dataset=dataset,
)
trainer.train()
Chain-of-Thought Distillation
Distill not just final answer but reasoning steps. Teacher outputs step-by-step reasoning, student learns to generate intermediate thoughts. Improves factuality and makes errors more traceable.
[Thought 1] → [Thought 2] → ... → [Answer]
| Model | Params | Latency | RAGAS Faithfulness | RAGAS Relevance | Cost per 1K tokens |
|---|---|---|---|---|---|
| GPT-4 (Teacher) | 1T+ | 3-5s | 0.94 | 0.92 | $0.06 |
| Llama-3-70B | 70B | 1-2s | 0.88 | 0.87 | $0.01 |
| Llama-3-8B (Distilled) | 8B | 200-400ms | 0.89 | 0.86 | $0.0005 |
Context: Always include retrieved documents in prompt. This grounds student. Diversity: Mix easy & hard examples. Temperature: 0.3 for deterministic outputs during distillation. Validation: Use RAGAS metrics to track faithfulness.
Query Transformer Distillation
Fast Query Expansion & Rewrite
Use Case: Query Expansion
Teacher (GPT-4) generates 3-5 reformulations of user query to improve retrieval coverage. Student (T5-small) learns to do same in 5ms.
"basketball" → ["basketball game", "NBA", "court sport", ...]
Use Case: Query Rewrite
Teacher rewrites conversational queries to standalone form. Improves multi-turn retrieval.
"What about alternatives?" → "What are alternatives to X?"
Training Recipe
# Generate query expansion training data
def generate_query_expansion_data(original_queries, teacher_model):
training_pairs = []
for query in original_queries:
prompt = f"""Generate 3 alternative search queries for: {query}
Format: query1 ||| query2 ||| query3"""
expansions = teacher_model.generate(prompt)
training_pairs.append({
"input": query,
"target": expansions
})
return training_pairs
# Fine-tune T5-small on seq2seq task
from transformers import T5ForConditionalGeneration, T5Tokenizer
model = T5ForConditionalGeneration.from_pretrained("google/flan-t5-small")
tokenizer = T5Tokenizer.from_pretrained("google/flan-t5-small")
def compute_loss(batch):
inputs = tokenizer(batch["input"], max_length=128, padding="max_length", return_tensors="pt")
labels = tokenizer(batch["target"], max_length=256, padding="max_length", return_tensors="pt")
outputs = model(input_ids=inputs.input_ids, labels=labels.input_ids)
return outputs.loss
| Model | Params | Latency | Expansion Quality | Recall Lift |
|---|---|---|---|---|
| GPT-4 (Teacher) | 1T+ | 2-3s | 95% | +18% |
| T5-small (Distilled) | 60M | 5-10ms | 92% | +17% |
Data: Use production queries + relevance judgments. Target diversity: Generate 3-5 expansions per query. Evaluation: Measure recall lift, not exact match. Integration: Combine expansions via union retrieval.
End-to-End Training Recipes
Complete Pipelines with Data Prep & Hypertuning
Data Preparation
Step 1: Generate Teacher Labels
Use teacher model to label training data. For embeddings: pair queries with hard negatives. For rerankers: score documents. For generators: generate answers with context.
Step 2: Filter & Balance
Remove ambiguous examples (low teacher confidence). Balance difficulty. For reranker: ensure positive scores much higher than negative.
# Complete training pipeline with HuggingFace Trainer
from transformers import Trainer, TrainingArguments, AutoModelForSequenceClassification
import torch
# Training arguments
training_args = TrainingArguments(
output_dir="./distilled-model",
num_train_epochs=3,
learning_rate=2e-5,
per_device_train_batch_size=32,
per_device_eval_batch_size=64,
warmup_steps=500,
weight_decay=0.01,
logging_steps=100,
eval_strategy="steps",
eval_steps=500,
save_strategy="steps",
save_steps=500,
load_best_model_at_end=True,
metric_for_best_model="eval_loss",
greater_is_better=False,
gradient_accumulation_steps=4,
)
# Load student model
model = AutoModelForSequenceClassification.from_pretrained(
"microsoft/MiniLM-L6-H384-uncased",
num_labels=1
)
# Custom distillation loss
class DistillationTrainer(Trainer):
def __init__(self, teacher_model=None, temperature=4, alpha=0.7, *args, **kwargs):
super().__init__(*args, **kwargs)
self.teacher = teacher_model
self.temperature = temperature
self.alpha = alpha
def compute_loss(self, model, inputs, return_outputs=False):
# Student forward
student_outputs = model(**inputs)
student_logits = student_outputs.logits
# Teacher forward
with torch.no_grad():
teacher_outputs = self.teacher(**inputs)
teacher_logits = teacher_outputs.logits
# Distillation loss
soft_targets = torch.nn.functional.softmax(teacher_logits / self.temperature, dim=-1)
student_log_soft = torch.nn.functional.log_softmax(student_logits / self.temperature, dim=-1)
kl_loss = torch.nn.functional.kl_div(student_log_soft, soft_targets, reduction='batchmean')
# Task-specific loss (MSE for scoring)
task_loss = torch.nn.functional.mse_loss(student_logits, teacher_logits)
# Combined loss
loss = self.alpha * (self.temperature ** 2) * kl_loss + (1 - self.alpha) * task_loss
return (loss, student_outputs) if return_outputs else loss
# Initialize trainer
trainer = DistillationTrainer(
model=model,
teacher_model=teacher,
temperature=4,
alpha=0.7,
args=training_args,
train_dataset=train_dataset,
eval_dataset=eval_dataset,
)
# Train
trainer.train()
Hardware & Optimization
| Setup | GPU Memory | Batch Size | Training Time (100K examples) | Cost |
|---|---|---|---|---|
| Single A100 (40GB) | 40GB | 32 | 2-4 hours | $8-16 |
| 4× A100 (40GB) DDP | 160GB | 256 | 30-45 min | $20-30 |
| Single A40 (48GB) with DeepSpeed | 48GB | 64 | 1.5-2 hours | $5-8 |
Learning rate: 1e-5 to 5e-5. Lower for larger models. Temperature: 3-20. Higher = more soft targets. Alpha: 0.5-0.9. Higher = more distillation vs task. Warmup: 5-10% of steps. Weight decay: 0.01-0.1.
Evaluation & Benchmarking
Measuring Distillation Quality by Component
Embedding Model Evaluation
MTEB Benchmark Metrics
- • Recall@K: Fraction of relevant items in top-K
- • NDCG@K: Normalized discounted cumulative gain
- • MAP: Mean average precision across queries
- • MRR: Mean reciprocal rank
# Evaluate embedding model with MTEB
from mteb import MTEB
tasks = ["STS12", "STS13", "STS14", "STS15", "STS16",
"STSBenchmark", "SummEval"]
evaluation = MTEB(tasks=tasks, task_langs=["en"])
results = evaluation.run(model, output_folder="results")
# Retrieval benchmark
retrieval_tasks = ["TREC-COVID", "DBpedia", "SCIFACT"]
results = MTEB(tasks=retrieval_tasks).run(model)
# Example: Check recall@1
for task, score in results.items():
print(f"{task}: Recall@1 = {score['recall@1']:.3f}")
Reranker Evaluation
| Metric | Definition | Target Threshold |
|---|---|---|
| NDCG@10 | Discounted gain at position 10 | >95% of teacher |
| MRR@10 | Reciprocal rank of first relevant | >95% of teacher |
| MAP@1000 | Mean average precision across ranking | >93% of teacher |
Generator Evaluation (RAGAS)
# Evaluate generation quality with RAGAS
from ragas import evaluate
from ragas.metrics import faithfulness, answer_relevancy, context_recall
# Prepare evaluation dataset
rag_results = {
"question": [...],
"answer": [...], # Generated by student model
"contexts": [...], # Retrieved documents
"ground_truth": [...] # Reference answers
}
# Compute metrics
score = evaluate(
rag_results,
metrics=[faithfulness, answer_relevancy, context_recall]
)
print(f"Faithfulness: {score['faithfulness']:.3f}")
print(f"Answer Relevancy: {score['answer_relevancy']:.3f}")
print(f"Context Recall: {score['context_recall']:.3f}")
A/B Testing in Production
Canary deployment: Route 5-10% of traffic to distilled model. Monitor latency, cost, quality metrics. If stable, increase to 50%, then 100%.
Metrics to track: Latency P50/P95/P99, hallucination rate (human review sample), user satisfaction (thumbs up/down), business KPIs (conversions, retention).
Rollback trigger: >5% regression in any critical metric. Keep teacher model running in parallel for 48 hours.
Set alerts for >2% drop in key metrics. Use sequential probability ratio tests (SPRT) for early stopping. Monitor distribution shift—if data changes significantly, re-distill.
LoRA & Parameter-Efficient Distillation
Combine PEFT Methods with Knowledge Transfer
LoRA (Low-Rank Adaptation)
Add low-rank matrices to attention layers. Train only 0.1-1% of parameters. Compatible with distillation.
W = W_frozen + α(A × B)
QLoRA (Quantized LoRA)
Quantize base model to 4-bit. Add LoRA on top. Fit 70B model in 48GB GPU.
70B model → 16GB VRAM
LoRA + Distillation Training
from peft import get_peft_model, LoraConfig
from transformers import AutoModelForCausalLM, BitsAndBytesConfig
# QLoRA config: 4-bit quantization
bnb_config = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_use_double_quant=True,
bnb_4bit_quant_type="nf4",
bnb_4bit_compute_dtype=torch.bfloat16,
)
# Load student model with quantization
model = AutoModelForCausalLM.from_pretrained(
"meta-llama/Llama-3-8b",
quantization_config=bnb_config,
device_map="auto"
)
# LoRA config: target attention weights
lora_config = LoraConfig(
r=8, # LoRA rank
lora_alpha=16,
target_modules=["q_proj", "v_proj"],
lora_dropout=0.05,
bias="none",
task_type="CAUSAL_LM"
)
model = get_peft_model(model, lora_config)
print(model.print_trainable_parameters()) # ~0.1% of 8B
# Distillation loss with LoRA
def lora_distillation_loss(student_logits, teacher_logits, temperature=4):
teacher_soft = F.softmax(teacher_logits / temperature, dim=-1)
student_soft = F.log_softmax(student_logits / temperature, dim=-1)
kl_loss = F.kl_div(student_soft, teacher_soft, reduction='batchmean')
return kl_loss * (temperature ** 2)
# Train with minimal memory
optimizer = torch.optim.AdamW(model.parameters(), lr=1e-4)
for batch in dataloader:
student_out = model(**batch)
teacher_out = teacher_model(**batch)
loss = lora_distillation_loss(student_out.logits, teacher_out.logits)
loss.backward()
optimizer.step()
| Method | GPU Memory | Trainable Params | Distillation Quality | Training Speed |
|---|---|---|---|---|
| Full Fine-tune | 80GB | 100% | Best | 1x |
| LoRA (r=8) | 48GB | 0.2% | 98% | 0.95x |
| QLoRA (r=8, 4-bit) | 16GB | 0.2% | 97% | 0.8x |
Rank (r): 4-16 typical. Higher rank = more capacity but slower. Target modules: Attention weights (q_proj, v_proj) usually sufficient. Initialization: Gaussian random with std = 1/rank. Merging: After training, merge LoRA weights into base model for inference.
Quantization + Distillation Synergy
Maximum Compression via Combined Techniques
Quantization-Aware Distillation
Simulate quantization during training. Student learns to be robust to quantization noise. Better quality than post-hoc quantization.
Execution Order
1. Distill → 2. Quantize
vs.
1. Distill aware of quantization (better quality)
Quantization Techniques
INT8
8-bit integer weights. 4x compression. Minimal quality loss. Easy integration.
INT4
4-bit quantization. 8x compression. Requires distillation for quality.
GPTQ/AWQ
Weight-only quantization. Fast inference. Good for LLMs.
# Quantization-aware distillation with fake quantization
import torch.quantization as quant
# Prepare model for QAT (Quantization Aware Training)
model.qconfig = quant.get_default_qat_qconfig('fbgemm')
quant.prepare_qat(model, inplace=True)
# Training loop with fake quantization
def qat_distillation_step(student, teacher, batch, temperature=4):
# Student forward (includes fake quant)
student_out = student(**batch)
# Teacher forward (no quant)
with torch.no_grad():
teacher_out = teacher(**batch)
# Distillation loss
teacher_soft = F.softmax(teacher_out.logits / temperature, dim=-1)
student_log = F.log_softmax(student_out.logits / temperature, dim=-1)
loss = F.kl_div(student_log, teacher_soft) * (temperature ** 2)
return loss
# Post-training: convert to INT8
quant.convert(model, inplace=True)
# Using GPTQ for 4-bit LLM quantization
from auto_gptq import AutoGPTQForCausalLM, BaseQuantizeConfig
gptq_config = BaseQuantizeConfig(
bits=4, # 4-bit quantization
group_size=128,
desc_act=False,
)
model = AutoGPTQForCausalLM.from_pretrained(
"meta-llama/Llama-3-8b",
quantize_config=gptq_config
)
| Technique | Bits | Compression | Quality Loss | Inference Speed | Combined with Distill |
|---|---|---|---|---|---|
| Original FP32 | 32 | 1x | — | 1x | — |
| INT8 Only | 8 | 4x | 1-2% | 2x | ↓ Loss to 0.5% |
| INT4 Only | 4 | 8x | 5-10% | 4x | ↓ Loss to 2-3% |
| Distilled + GPTQ 4-bit | 4 | 80x | 3-5% | 20x | Combined |
Step 1: Distill teacher to student (90-95% quality). Step 2: Quantize student with QAT (distill with fake quant active). Step 3: Convert to INT4/GPTQ. Result: 80x compression with 93-97% quality.
Production Deployment
Serving Distilled Models at Scale
Inference Frameworks
vLLM
High-throughput LLM inference. Paged attention, continuous batching. 10-50x faster than vanilla HuggingFace.
TensorRT
NVIDIA's inference optimizer. Optimized kernels, automatic optimization. Best for NVIDIA GPUs.
ONNX Runtime
Cross-platform, cross-hardware. CPU/GPU/mobile. Good for edge deployment.
Triton Inference Server
Multi-model serving. Dynamic batching, ensemble pipelines. For production LLM endpoints.
A/B Testing & Canary Deployment
# A/B testing setup with vLLM
from vllm import LLM, SamplingParams
import random
# Load teacher and student models
teacher_model = LLM(model="gpt2", gpu_memory_utilization=0.8)
student_model = LLM(model="distilled-gpt2", gpu_memory_utilization=0.8)
sampling_params = SamplingParams(temperature=0.7, top_p=0.9, max_tokens=256)
# Canary deployment: 10% student, 90% teacher
def generate_with_ab_test(prompt, canary_rate=0.1):
if random.random() < canary_rate:
# Use student model (distilled)
model = student_model
variant = "student"
else:
# Use teacher model (original)
model = teacher_model
variant = "teacher"
outputs = model.generate([prompt], sampling_params=sampling_params)
# Log for analysis
log_metric({
"prompt": prompt,
"variant": variant,
"output": outputs[0].outputs[0].text,
"latency": outputs[0].metrics["latency"],
})
return outputs[0].outputs[0].text
# Increase canary rate over time
canary_schedule = {
"hour_0": 0.05, # 5%
"hour_2": 0.10, # 10%
"hour_6": 0.25, # 25%
"hour_12": 0.50, # 50%
"hour_24": 1.00, # 100% full cutover
}
Monitoring & Alerting
| Metric | Normal | Warning | Critical |
|---|---|---|---|
| P50 Latency | <50ms | 50-100ms | >100ms |
| P99 Latency | <200ms | 200-500ms | >500ms |
| Quality (NDCG) | >0.60 | 0.57-0.60 | <0.57 |
| Error Rate | <0.1% | 0.1-0.5% | >0.5% |
Automatic: If error rate >1% or quality drops >5%, auto-rollback to teacher. Manual: Keep teacher running in parallel for 48 hours. Hotfix: Disable distilled model, re-train if data distribution changed significantly.
Cost Analysis & ROI
When Distillation Pays Off
Per-Component Cost Breakdown
| Component | Teacher Model | Distilled Model | Cost per 1M Requests | Savings (1M req/day) |
|---|---|---|---|---|
| Embedding (1M docs) | text-embedding-3-large: $0.13 | MiniLM-384d: $0.001 | $130 → $1 | $129/day = $47K/yr |
| Reranker (10K docs/req) | DeBERTa-large: $150 | ColBERT: $8 | $150 → $8 | $142/day = $52K/yr |
| Generator (512 tokens) | GPT-4: $30 | Llama-8B: $0.50 | $30 → $0.50 | $29.50/day = $10.8K/yr |
| TOTAL (Full Pipeline) | $180 per 1M | $9.50 per 1M | $180 → $9.50 | $170.50/day = $62.2K/yr |
Training Investment vs. Savings
One-Time Training Cost
- • 4× A100 GPU: 24 hours
- • GPU cost: $1,500
- • Labeling/data prep: $2,000
- • Validation/testing: $1,000
- • Total: $4,500
Break-Even Analysis
- • Savings per day: $170
- • Break-even: $4,500 / $170 = 26 days
- • Monthly ROI: 11x
- • Yearly ROI: 138x
- • 1M requests/day: YES
Distillation Makes Sense When:
- ✓ High volume: >100K requests/day per component
- ✓ Cost-sensitive: Inference cost is significant (>10% of budget)
- ✓ Latency-critical: Need <100ms P99 latency
- ✓ Stable workload: Data distribution doesn't change rapidly
- ✓ Quality tolerance: Can accept 2-5% quality drop
Volume-Based ROI Calculator
Monthly Cost (No Distill) = daily_requests × daily_cost
Monthly Cost (Distilled) = daily_requests × distilled_daily_cost
Monthly Savings = Cost(No Distill) - Cost(Distilled)
Payback Period (months) = Training Cost / Monthly Savings
Example: 1M requests/day for production LLM systems
• Current cost: $180/day → $5,400/month
• Distilled cost: $9.50/day → $285/month
• Savings: $5,115/month
• Training cost: $4,500 (one-time)
• Payback: 0.88 months (27 days)
• Year 1 ROI: 16.8x
Cloud GPU Pricing Reference (March 2026)
| GPU | VRAM | On-Demand $/hr | Spot $/hr | Use for Distillation |
|---|---|---|---|---|
| A100 80GB | 80GB | $2.00-3.00 | $1.00-1.80 | Embedding/reranker distillation (single GPU) |
| 4× A100 80GB | 320GB | $8.00-12.00 | $4.00-7.00 | Generator distillation (8B student) |
| H100 SXM | 80GB | $2.40-4.00 | $1.50-2.50 | Fast teacher inference + student training |
Distill vs Fine-Tune vs API — Total Cost of Ownership
| Approach | Setup Cost | Monthly (1M req/day) | Annual | Latency |
|---|---|---|---|---|
| API (GPT-4o) | $0 | $7,500 | $90,000 | 500-2000ms |
| API (GPT-4o-mini) | $0 | $1,125 | $13,500 | 200-800ms |
| Fine-tuned 8B (self-host) | $2-5K | $5,760 | $71,620 | 50-200ms |
| Distilled 3B (self-host) | $4.5K | $1,440 | $21,780 | 10-50ms |
| Distilled 3B + INT4 | $5K | $576 | $11,912 | 5-20ms |
Month 1: Distill components (est. $4.5K). Month 2-12: Save $5.1K/month vs API. Year 1 total savings: $51.6K after investment. Year 2: Pure savings ($61K+). Distilled + quantized: 87% cheaper than GPT-4o API, 7.5x cheaper than self-hosted fine-tuned 8B.
Real-World Case Studies
Production Success Stories
Case 1: E-Commerce Document Retrieval
Challenge
E-commerce platform with 10M product descriptions. text-embedding-3-large too slow (300ms/query). Cost: $40K/month.
Solution
Distill to MiniLM-384d using contrastive loss on 100K product pairs.
Results:
- • Latency: 300ms → 8ms (37x faster)
- • Recall@10: 94.2% → 92.8% (98.5% quality)
- • Cost: $40K/mo → $1.2K/mo (97% savings)
- • Training server: Vast.ai 1× A100 40GB spot — $1.10/hr (~$53 for 2 days)
- • Serving: RunPod 1× RTX 4090 — $0.39/hr ($280/month)
- • Payback: 9 days | Year 1 ROI: 405x
Case 2: SaaS Question-Answering System
Challenge
Support chatbot using GPT-4 with retrieval-augmented generation. High latency (3s), high cost ($100K/month). Users frustrated with wait times.
Solution
Distill GPT-4 to Llama-3-8B using 50K QA pairs + output distillation + LoRA.
Results:
- • Latency: 3000ms → 350ms (8.6x faster)
- • RAGAS Faithfulness: 0.94 → 0.89 (94.7% quality)
- • Cost: $100K/mo → $2.5K/mo (97.5% savings)
- • Training server: RunPod 4× A100 80GB — $10.28/hr (~$494 for 2 days)
- • Alternative: Lambda Labs 4× A100 — $12.00/hr or Vast.ai spot — $8.50/hr
- • Serving: RunPod 1× A10G 24GB — $0.50/hr ($360/month)
- • Payback: 4 days | Year 1 ROI: 315x
Case 3: Search Ranking with Reranker
Challenge
Cross-encoder reranker (DeBERTa) bottleneck. Must rerank top-100 per query. 200ms per request. P99 latency: 500ms.
Solution
Distill to ColBERT (late interaction). Score distillation + margin loss.
Results:
- • Latency: 200ms → 12ms (16.7x faster)
- • NDCG@10: 0.625 → 0.618 (98.8% quality)
- • P99 latency: 500ms → 35ms (14x improvement)
- • Hardware: 4 GPUs → 1 GPU (75% cost)
- • Training server: Vast.ai 1× A100 80GB spot — $1.80/hr (~$43 for 1 day)
- • Alternative: RunPod 1× A100 80GB — $2.57/hr
- • Serving: RunPod 1× RTX 4090 — $0.39/hr (handles 10K+ qps)
- • Payback: 5 days | Year 1 ROI: 200x
Embedding distillation (small models): Vast.ai 1× A100 40GB spot — $1.10/hr (cheapest GPU cloud).
LLM fine-tune/distill (7-8B): RunPod 4× A100 80GB — $10.28/hr or Vast.ai spot 4× A100 — $8.50/hr.
Reranker training: Vast.ai 1× A100 80GB spot — $1.80/hr (single-GPU sufficient).
Production serving: RunPod 1× RTX 4090 — $0.39/hr (best price-performance for inference).
1. High volume: All cases had 1M+ requests/day. 2. Clear bottleneck: One slow component identified. 3. High-quality teacher: Started with strong model (GPT-4, DeBERTa). 4. Fast payback: Most broke even <2 weeks. 5. Conservative rollout: Canary to 100% over 48 hours.
Use Case 1: RAG Chatbot — Full Distillation Walkthrough
Distill GPT-4o → Llama-3.1-8B-Instruct on RunPod (4× A100 80GB)
You run a SaaS product with 50K internal docs (Confluence, Notion, PDFs). Your RAG chatbot uses GPT-4o at $120K/month. You want to distill it to a self-hosted Llama-3.1-8B that handles 90%+ of queries at 1/50th the cost, with <500ms latency.
Step 1: Provision the Training Server
Server: RunPod — 4× A100 80GB SXM
| Component | Specification | Why This Choice |
|---|---|---|
| Provider | RunPod (on-demand pod) | $10.28/hr for 4× A100 — cheapest for short runs |
| GPU | 4× NVIDIA A100 80GB SXM4 | 320GB total VRAM — fits Llama 8B in full precision + large batches |
| CPU | 32 vCPUs (AMD EPYC) | Data preprocessing parallelism |
| RAM | 256GB DDR4 | Hold full dataset in memory |
| Storage | 500GB NVMe SSD | Model weights + datasets + checkpoints |
| Network | 10 Gbps | Fast model download from HuggingFace |
| Image | runpod/pytorch:2.2.0-py3.10-cuda12.1.1-devel | CUDA 12.1 for Flash Attention 2 |
| Est. Cost | ~$250 total (24h training) | One-time cost, saves $120K/month |
# Alternative servers (if RunPod unavailable):
# Lambda Labs: 4× A100 80GB — $12.00/hr (on-demand)
# Vast.ai: 4× A100 80GB — $8.50-11/hr (spot, can be preempted)
# AWS p4d.24xlarge: 8× A100 40GB — $32.77/hr (overkill but always available)
# GCP a2-ultragpu-4g: 4× A100 80GB — $29.39/hr (expensive, use spot)
# CoreWeave: 4× A100 80GB — $9.36/hr (good for long runs, reserved)
Step 2: Environment Setup (SSH into server)
# SSH into your RunPod instance
ssh root@{your-runpod-ip} -p 22 -i ~/.ssh/runpod_key
# Verify GPUs are visible
nvidia-smi
# Should show: 4× A100 80GB, CUDA 12.1, Driver 535+
# Install dependencies
pip install torch==2.2.0 transformers==4.44.0 datasets==2.21.0 \
accelerate==0.33.0 peft==0.12.0 trl==0.9.6 \
bitsandbytes==0.43.0 flash-attn==2.6.3 \
wandb==0.17.0 vllm==0.5.5 sentencepiece protobuf
# Login to HuggingFace (for gated models like Llama)
huggingface-cli login --token hf_YOUR_TOKEN_HERE
# Login to Weights & Biases (training monitoring)
wandb login YOUR_WANDB_API_KEY
# Create project directory
mkdir -p /workspace/rag-distillation/{data,models,scripts,checkpoints}
cd /workspace/rag-distillation
Step 3: Generate Training Data from Teacher (GPT-4o)
# generate_training_data.py
# Run this BEFORE provisioning GPU server (use your local machine + API)
# Cost: ~$300-500 for 50K examples at GPT-4o pricing
import json, os, asyncio
from openai import AsyncOpenAI
client = AsyncOpenAI(api_key=os.getenv("OPENAI_API_KEY"))
SYSTEM_PROMPT = """You are a helpful assistant for [YourProduct].
Answer the user's question using ONLY the provided context.
If the context doesn't contain the answer, say "I don't have enough
information to answer that." Always cite which document you used."""
async def generate_example(query, retrieved_chunks):
context = "\n\n".join([
f"[Doc {i+1}: {c['title']}]\n{c['text']}"
for i, c in enumerate(retrieved_chunks[:5])
])
messages = [
{"role": "system", "content": SYSTEM_PROMPT},
{"role": "user", "content": f"Context:\n{context}\n\nQuestion: {query}"}
]
response = await client.chat.completions.create(
model="gpt-4o",
messages=messages,
temperature=0.3,
max_tokens=1024
)
return {
"system": SYSTEM_PROMPT,
"query": query,
"context": context,
"response": response.choices[0].message.content,
"model": "gpt-4o"
}
# Process queries from your production logs
async def main():
queries = json.load(open("production_queries.json")) # 50K queries
chunks_db = json.load(open("retrieved_chunks.json"))
semaphore = asyncio.Semaphore(50) # 50 concurrent requests
results = []
async def bounded_generate(q):
async with semaphore:
return await generate_example(q["text"], chunks_db[q["id"]])
tasks = [bounded_generate(q) for q in queries]
results = await asyncio.gather(*tasks)
with open("data/training_data_gpt4o.jsonl", "w") as f:
for r in results:
f.write(json.dumps(r) + "\n")
print(f"Generated {len(results)} training examples")
asyncio.run(main())
Step 4: Format Data for SFT Training
# prepare_dataset.py — Convert to Llama 3.1 chat format
import json
from datasets import Dataset, DatasetDict
def format_for_llama31(example):
"""Convert to Llama 3.1 chat template format"""
conversation = [
{"role": "system", "content": example["system"]},
{"role": "user", "content": f"Context:\n{example['context']}\n\nQuestion: {example['query']}"},
{"role": "assistant", "content": example["response"]}
]
return {"conversations": conversation}
# Load and split
data = [json.loads(line) for line in open("data/training_data_gpt4o.jsonl")]
dataset = Dataset.from_list(data).map(format_for_llama31)
split = dataset.train_test_split(test_size=0.1, seed=42)
split = DatasetDict({
"train": split["train"],
"validation": split["test"]
})
split.save_to_disk("data/rag_chatbot_dataset")
print(f"Train: {len(split['train'])}, Val: {len(split['validation'])}")
Step 5: Train with QLoRA (the actual training script)
# train_rag_chatbot.py — Main training script
import torch
from transformers import (
AutoModelForCausalLM, AutoTokenizer,
BitsAndBytesConfig, TrainingArguments
)
from peft import LoraConfig, get_peft_model, prepare_model_for_kbit_training
from trl import SFTTrainer
from datasets import load_from_disk
# === CONFIG ===
MODEL_ID = "meta-llama/Llama-3.1-8B-Instruct"
OUTPUT_DIR = "checkpoints/rag-chatbot-llama31-8b"
DATASET_PATH = "data/rag_chatbot_dataset"
# === Load tokenizer ===
tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)
tokenizer.pad_token = tokenizer.eos_token
tokenizer.padding_side = "right"
# === 4-bit quantization config (QLoRA) ===
bnb_config = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_quant_type="nf4",
bnb_4bit_compute_dtype=torch.bfloat16,
bnb_4bit_use_double_quant=True # nested quantization
)
# === Load base model ===
model = AutoModelForCausalLM.from_pretrained(
MODEL_ID,
quantization_config=bnb_config,
device_map="auto", # spread across 4× A100
torch_dtype=torch.bfloat16,
attn_implementation="flash_attention_2"
)
model = prepare_model_for_kbit_training(model)
# === LoRA config ===
lora_config = LoraConfig(
r=64, # rank (higher = more capacity)
lora_alpha=128, # scaling factor
target_modules=[
"q_proj", "k_proj", "v_proj", "o_proj",
"gate_proj", "up_proj", "down_proj"
],
lora_dropout=0.05,
bias="none",
task_type="CAUSAL_LM"
)
model = get_peft_model(model, lora_config)
model.print_trainable_parameters()
# → trainable params: 167M/8.03B (2.08% of model)
# === Dataset ===
dataset = load_from_disk(DATASET_PATH)
def formatting_func(example):
return tokenizer.apply_chat_template(
example["conversations"],
tokenize=False,
add_generation_prompt=False
)
# === Training arguments ===
training_args = TrainingArguments(
output_dir=OUTPUT_DIR,
num_train_epochs=3,
per_device_train_batch_size=4,
per_device_eval_batch_size=4,
gradient_accumulation_steps=4, # effective batch = 4×4×4 = 64
learning_rate=2e-4,
lr_scheduler_type="cosine",
warmup_ratio=0.03,
weight_decay=0.01,
bf16=True,
logging_steps=10,
eval_strategy="steps",
eval_steps=200,
save_strategy="steps",
save_steps=200,
save_total_limit=3,
load_best_model_at_end=True,
metric_for_best_model="eval_loss",
report_to="wandb",
run_name="rag-chatbot-llama31-8b-qlora",
gradient_checkpointing=True,
max_grad_norm=0.3,
dataloader_num_workers=4
)
# === Trainer ===
trainer = SFTTrainer(
model=model,
train_dataset=dataset["train"],
eval_dataset=dataset["validation"],
tokenizer=tokenizer,
args=training_args,
formatting_func=formatting_func,
max_seq_length=4096,
packing=True # pack short examples together
)
# === Train ===
trainer.train()
trainer.save_model(f"{OUTPUT_DIR}/final")
tokenizer.save_pretrained(f"{OUTPUT_DIR}/final")
Step 6: Launch Training
# Launch on 4× A100 with accelerate
accelerate launch --num_processes 4 --mixed_precision bf16 \
scripts/train_rag_chatbot.py
# Expected output:
# Epoch 1/3: loss=1.42 → 0.89 (45K examples, ~4h)
# Epoch 2/3: loss=0.89 → 0.71 (~4h)
# Epoch 3/3: loss=0.71 → 0.64 (~4h)
# Total time: ~12-14 hours on 4× A100
# Total cost: ~$130 on RunPod ($10.28/hr × 13h)
Step 7: Merge LoRA + Quantize + Deploy
# merge_and_export.py
from peft import PeftModel
from transformers import AutoModelForCausalLM, AutoTokenizer
# Load base + LoRA, merge weights
base = AutoModelForCausalLM.from_pretrained(
"meta-llama/Llama-3.1-8B-Instruct",
torch_dtype=torch.bfloat16, device_map="auto"
)
model = PeftModel.from_pretrained(base, "checkpoints/rag-chatbot-llama31-8b/final")
merged = model.merge_and_unload()
merged.save_pretrained("models/rag-chatbot-merged")
# Quantize to GPTQ 4-bit for production (saves 75% VRAM)
# pip install auto-gptq
from auto_gptq import AutoGPTQForCausalLM, BaseQuantizeConfig
quantize_config = BaseQuantizeConfig(bits=4, group_size=128, desc_act=False)
quantized = AutoGPTQForCausalLM.from_pretrained(
"models/rag-chatbot-merged", quantize_config
)
quantized.quantize(calibration_dataset) # ~30 min
quantized.save_quantized("models/rag-chatbot-gptq-4bit")
# Deploy with vLLM on a single A10G (24GB, $0.75/hr on RunPod)
# vllm serve models/rag-chatbot-gptq-4bit \
# --host 0.0.0.0 --port 8000 \
# --max-model-len 4096 \
# --gpu-memory-utilization 0.90 \
# --quantization gptq
Training cost: $130 (RunPod) + $400 (GPT-4o data generation) = $530 total
Serving cost: 1× A10G on RunPod = $540/month (vs $120K/month GPT-4o API)
Quality: 91-94% of GPT-4o on domain-specific RAG tasks (measured by RAGAS faithfulness)
Latency: 180-350ms (vs 2-4s GPT-4o API)
Payback: <1 day
Use Case 2: Voice Agent — Distill for Real-Time Speech
Distill GPT-4o-mini → Phi-3.5-mini-instruct on Lambda Labs (1× A100 80GB)
You're building a voice agent (phone support, in-app voice assistant). The pipeline: Whisper STT → LLM reasoning → TTS output. The LLM must respond in <300ms to feel natural. GPT-4o-mini is 600-1200ms — too slow. You need a tiny model (<4B params) that runs on a single GPU with <200ms latency.
Architecture: Voice Agent Pipeline
Server: Lambda Labs — 1× A100 80GB
| Component | Specification |
|---|---|
| Provider | Lambda Labs — 1× A100 80GB ($1.29/hr on-demand) |
| Training time | ~6 hours (small model, 30K examples) |
| Training cost | ~$8 GPU + ~$50 GPT-4o-mini data gen = $58 total |
| Serving GPU | 1× RTX 4090 24GB ($0.35/hr RunPod) — Phi-3.5-mini fits easily |
| Serving cost | $252/month (RTX 4090) — handles 50+ concurrent voice sessions |
Step 1: Generate Voice-Specific Training Data
# generate_voice_data.py — Optimized for voice: short, direct answers
import json, asyncio
from openai import AsyncOpenAI
client = AsyncOpenAI()
VOICE_SYSTEM = """You are a voice assistant for [CompanyName].
Rules for voice responses:
- Keep answers under 3 sentences (people are LISTENING, not reading)
- Use simple, spoken language (no markdown, no bullet points, no URLs)
- If you need to perform an action, output: ACTION: {action_name}({params})
- Confirm actions before executing: "I'll transfer you now, one moment"
- For complex questions, summarize and offer to send details via email"""
SCENARIOS = [
"check order status", "cancel subscription",
"billing question", "technical support",
"product recommendation", "appointment scheduling",
"complaint handling", "account update",
"transfer to human", "FAQ answers"
]
async def generate_conversation(scenario):
# Generate a multi-turn voice conversation
messages = [{"role": "system", "content": f"""Generate a realistic 4-6 turn phone
conversation for scenario: {scenario}. Format as JSON array of
{{"role": "user"/"assistant", "content": "..."}}. User messages should
sound like natural speech (not text). Assistant responses must be short
(1-3 sentences, voice-friendly)."""}]
response = await client.chat.completions.create(
model="gpt-4o-mini", messages=messages,
temperature=0.8, max_tokens=1500
)
return json.loads(response.choices[0].message.content)
# Generate 30K conversations (3K per scenario)
# Cost: ~$50 with GPT-4o-mini
Step 2: Train Phi-3.5-mini with Full Fine-Tuning
# train_voice_agent.py — Full fine-tune (model is small enough)
from transformers import AutoModelForCausalLM, AutoTokenizer, TrainingArguments
from trl import SFTTrainer
MODEL_ID = "microsoft/Phi-3.5-mini-instruct" # 3.8B params
tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)
model = AutoModelForCausalLM.from_pretrained(
MODEL_ID, torch_dtype=torch.bfloat16,
attn_implementation="flash_attention_2",
device_map="auto"
)
training_args = TrainingArguments(
output_dir="checkpoints/voice-agent-phi35",
num_train_epochs=5, # more epochs for small dataset
per_device_train_batch_size=8,
gradient_accumulation_steps=2, # effective batch = 16
learning_rate=5e-5, # lower LR for full fine-tune
lr_scheduler_type="cosine",
warmup_ratio=0.05,
bf16=True,
gradient_checkpointing=True,
logging_steps=25,
eval_strategy="steps",
eval_steps=100,
save_strategy="steps",
save_steps=100,
report_to="wandb",
run_name="voice-agent-phi35-full-ft",
max_grad_norm=1.0
)
trainer = SFTTrainer(
model=model,
train_dataset=train_dataset,
eval_dataset=val_dataset,
tokenizer=tokenizer,
args=training_args,
formatting_func=format_phi35_chat,
max_seq_length=2048, # voice conversations are short
packing=True
)
trainer.train()
trainer.save_model("models/voice-agent-phi35-final")
# Training time: ~6 hours on 1× A100 80GB
# Peak VRAM: ~45GB (full fine-tune with gradient checkpointing)
Step 3: Deploy for Real-Time Voice
# Serve with vLLM on RTX 4090 (production inference server)
# RunPod: RTX 4090 24GB — $0.35/hr ($252/month)
# Quantize to AWQ 4-bit first for faster inference
pip install autoawq
python -c "
from awq import AutoAWQForCausalLM
from transformers import AutoTokenizer
model = AutoAWQForCausalLM.from_pretrained('models/voice-agent-phi35-final')
tokenizer = AutoTokenizer.from_pretrained('models/voice-agent-phi35-final')
model.quantize(tokenizer, quant_config={'zero_point': True, 'q_group_size': 128, 'w_bit': 4})
model.save_quantized('models/voice-agent-awq-4bit')
"
# Launch vLLM with streaming (critical for voice — first token fast)
vllm serve models/voice-agent-awq-4bit \
--host 0.0.0.0 --port 8000 \
--max-model-len 2048 \
--gpu-memory-utilization 0.85 \
--quantization awq \
--enable-prefix-caching \
--max-num-seqs 64 # 64 concurrent voice sessions
# Benchmark: measure time-to-first-token (TTFT)
# TTFT: ~40ms (vs 300-800ms GPT-4o-mini API)
# Full response (50 tokens): ~120ms total
# Throughput: 50+ concurrent voice sessions on 1× RTX 4090
TTFT: 40ms (vs 300-800ms API) — conversational feel achieved
Quality: 88% of GPT-4o-mini on voice-specific tasks (short answers, actions)
Cost: $252/month serving (vs ~$8K/month API at 100K calls/day)
Concurrent sessions: 50+ per single RTX 4090
Use Case 3: Customer Support Chatbot
Distill Claude Sonnet → Mistral-7B-Instruct on CoreWeave (2× A100 80GB)
E-commerce platform handling 200K customer support tickets/month. Currently using Claude Sonnet API at $45K/month. Need: multi-turn conversation, order lookup, return processing, FAQ. Must handle 500 concurrent chats with <1s response time.
Server & Cost Breakdown
| Phase | Server | GPU | Duration | Cost |
|---|---|---|---|---|
| Data generation | Local machine + API | None | ~6 hours | $800 (Claude API) |
| Training | CoreWeave 2× A100 80GB | 2× A100 SXM | ~18 hours | $168 ($9.36/hr) |
| Quantization | Same server | 1× A100 | ~1 hour | $5 |
| Serving (prod) | RunPod 2× L40S 48GB | 2× L40S | Monthly | $1,440/month |
| Total one-time training | $973 | |||
| Monthly savings | $43,560/month | |||
Step 1: Generate Multi-Turn Support Conversations
# generate_support_data.py
import anthropic, json
client = anthropic.Anthropic()
SUPPORT_SYSTEM = """You are a customer support agent for [EcommerceCo].
You can:
- Look up orders: TOOL_CALL: lookup_order(order_id)
- Process returns: TOOL_CALL: initiate_return(order_id, reason)
- Check inventory: TOOL_CALL: check_stock(product_id)
- Apply discount: TOOL_CALL: apply_coupon(order_id, code)
- Transfer to human: TOOL_CALL: escalate(reason)
Rules:
- Be empathetic but efficient
- Verify customer identity before account actions
- Offer alternatives before processing returns
- Escalate if customer is angry after 2 failed resolutions"""
# Generate 80K multi-turn conversations across 15 categories
CATEGORIES = {
"order_status": 12000, "returns": 10000,
"shipping_issues": 8000, "billing": 8000,
"product_questions": 7000, "account_issues": 6000,
"complaints": 6000, "promotions": 5000,
"size_exchange": 4000, "damaged_items": 4000,
"refunds": 3000, "loyalty_program": 3000,
"gift_cards": 2000, "international": 1000,
"escalation": 1000
}
def generate_conversation(category, count):
response = client.messages.create(
model="claude-sonnet-4-20250514",
max_tokens=2000,
system=f"Generate a realistic {category} support conversation with
3-8 turns. Include tool calls where appropriate. Customer should have
varying levels of frustration. Output as JSON array.",
messages=[{"role": "user", "content": f"Generate conversation #{count}"}]
)
return json.loads(response.content[0].text)
Step 2: Train on CoreWeave
# CoreWeave setup: 2× A100 80GB SXM ($4.68/hr per GPU)
# Total: $9.36/hr × 18h = $168
# SSH into CoreWeave instance
ssh ubuntu@cw-a100-instance.coreweave.cloud
# Install env
pip install torch transformers datasets accelerate peft trl \
bitsandbytes flash-attn wandb
# Download model
huggingface-cli download mistralai/Mistral-7B-Instruct-v0.3
# Launch training (QLoRA — rank 128 for high capacity)
accelerate launch --num_processes 2 --mixed_precision bf16 \
train_support_chatbot.py \
--model_id mistralai/Mistral-7B-Instruct-v0.3 \
--dataset data/support_conversations \
--output_dir checkpoints/support-mistral-7b \
--lora_r 128 \
--lora_alpha 256 \
--epochs 3 \
--batch_size 4 \
--grad_accum 8 \
--lr 1e-4 \
--max_seq_length 4096 \
--warmup_ratio 0.03
# Expected: ~18 hours, final loss ~0.58
# Monitor: https://wandb.ai/your-team/support-chatbot
Step 3: Production Deployment with Tool Calling
# Deploy on RunPod: 2× L40S 48GB ($1.00/hr each)
# L40S has excellent int8 throughput for Mistral-7B
# Merge LoRA weights
python merge_lora.py \
--base mistralai/Mistral-7B-Instruct-v0.3 \
--lora checkpoints/support-mistral-7b/final \
--output models/support-chatbot-merged
# Quantize to AWQ 4-bit
python quantize_awq.py \
--model models/support-chatbot-merged \
--output models/support-chatbot-awq \
--bits 4 --group_size 128
# Serve with vLLM (supports tool/function calling)
vllm serve models/support-chatbot-awq \
--host 0.0.0.0 --port 8000 \
--tensor-parallel-size 2 \
--max-model-len 4096 \
--gpu-memory-utilization 0.90 \
--quantization awq \
--enable-auto-tool-choice \
--tool-call-parser mistral \
--max-num-seqs 256 # 256 concurrent chats
# Benchmark results:
# Throughput: 2,800 tokens/sec (handles 500+ concurrent chats)
# Latency p50: 180ms, p99: 450ms
# Tool call accuracy: 96.2% (tested on 5K tool-call examples)
Resolution rate: 78% automated (vs 82% with Claude Sonnet) — 95% quality retained
CSAT score: 4.1/5.0 (vs 4.3/5.0 with Claude) — customers barely notice the difference
Cost: $1,440/month (vs $45K/month) — 97% savings
Concurrent capacity: 500+ chats on 2× L40S
Use Case 4: Embedding Model Distillation for RAG
Distill text-embedding-3-large → all-MiniLM-L6-v2 on Vast.ai (1× A100 40GB)
Your RAG system embeds 2M documents + handles 500K queries/day. Using OpenAI's text-embedding-3-large API costs $18K/month and adds 50-100ms network latency per call. You want a self-hosted embedding model that's 10× faster and 90%+ as accurate on your domain.
Server & Cost
| Phase | Server | GPU | Duration | Cost |
|---|---|---|---|---|
| Teacher embedding generation | Local + OpenAI API | None | ~4 hours | $200 (API) |
| Training | Vast.ai 1× A100 40GB | 1× A100 40GB | ~8 hours | $20 (~$2.50/hr spot) |
| Serving (prod) | RunPod 1× RTX 4090 | 1× RTX 4090 24GB | Monthly | $252/month |
| Total one-time | $220 | |||
| Monthly savings | $17,748/month | |||
Step 1: Generate Teacher Embeddings
# generate_teacher_embeddings.py
import openai, json, numpy as np
from tqdm import tqdm
client = openai.OpenAI()
# Load your domain data: queries + documents
queries = json.load(open("data/production_queries.json")) # 100K queries
documents = json.load(open("data/document_chunks.json")) # 200K chunks
# Generate teacher embeddings in batches
def embed_batch(texts, model="text-embedding-3-large"):
response = client.embeddings.create(input=texts, model=model)
return [e.embedding for e in response.data]
# Embed all queries and docs
query_embeddings = []
for i in tqdm(range(0, len(queries), 100)):
batch = [q["text"] for q in queries[i:i+100]]
query_embeddings.extend(embed_batch(batch))
doc_embeddings = []
for i in tqdm(range(0, len(documents), 100)):
batch = [d["text"] for d in documents[i:i+100]]
doc_embeddings.extend(embed_batch(batch))
# Create training pairs: (query, positive_doc, hard_negatives)
# Use teacher scores to find hard negatives
training_pairs = []
for i, q_emb in enumerate(query_embeddings):
# Compute cosine similarity to all docs
scores = np.dot(doc_embeddings, q_emb)
# Positive: highest scoring (ground truth)
pos_idx = queries[i]["relevant_doc_id"]
# Hard negatives: high-scoring but not relevant
neg_indices = np.argsort(scores)[-20:]
neg_indices = [n for n in neg_indices if n != pos_idx][:7]
training_pairs.append({
"query": queries[i]["text"],
"positive": documents[pos_idx]["text"],
"negatives": [documents[n]["text"] for n in neg_indices],
"teacher_score": float(scores[pos_idx])
})
json.dump(training_pairs, open("data/embedding_training_pairs.json", "w"))
print(f"Created {len(training_pairs)} training pairs")
Step 2: Train Student Embedding Model
# train_embedding.py — Contrastive distillation with sentence-transformers
# Run on Vast.ai: 1× A100 40GB ($2.50/hr spot)
pip install sentence-transformers==3.0.0
from sentence_transformers import (
SentenceTransformer, losses, InputExample,
evaluation, SentenceTransformerTrainer,
SentenceTransformerTrainingArguments
)
from torch.utils.data import DataLoader
import json
# Load student model
student = SentenceTransformer("sentence-transformers/all-MiniLM-L6-v2")
# 22M params, 384-dim embeddings, 80MB model size
# Load training pairs
pairs = json.load(open("data/embedding_training_pairs.json"))
# Create training examples
train_examples = []
for p in pairs:
# MultipleNegativesRankingLoss: (anchor, positive)
train_examples.append(
InputExample(texts=[p["query"], p["positive"]])
)
train_dataloader = DataLoader(train_examples, shuffle=True, batch_size=256)
# Loss: MultipleNegativesRanking + knowledge distillation
train_loss = losses.MultipleNegativesRankingLoss(student)
# Evaluation
evaluator = evaluation.InformationRetrievalEvaluator(
queries={str(i): p["query"] for i, p in enumerate(pairs[:1000])},
corpus={str(i): p["positive"] for i, p in enumerate(pairs[:1000])},
relevant_docs={str(i): {str(i)} for i in range(1000)},
name="domain-eval"
)
# Train
student.fit(
train_objectives=[(train_dataloader, train_loss)],
evaluator=evaluator,
epochs=10,
evaluation_steps=500,
warmup_steps=1000,
output_path="models/domain-minilm-distilled",
show_progress_bar=True,
use_amp=True # mixed precision
)
# Training time: ~8 hours on 1× A100
# Expected recall@10: 91-93% (vs 95% teacher)
Step 3: Deploy with FastAPI + ONNX
# Export to ONNX for maximum inference speed
python -m sentence_transformers.export \
--model models/domain-minilm-distilled \
--output models/domain-minilm-onnx \
--format onnx
# serve_embeddings.py — FastAPI server
from fastapi import FastAPI
from sentence_transformers import SentenceTransformer
import numpy as np
app = FastAPI()
model = SentenceTransformer(
"models/domain-minilm-distilled",
device="cuda"
)
@app.post("/embed")
async def embed(texts: list[str]):
embeddings = model.encode(texts, batch_size=256,
normalize_embeddings=True)
return {"embeddings": embeddings.tolist()}
# Run: uvicorn serve_embeddings:app --host 0.0.0.0 --port 8001 --workers 1
# Benchmark: 3,000 embeddings/sec on RTX 4090
# Latency: 2ms per query (vs 80ms OpenAI API)
Recall@10: 92.1% (vs 95.3% teacher) — 96.6% quality retained
Latency: 2ms per query (vs 80ms API) — 40× faster
Throughput: 3,000 queries/sec on 1× RTX 4090
Model size: 80MB (vs API dependency)
Cost: $252/month serving (vs $18K/month API) — 98.6% savings
GPU Server Selection Guide
Which server for which distillation job — March 2026 pricing
Training Server Comparison
| Provider | GPU | VRAM | $/hr (On-Demand) | $/hr (Spot) | Best For |
|---|---|---|---|---|---|
| RunPod | 1× A100 80GB | 80GB | $2.57 | $1.64 | Short training runs (<24h) |
| RunPod | 4× A100 80GB | 320GB | $10.28 | $6.57 | Large model distillation (8B+) |
| Lambda Labs | 1× A100 80GB | 80GB | $1.29 | — | Cheapest A100 on-demand |
| Lambda Labs | 8× A100 80GB | 640GB | $10.32 | — | 70B+ model training |
| Vast.ai | 1× A100 40GB | 40GB | $2.80 | $1.50 | Budget embedding training |
| Vast.ai | 1× RTX 4090 | 24GB | $0.45 | $0.25 | Small model training (<3B) |
| CoreWeave | 1× A100 80GB | 80GB | $4.68 | $1.87 | Long runs (reserved pricing) |
| CoreWeave | 1× H100 80GB | 80GB | $4.76 | $1.90 | 2× faster than A100 for same price |
| AWS p4d.24xlarge | 8× A100 40GB | 320GB | $32.77 | $12.45 | Enterprise, always available |
| AWS p5.48xlarge | 8× H100 80GB | 640GB | $98.32 | $37.84 | Frontier model distillation |
| GCP a3-highgpu-8g | 8× H100 80GB | 640GB | $101.36 | $30.41 | GCP ecosystem integration |
Production Serving Server Comparison
| Provider | GPU | VRAM | $/month | Max Model Size (INT4) | Throughput |
|---|---|---|---|---|---|
| RunPod | 1× RTX 4090 | 24GB | $252 | ~14B params | ~1,500 tok/s |
| RunPod | 1× A10G | 24GB | $540 | ~14B params | ~800 tok/s |
| RunPod | 1× L40S | 48GB | $720 | ~30B params | ~2,000 tok/s |
| RunPod | 1× A100 80GB | 80GB | $1,850 | ~70B params | ~3,500 tok/s |
| AWS g5.xlarge | 1× A10G | 24GB | $730 | ~14B params | ~800 tok/s |
| AWS g6.xlarge | 1× L4 | 24GB | $530 | ~14B params | ~600 tok/s |
| Together.ai | Serverless | — | Usage-based | Any supported | High |
| Fireworks.ai | Serverless | — | Usage-based | Any supported | Very high |
Quick Decision Matrix
RAG Chatbot (8B model)
Train: RunPod 4× A100 80GB, ~$130, ~13h
Serve: RunPod 1× A10G 24GB, $540/month
Latency: 200-400ms, 100+ concurrent
Voice Agent (3-4B model)
Train: Lambda 1× A100 80GB, ~$8, ~6h
Serve: RunPod 1× RTX 4090, $252/month
TTFT: 40ms, 50+ concurrent sessions
Support Chatbot (7B model)
Train: CoreWeave 2× A100 80GB, ~$168, ~18h
Serve: RunPod 2× L40S, $1,440/month
Latency: 180ms p50, 500+ concurrent
Embedding Model (22M model)
Train: Vast.ai 1× A100 40GB, ~$20, ~8h
Serve: RunPod 1× RTX 4090, $252/month
Latency: 2ms, 3,000 queries/sec
1. Always start with spot/interruptible instances for training — save 40-60%. Use checkpointing to resume if preempted. 2. Use on-demand for serving (uptime matters). 3. RunPod and Vast.ai are cheapest for GPU rental. Lambda Labs is cheapest A100 on-demand. CoreWeave for reserved long-term. 4. For serving, RTX 4090 has best price/performance for models <14B. L40S for 14-30B. A100 for 30-70B. 5. Consider Together.ai or Fireworks.ai serverless if your traffic is bursty — you only pay per token, no idle GPU cost.
Production Checklist
De-Risk Your Distillation Rollout
Data Quality
- ☐ Collected 100K+ training examples
- ☐ Verified label distribution matches production
- ☐ Removed outliers/ambiguous examples
- ☐ Split: 80/10/10 train/val/test
- ☐ Validated teacher consistency
Training
- ☐ Ran hyperparameter sweep
- ☐ Logged learning curves
- ☐ Confirmed convergence on val set
- ☐ Saved checkpoints every 500 steps
- ☐ Tested inference latency
Evaluation
- ☐ Computed metrics on test set
- ☐ Verified >93% quality vs teacher
- ☐ Human evaluation (100 examples)
- ☐ Error analysis completed
- ☐ Documented regressions
Deployment
- ☐ Converted to ONNX/TensorRT
- ☐ Optimized model for target hardware
- ☐ Benchmarked P50/P95/P99
- ☐ Set up model serving (vLLM/Triton)
- ☐ Containerized application
A/B Testing
- ☐ Setup canary at 5%
- ☐ Monitoring dashboards ready
- ☐ Alerts configured (>2% regression)
- ☐ Rollback procedure documented
- ☐ Run for 24+ hours at each %
Monitoring
- ☐ Latency: P50/P95/P99 tracked
- ☐ Quality: Daily metric dashboard
- ☐ Error rate: <0.5% acceptable
- ☐ User feedback collected
- ☐ Weekly review of metrics
Cost Tracking
- ☐ Calculated baseline cost
- ☐ Projected monthly savings
- ☐ Break-even timeline confirmed
- ☐ ROI tracker setup
- ☐ Monthly cost report
Documentation
- ☐ Training recipe documented
- ☐ Hyperparameters recorded
- ☐ Reproducibility verified
- ☐ Inference pipeline documented
- ☐ Runbooks for support team
Go/No-Go criteria: Quality ≥93% of teacher + P95 latency <150ms + Error rate <0.5% + Cost savings >50% + All checklist items completed. If all met, proceed with 5% canary.
Post-Launch (Day 1-30)
- ☐ Day 1: 5% traffic to distilled model. Monitor every 15 min.
- ☐ Day 2: If stable, increase to 10%. Check latency P99.
- ☐ Day 3: 25% traffic. Validate quality with sample review.
- ☐ Day 4-7: 50% traffic. Full monitoring active.
- ☐ Day 8-14: 75-90% traffic. Collect user feedback.
- ☐ Day 15-30: 100% traffic. Daily metrics review.
- ☐ Day 30: Finalize ROI report. Decommission teacher model if stable.
Safety & Alignment in Distillation
Risks, Inherited Behaviors & Mitigation Strategies
Distillation does not automatically make a model safe. It can inherit or amplify dangerous capabilities from the teacher. Every distillation pipeline must include safety evaluation as a first-class concern.
Key Safety Risks
Amplified Biases
The student will learn any biases or misalignments in the teacher's outputs. Without careful filtering, the student may become worse at filtering harmful content than the teacher was, since it lacks the teacher's broader context understanding.
Overconfidence & Hallucinations
Teachers often produce overconfident (spiky) output distributions. If unchecked, the student becomes brittle on unfamiliar inputs. Calibrated Uncertainty Distillation (CUD) reshapes teacher outputs so students learn structured uncertainty, improving out-of-distribution reliability.
Prompt Injection & Data Leakage
If the teacher is an API, student training may inadvertently copy sensitive information embedded in responses. Use response redaction and rate limiting when generating teacher data. Apply PII detection (Microsoft Presidio, NVIDIA UIMA) to both training data and student responses.
Excessive Agency
Even distilled models can attempt multi-step or out-of-scope actions. OWASP's LLM Top-10 lists "Excessive Agency" and "Insecure Output Handling" as key threats. Continuous red-teaming is essential to verify the student doesn't learn to override safety filters.
Distillation as an Alignment Tool
Recent work (Redwood Research, 2025) proposes distillation as an alignment tool: by selecting only "safe" behavior trajectories from a powerful model, one can distill a smaller model that retains capabilities but hopefully lacks certain misaligned tendencies.
Pipeline: Generate many teacher outputs → Filter out problematic/reward-hacking trajectories → Train the student on the curated safe subset. This underscores that data selection is a lever for safety.
Mitigation Strategies
| Risk | Mitigation | Tools / Techniques |
|---|---|---|
| Bias Amplification | Whitelist/blacklist training data; diversify teacher sources | Ensemble teachers, bias benchmarks (BBQ, WinoBias) |
| Overconfidence | Calibrated Uncertainty Distillation (CUD); temperature tuning | Reliability diagrams, Expected Calibration Error (ECE) |
| Data Leakage | PII redaction on teacher outputs; response sanitization | Microsoft Presidio, NVIDIA UIMA, regex-based filters |
| Safety Filter Loss | Train only on post-filter outputs; include explicit safety data | Red-teaming prompts, OWASP LLM eval suite |
| Excessive Agency | Inject adversarial prompts during training; constrain output schema | PromptFoo, OpenAI Evals, custom red-team tests |
Pre-training: Audit teacher outputs for bias and PII. During training: Monitor loss curves for unexpected behavior. Post-training: Run adversarial/red-team suite. Deployment: Canary with safety-specific monitoring. Ongoing: Weekly safety metric reviews.
Privacy & Regulatory Compliance
Data Protection, Federated Distillation & Governance
Privacy-Preserving Distillation
Data Privacy via KD
Training a student on a teacher's outputs can be privacy-preserving since it avoids exposing raw training data. However, if teacher outputs contain personal data, the student can memorize it. Always de-identify and scrub sensitive fields in outputs using NLP-based PII detection.
Federated / On-Device Distillation
DistilLock demonstrates distillation without revealing teacher or student to external parties. The teacher runs inside a Trusted Execution Environment (TEE) on the data owner's device, providing only black-box outputs to the student, preserving both data privacy and model IP.
Federated Distillation Architecture
Regulatory & Governance Requirements
Audit Logs
Maintain complete audit trails of which teacher model version and dataset were used for each distillation run. Version-control every config and random seed.
Access Control
Allow only authorized personnel to initiate distillations, since the teacher model may be proprietary. Treat distillation pipelines as sensitive pipelines.
Risk Assessment
Perform risk assessments (e.g., via NIST AI RMF) for the student model. Healthcare and finance sectors require models not to leak confidential info.
Pipeline security: Treat distillation runs as sensitive — keep audit trails, restrict access, periodically review outputs against a policy suite. GDPR: Distillation is privacy-positive (avoids sending raw data to cloud). IP protection: DistilLock/TEE patterns protect both data and model weights.
Failure Modes & Mitigation
Common Pitfalls and How to Avoid Them
Catastrophic Forgetting
If fine-tuning datasets are narrow, the student may "forget" some general knowledge of the teacher.
Fix: Include a portion of original teacher data or use replay buffers during training.
Overfitting Teacher Biases
The student may overfit to teacher idiosyncrasies, especially if the teacher is itself biased on the training set.
Fix: Diversify teacher sources (ensemble distillation) or calibrate teacher outputs with CUD.
Loss of Safety Filters
If the teacher has a safety layer (filtering outputs), but distillation is done on raw teacher responses, the student may learn to ignore the teacher's content policy.
Fix: Use only "post-filter" outputs for training, or explicitly include safety training data.
Poor Calibration
Standard KD can produce overconfident students whose predicted probabilities don't match true likelihoods.
Fix: Apply Calibrated Uncertainty Distillation (CUD), temperature tuning, and mix of teacher/student losses. Always measure with reliability diagrams.
Resource Exhaustion
A naive attempt to distill a 70B model into a 7B student on a single GPU may OOM or run for months.
Fix: Use progressive/multi-stage approaches, DeepSpeed ZeRO, or scale-out training across multiple GPUs.
Data Leakage & IP Risks
If teacher outputs contain private data, the student can memorize it. The DistillGuard paper highlights IP leakage risks if a model's outputs are scraped.
Fix: Apply PII redaction on the output stream. Review the teacher model for memorization vulnerabilities.
Capacity Gap Problem
Recent research (Apple/Oxford) reveals an important scaling law for distillation: overly strong teachers can overwhelm a small student (the "capacity gap"). If the student is very small, a moderately-sized teacher may suffice. If the student is large, a top-notch teacher is needed. Matching capacities is key to avoiding training instabilities.
Solution: Use teacher cascades (multi-stage) — a chain where a large teacher first distills to an intermediate teacher, which then distills to the final small student. This gradually bridges the capacity gap.
DistillGuard: Defense Metrics
| Metric | Definition | Use Case |
|---|---|---|
| Distillation Effectiveness (DE) | How well defense techniques prevent unauthorized knowledge transfer | Protecting proprietary model IP |
| Distillation Cost (DC) | Computational cost imposed on adversarial distillation attempts | Making unauthorized distillation economically infeasible |
Treat distillation training runs as sensitive pipelines: maintain audit trails of teacher model versions and data used, restrict distillation access to authorized personnel, and periodically review distilled model outputs against a policy compliance suite.
Deployment Roadmap
Phased Plan for Production Distillation (6-9 Months)
Phase 1: Research & Pilot
Weeks 1-12- ✓ Define project scope, select teacher model
- ✓ Assemble pilot dataset (real + synthetic)
- ✓ Initial distillation proof-of-concept
- ✓ Basic evaluation (accuracy, latency)
Deliverable: Prototype student model with baseline performance vs teacher
Phase 2: Implementation
Weeks 13-24- ▶ Develop full distillation pipeline (training code)
- ▶ Add feature losses (attention/hidden state matching)
- ▶ Integrate quantization-aware training
- ▶ Build model validation suite + benchmarks
Deliverable: Full-scale distillation scripts, hyperparameter tuning report (optimal α, T, LR), quantized student
Phase 3: Security & Validation
Weeks 25-30- ▶ Develop adversarial test scenarios (injection, data leak)
- ▶ Implement privacy filters (PII redaction)
- ▶ Perform red-team evaluation (internal)
- ▶ Iterate fixes (re-distill if needed)
Deliverable: Security report on exploit cases, sanitized training data, mitigation logs
Phase 4: CI/CD & Monitoring
Weeks 31-36- ▶ Set up model registry and versioning
- ▶ Automate distillation pipeline (GitHub Actions / CI)
- ▶ Define drift and performance alerts (Prometheus)
- ▶ Conduct pilot A/B test with limited user traffic
Deliverable: Automated CI pipeline that retrains student on updated teacher/data; live monitoring dashboard
Phase 5: Scale-Out & Production
Weeks 37+- ▶ Expand distillation to additional tasks/domains
- ▶ Train quantized 4-bit and 8-bit variants
- ▶ Final deployment (k8s serving / edge packaging)
Deliverable: Production-grade model deployed, with fallback/retry logic and documented rollback procedures
Each phase should end with a review covering accuracy, latency, and safety metrics, with sign-off before proceeding. Budget 6-9 months for the full pipeline from research to production scale-out.
Research References
Key Papers, Frameworks & Industry Resources (2023-2026)
Foundational Papers
- [1] ACL Findings 2023 — Feature/Representation distillation techniques for LLMs (aclanthology.org)
- [2] Iterative Layer-wise Distillation — Efficient compression of large language models via iterative KD (arXiv:2511.05085)
- [3] Compact Language Models via Pruning and KD — Minitron: pruning+distillation compressing 15B to 8B/4B (arXiv:2407.14679)
- [4] Survey on Knowledge Distillation for LLMs — Comprehensive survey of methods, evaluation, and applications (arXiv:2407.01885)
Safety & Privacy
- [5] Calibrated Uncertainty Distillation (CUD) — Trust the uncertain teacher: distilling dark knowledge via calibrated uncertainty (arXiv:2602.12687)
- [6] DistilLock — Safeguarding LLMs from unauthorized knowledge distillation on the edge via TEE (arXiv:2510.16716)
- [7] DistillGuard — Evaluating defenses against LLM knowledge distillation, introduces DE/DC metrics (arXiv:2603.07835)
- [8] AI Safety via Distillation — Redwood Research: leveraging distillation for alignment (blog.redwoodresearch.org)
Frameworks & Industry Guides
- [9] HuggingFace Knowledge Distillation Blog — Everything you need to know about KD, including practical guides (huggingface.co)
- [10] Nebius: Concept Behind Distilling an LLM — Practical walkthrough of distillation concepts and training (nebius.com)
- [11] Intel Neural Compressor — Distillation for quantization with integrated QAT support (intel.github.io)
- [12] HuggingFace Optimum — Optimized backends and quantization tools for distilled models (huggingface.co/docs)
Edge & Specialized
- [13] TinyLLM — A framework for training and deploying language models at the edge (tinyllm.org)
- [14] Self-Distillation in Deep Learning — Emergent Mind topic on self-distillation frameworks (emergentmind.com)
- [15] Non-Destructive Task Composition with KD — Adapter-based distillation (EMNLP '23) (arXiv:2312.16261)
- [16] Future of AI Models — Small LLMs, on-device AI, and edge deployment architectures (medium.com)
This guide also references the OWASP GenAI Top 10 (genai.owasp.org) for LLM security evaluation, and the NIST AI Risk Management Framework for regulatory compliance assessments. Both are recommended for any production distillation deployment.
HuggingFace Embedding Models
Teacher & Student Models for Embedding Distillation
Teacher Models (High Quality)
| Model | Params | Dims | MTEB Score | Best For |
|---|---|---|---|---|
| Alibaba/Qwen3-Embedding-8B | 8B | 32-4096 (configurable) | 70.58 (MTEB #1) | Multilingual retrieval, highest quality teacher |
| BAAI/bge-m3 | 568M | 1024 | ~66 | Dense + sparse + multi-vector; 100+ languages |
| jinaai/jina-embeddings-v3 | 570M | 1024 | ~65 | Multi-task multilingual; most downloaded on HF |
| nvidia/NV-Embed-v2 | 7.8B (Llama-3.1 based) | 4096 | ~69 | Multilingual understanding, high-accuracy teacher |
| BAAI/bge-large-en-v1.5 | 335M | 1024 | ~64 | English-only retrieval; popular teacher baseline |
Student / Distilled Models (Fast, Production-Ready)
| Model | Params | Dims | Latency | Best For |
|---|---|---|---|---|
| sentence-transformers/all-MiniLM-L6-v2 | 22M | 384 | ~3ms | Fastest quality option; ideal RAG student target |
| sentence-transformers/all-MiniLM-L12-v2 | 33M | 384 | ~5ms | Better quality than L6; still very fast on CPU |
| sentence-transformers/all-mpnet-base-v2 | 109M | 768 | ~8ms | Highest quality sentence-transformer; good student |
| BAAI/bge-small-en-v1.5 | 33M | 384 | ~3ms | BGE family small; great for English-only RAG |
| BAAI/bge-base-en-v1.5 | 109M | 768 | ~6ms | Balanced quality/speed in BGE family |
| Alibaba/Qwen3-Embedding-0.6B | 600M | configurable | ~12ms | Smallest Qwen3 embedding; multilingual student |
| microsoft/Multilingual-MiniLM-L12-H384 | 117M | 384 | ~5ms | 100+ languages; lightweight multilingual student |
Recommended pipeline: Use Qwen3-Embedding-8B or bge-m3 as teacher → distill to all-MiniLM-L6-v2 or bge-small-en-v1.5 for 10-50x speedup with 95%+ quality retention. For multilingual: teacher Qwen3-Embedding-8B → student Multilingual-MiniLM-L12-H384.
# Load teacher and student for embedding distillation
from sentence_transformers import SentenceTransformer
# Teacher: high-quality embedding model
teacher = SentenceTransformer("BAAI/bge-m3")
# Student: fast, lightweight model
student = SentenceTransformer("sentence-transformers/all-MiniLM-L6-v2")
# Example: generate teacher embeddings for distillation
texts = ["What is machine learning?", "How does RAG work?"]
teacher_embeddings = teacher.encode(texts, normalize_embeddings=True)
student_embeddings = student.encode(texts, normalize_embeddings=True)
print(f"Teacher dims: {teacher_embeddings.shape[1]}") # 1024
print(f"Student dims: {student_embeddings.shape[1]}") # 384
HuggingFace Reranker Models
Cross-Encoder & Late Interaction Models for Distillation
Cross-Encoder Rerankers (Teachers)
| Model | Params | Context | BEIR NDCG | Best For |
|---|---|---|---|---|
| mixedbread-ai/mxbai-rerank-large-v2 | 1.5B (Qwen-2.5) | 8K tokens | SOTA | Highest quality; 100+ langs; RL-trained |
| BAAI/bge-reranker-v2-m3 | 568M | 8K tokens | ~0.62 | Multilingual reranking; strong BEIR scores |
| BAAI/bge-reranker-large | 560M | 512 tokens | ~0.60 | English reranking; well-tested in production |
| cross-encoder/ms-marco-MiniLM-L-12-v2 | 33M | 512 tokens | ~0.53 | Lightweight MS-MARCO reranker; fast inference |
Student / Distilled Rerankers (Fast)
| Model | Params | Latency (100 docs) | Best For |
|---|---|---|---|
| mixedbread-ai/mxbai-rerank-base-v2 | 500M | ~30ms | Compact SOTA reranker; great distillation target |
| BAAI/bge-reranker-base | 278M | ~25ms | Balanced quality/speed; popular production choice |
| cross-encoder/ms-marco-MiniLM-L-6-v2 | 22M | ~8ms | Ultra-fast reranker; 6-layer distilled MiniLM |
| colbert-ir/colbertv2.0 | 110M | ~12ms | Late interaction model; pre-compute doc tokens |
Use score distillation: have teacher score query-doc pairs, then train student to predict those scores via MSE loss. Combine with margin-based ranking loss for better pair-wise ordering.
ColBERT pre-computes document token representations and only performs late interaction at query time, making it 10-100x faster than full cross-encoders while retaining 95%+ quality. Great distillation target.
# Reranker distillation: teacher scores → student training
from sentence_transformers import CrossEncoder
# Teacher: high-quality reranker
teacher = CrossEncoder("BAAI/bge-reranker-v2-m3")
# Student: fast, lightweight reranker
student = CrossEncoder("cross-encoder/ms-marco-MiniLM-L-6-v2")
# Generate teacher scores for distillation
pairs = [
("What is RAG?", "RAG combines retrieval with generation..."),
("What is RAG?", "The weather is nice today..."),
]
teacher_scores = teacher.predict(pairs)
print(f"Teacher scores: {teacher_scores}") # [0.95, 0.02]
HuggingFace Generator Models
Teacher & Student LLMs for Generation Distillation
Teacher Models (Large, High Quality)
| Model | Params | Context | MMLU | Best For |
|---|---|---|---|---|
| Qwen/Qwen3-30B-A3B | 30B (3B active MoE) | 262K | ~82 | Best quality/cost ratio; 30B quality at 3B speed |
| meta-llama/Llama-3.1-70B-Instruct | 70B | 128K | ~86 | Strong teacher for general RAG generation |
| Qwen/Qwen3-32B | 32B | 128K | ~83 | Excellent multilingual teacher; strong reasoning |
| mistralai/Mistral-Large-2 | 123B | 128K | ~84 | Top-tier open teacher; strong code + reasoning |
Student / Distilled Models (Small, Fast)
| Model | Params | MMLU | HumanEval | Best For |
|---|---|---|---|---|
| HuggingFaceTB/SmolLM3-3B | 3B | ~67 | ~58 | Best-in-class 3B; beats Llama-3.2-3B, Qwen2.5-3B |
| meta-llama/Llama-3.3-8B-Instruct | 8B | 73.0 | 72.6 | Best all-around 8B; excellent RAG student |
| Qwen/Qwen3-8B | 8B | ~72 | ~75 | Strong code generation; multilingual strength |
| microsoft/Phi-4-mini-instruct | 3.8B | ~70 | ~66 | Microsoft's compact reasoning model; edge-ready |
| google/gemma-2-9b-it | 9B | ~71 | ~64 | On-device/edge deployment; good instruction following |
| mistralai/Mistral-7B-Instruct-v0.3 | 7B | ~63 | ~40 | Sliding window attention; fast inference |
| Qwen/Qwen3-0.6B | 0.6B | ~47 | ~30 | Ultra-small; IoT/mobile deployment |
Serving Frameworks
vLLM
Paged attention, continuous batching. Best for high-throughput multi-GPU serving.
pip install vllm
HF TGI
HuggingFace Text Generation Inference. Optimized Docker-based serving with flash attention.
docker pull ghcr.io/huggingface/tgi
llama.cpp / Ollama
CPU/edge inference with GGUF quantized models. Best for on-device deployment.
ollama run llama3.3
RAG generation: Distill Llama-3.1-70B → Llama-3.3-8B or SmolLM3-3B using output distillation + LoRA. Code generation: Qwen3-32B → Qwen3-8B. Edge/mobile: Any teacher → Phi-4-mini or Qwen3-0.6B with aggressive quantization (Q4). Multilingual: Qwen3-30B-A3B → Qwen3-8B. Check MTEB & Open LLM Leaderboard for latest rankings.
Glossary of Distillation Terms
18 key technical terms used throughout this guide, organized alphabetically.
C
| Term | Definition |
|---|---|
| Calibrated Uncertainty Distillation (CUD) | A distillation technique that reshapes teacher output distributions to have higher entropy on difficult examples, so the student learns structured uncertainty rather than overconfident predictions. |
| Capacity Gap | The problem where an overly large teacher overwhelms a small student during distillation. Bridged by teacher cascades (multi-stage distillation) or matching teacher/student sizes. |
| ColBERT | Contextualized Late Interaction over BERT — a retrieval model that pre-computes document token embeddings and performs late interaction at query time. 10-100× faster than full cross-encoders. Popular distillation target for rerankers. |
| Contrastive Distillation | Training the student to preserve pairwise similarities between data points in the teacher's embedding space via a contrastive loss. Useful for embedding model distillation. |
D
| Term | Definition |
|---|---|
| Dark Knowledge | The information contained in the teacher's soft probability distribution over all classes — not just the top prediction. Soft targets reveal relationships between classes that hard labels cannot. |
| DistilBERT | A 6-layer distilled version of BERT that retains 97% of BERT's accuracy with 40% fewer parameters and 60% faster inference. A landmark distillation success story. |
| DistillGuard | A framework for evaluating defenses against unauthorized knowledge distillation. Introduces Distillation Effectiveness (DE) and Distillation Cost (DC) metrics. |
| DistilLock | A privacy-preserving distillation technique using Trusted Execution Environments (TEE). The teacher runs in a secure enclave, providing only black-box outputs to the student. |
E
| Term | Definition |
|---|---|
| Ensemble Distillation | Using multiple teacher models (or ensemble outputs) to supervise a single student. The student inherits diverse knowledge and can outperform any single teacher. |
F
| Term | Definition |
|---|---|
| Feature Distillation | Matching internal layer representations (hidden states, attention maps) between teacher and student via MSE loss. Transfers richer structural information than logit-only distillation. |
K
| Term | Definition |
|---|---|
| Knowledge Distillation (KD) | The process of training a smaller student model to mimic a larger teacher model by learning from the teacher's soft probability outputs, internal features, or attention patterns. |
L
| Term | Definition |
|---|---|
| Logit Distillation | The classic KD approach: training the student to match the teacher's softened output probability distribution using KL divergence with temperature scaling. Simple and architecture-agnostic. |
M
| Term | Definition |
|---|---|
| MiniLM | A family of compact Transformer models (22M-33M params) distilled from larger models. all-MiniLM-L6-v2 is one of the most popular embedding models for production RAG systems. |
| Minitron | NVIDIA's pruning+distillation approach that compressed Nemotron 15B to 8B and 4B parameter models, matching or exceeding other 7-8B models on benchmarks. |
P
| Term | Definition |
|---|---|
| Progressive Distillation | Multi-stage distillation that iteratively compresses the model through multiple rounds, gradually removing layers or reducing dimensions. Avoids the capacity gap problem. |
S
| Term | Definition |
|---|---|
| Self-Distillation | A technique where the model acts as both teacher and student — typically using earlier training checkpoints or ensemble of its own outputs to improve calibration and generalization. |
| Soft Targets | The teacher's probability distribution over the vocabulary, softened by temperature scaling. Higher temperature produces softer (more uniform) distributions that reveal more inter-class relationships. |
T
| Term | Definition |
|---|---|
| Temperature (Distillation) | A scalar T applied to logits before softmax: softmax(z/T). Higher T produces softer distributions with more information. T=1 is standard inference; T=2-20 for distillation training. |