Distillation: Shrinking the Giant
Dec 30, 2025 • 20 min read
You cannot deploy a 405B Llama-3.1 model on a mobile phone. Inference on a 70B model costs $0.90/1M tokens. But you can train a small model to mimic the behavior of a large model — and with modern distillation techniques, the student model can achieve 90%+ of the teacher's quality in your domain at 1-5% of the inference cost. This is knowledge distillation, and it's how OpenAI trained GPT-4o-mini (small, fast, cheap) to match GPT-4 on targeted tasks.
1. Why Distillation Works Better Than Scratch Training
Training a 7B model from scratch requires trillions of tokens of internet text and hundreds of GPU-days. Distillation skips this: instead of learning what the correct answer is from human labels, the student learns the complete probability distribution of the teacher. This "dark knowledge" — knowing that the teacher assigns 15% probability to "dog" and 3% to "feline" in addition to 75% to "cat" — carries richer supervisory signal than just the label "cat":
# Classic cross-entropy loss: trains on hard labels
# "The answer is cat" → student learns P(cat) should be 1.0
loss_ce = cross_entropy(student_logits, hard_labels)
# KD loss: trains on teacher's soft probability distribution
# "Teacher assigns: cat=0.75, dog=0.15, feline=0.03..."
# Temperature T > 1 amplifies low-probability predictions (dark knowledge)
T = 4.0 # Temperature — higher = softer distribution = more dark knowledge
loss_kd = KLDivLoss(
F.log_softmax(student_logits / T, dim=-1), # Student distribution
F.softmax(teacher_logits / T, dim=-1), # Teacher distribution
) * T**2 # T^2 normalizes the loss scale
# Interpolated loss (combine both signals):
# alpha=0.7: mostly KD loss, 30% supervised loss
alpha = 0.7
loss = alpha * loss_kd + (1 - alpha) * loss_ce2. Full Implementation with HuggingFace Trainer
import torch
import torch.nn as nn
import torch.nn.functional as F
from transformers import Trainer, TrainingArguments, AutoModelForCausalLM
# Load teacher (large, frozen — only used for logit generation)
teacher_model = AutoModelForCausalLM.from_pretrained(
"meta-llama/Meta-Llama-3-70B-Instruct",
torch_dtype=torch.bfloat16,
device_map="auto",
)
teacher_model.eval() # Teacher is frozen — never update its weights
for param in teacher_model.parameters():
param.requires_grad = False
# Load student (small, will be fine-tuned)
student_model = AutoModelForCausalLM.from_pretrained(
"meta-llama/Meta-Llama-3-8B-Instruct",
torch_dtype=torch.bfloat16,
)
class DistillationTrainer(Trainer):
def __init__(self, *args, teacher=None, temperature=4.0, alpha=0.7, **kwargs):
super().__init__(*args, **kwargs)
self.teacher = teacher
self.temperature = temperature
self.alpha = alpha # Weight of KD loss vs supervised loss
def compute_loss(self, model, inputs, return_outputs=False):
# Student forward pass
outputs_student = model(**inputs)
student_logits = outputs_student.logits # [batch, seq_len, vocab_size]
# Teacher forward pass (no gradient computation)
with torch.no_grad():
outputs_teacher = self.teacher(**inputs)
teacher_logits = outputs_teacher.logits.to(student_logits.device)
# KL Divergence loss (dark knowledge transfer)
T = self.temperature
loss_kd = nn.KLDivLoss(reduction="batchmean")(
F.log_softmax(student_logits / T, dim=-1),
F.softmax(teacher_logits / T, dim=-1),
) * T * T
# Standard cross-entropy loss on ground truth labels
labels = inputs.get("labels")
loss_ce = outputs_student.loss # Built-in CE loss in HuggingFace
# Combined loss
loss = self.alpha * loss_kd + (1 - self.alpha) * loss_ce
return (loss, outputs_student) if return_outputs else loss
# Training configuration
args = TrainingArguments(
output_dir="./distilled-llama3-8b",
num_train_epochs=3,
per_device_train_batch_size=4,
gradient_accumulation_steps=8,
learning_rate=2e-5,
lr_scheduler_type="cosine",
warmup_ratio=0.05,
fp16=True,
logging_steps=50,
save_steps=500,
)
trainer = DistillationTrainer(
model=student_model,
teacher=teacher_model,
temperature=4.0, # Higher T = more dark knowledge shared
alpha=0.7, # 70% KD, 30% supervised loss
args=args,
train_dataset=train_dataset,
eval_dataset=eval_dataset,
)
trainer.train()3. Chain-of-Thought Distillation (Modern Approach)
Models like DeepSeek, Orca, and Phi-2 use a more powerful technique: they distill the teacher's reasoning process, not just its final answer:
from openai import OpenAI
import json
client = OpenAI()
# Step 1: Generate Chain-of-Thought explanations from teacher (GPT-4o)
def generate_cot_training_data(problems: list[str]) -> list[dict]:
"""Use GPT-4o to generate step-by-step reasoning for each problem."""
training_examples = []
for problem in problems:
# Teacher generates detailed reasoning + final answer
response = client.chat.completions.create(
model="gpt-4o",
messages=[{
"role": "system",
"content": """Generate a detailed step-by-step solution.
Format: <thinking>reasoning steps</thinking><answer>final answer</answer>"""
}, {
"role": "user",
"content": problem
}],
temperature=0.1,
)
cot_response = response.choices[0].message.content
training_examples.append({
"instruction": problem,
"output": cot_response, # Student learns to produce CoT reasoning
})
return training_examples
# Step 2: Fine-tune student on CoT data using SFT (simpler than logit distillation)
# The student learns to reproduce the teacher's reasoning chain
# → Much more powerful than just copying final answers
# Key insight from Orca/Phi research:
# Student trained on 5M CoT examples from GPT-4 outperforms
# student trained on 500M examples of raw text supervised by human labels4. Task-Specific Distillation (Domain Focus)
# Most production distillation is task-specific:
# Train student to match teacher ONLY on your domain (e.g., code generation, SQL)
# Step 1: Collect task-specific prompts from production
# (Your actual user queries — the most valuable training signal)
# Step 2: Get teacher logits for each prompt
import torch
from transformers import AutoTokenizer
tokenizer = AutoTokenizer.from_pretrained("meta-llama/Meta-Llama-3-70B-Instruct")
def get_teacher_logits(prompt: str, teacher_model, max_new_tokens: int = 256):
inputs = tokenizer(prompt, return_tensors="pt").to("cuda")
with torch.no_grad():
# Generate with output_scores=True to get probability distributions
outputs = teacher_model.generate(
**inputs,
max_new_tokens=max_new_tokens,
output_scores=True, # Returns logits at each generation step
return_dict_in_generate=True,
do_sample=False, # Greedy decoding for consistent training signal
)
# outputs.scores: list of [batch, vocab_size] tensors (one per generated token)
logits_stack = torch.stack(outputs.scores, dim=1) # [batch, seq_len, vocab_size]
return logits_stack.cpu(), tokenizer.decode(outputs.sequences[0])
# Step 3: Train student on (prompt, teacher_logits) pairs
# This creates a distilled student that is:
# - 10x cheaper to run
# - 5-10ms latency vs 200ms for the teacher
# - 90%+ quality within the task domain5. Distillation vs Fine-Tuning vs Quantization
| Technique | Size Reduction | Quality Impact | Use When |
|---|---|---|---|
| Quantization (INT4) | 4x smaller | Minimal — 1-3% worse | Need same model, just faster; quick win |
| Pruning | 20-50% smaller | Moderate — depends on sparsity ratio | Reducing specific layers; structured pruning |
| Fine-tuning (SFT) | No size change | Better on domain, possibly worse generally | Adapting existing model to your domain |
| Knowledge Distillation | 10-50x smaller | Significant — 85-95% of teacher quality | Need a much smaller/cheaper model for production |
| CoT Distillation | 10-50x smaller | Often beats standard KD, near teacher quality | Need reasoning capability in small model |
Frequently Asked Questions
How much data do I need for effective distillation?
For task-specific distillation (e.g., a customer support bot), 10,000-50,000 domain-specific examples is typically sufficient. The Orca paper showed that quality matters far more than quantity: 5M carefully selected CoT examples from GPT-4 produced better results than 500M standard supervised examples. Generate CoT training data from your production query logs using GPT-4o as the teacher — your actual user queries are more valuable than synthetic data.
Can I distill proprietary models like GPT-4?
OpenAI's terms of service prohibit using GPT-4 outputs to train competing models. However, you can use Claude (Anthropic allows this with some restrictions), open-source models like Llama 3.1 405B as teachers, or generate training data through your own model outputs. Check the terms of service for any model you use as a teacher before commercial deployment.
Conclusion
Knowledge distillation is one of the most powerful techniques in applied ML: taking expensive frontier model intelligence and compressing it into a fast, cheap, deployable student model. CoT distillation (training students on reasoning chains from teachers) has become the dominant approach, enabling small 7-8B models to achieve near-70B quality in specific domains. For teams deploying AI at scale, distillation often represents the difference between an ML feature that works in a demo and one that's economically viable in production.
Continue Reading
Vivek
AI EngineerFull-stack AI engineer with 4+ years building LLM-powered products, autonomous agents, and RAG pipelines. I've shipped AI features to production for startups and worked hands-on with GPT-4o, LangChain, LlamaIndex, and the Vercel AI SDK. I started OpnCrafter to share everything I wish I had when learning — no fluff, just working code and real-world context.