opncrafter

Distillation: Shrinking the Giant

Dec 30, 2025 • 20 min read

You cannot deploy a 405B Llama-3.1 model on a mobile phone. Inference on a 70B model costs $0.90/1M tokens. But you can train a small model to mimic the behavior of a large model — and with modern distillation techniques, the student model can achieve 90%+ of the teacher's quality in your domain at 1-5% of the inference cost. This is knowledge distillation, and it's how OpenAI trained GPT-4o-mini (small, fast, cheap) to match GPT-4 on targeted tasks.

1. Why Distillation Works Better Than Scratch Training

Training a 7B model from scratch requires trillions of tokens of internet text and hundreds of GPU-days. Distillation skips this: instead of learning what the correct answer is from human labels, the student learns the complete probability distribution of the teacher. This "dark knowledge" — knowing that the teacher assigns 15% probability to "dog" and 3% to "feline" in addition to 75% to "cat" — carries richer supervisory signal than just the label "cat":

# Classic cross-entropy loss: trains on hard labels
# "The answer is cat" → student learns P(cat) should be 1.0
loss_ce = cross_entropy(student_logits, hard_labels)

# KD loss: trains on teacher's soft probability distribution
# "Teacher assigns: cat=0.75, dog=0.15, feline=0.03..." 
# Temperature T > 1 amplifies low-probability predictions (dark knowledge)
T = 4.0  # Temperature — higher = softer distribution = more dark knowledge
loss_kd = KLDivLoss(
    F.log_softmax(student_logits / T, dim=-1),   # Student distribution
    F.softmax(teacher_logits / T, dim=-1),        # Teacher distribution
) * T**2  # T^2 normalizes the loss scale

# Interpolated loss (combine both signals):
# alpha=0.7: mostly KD loss, 30% supervised loss
alpha = 0.7
loss = alpha * loss_kd + (1 - alpha) * loss_ce

2. Full Implementation with HuggingFace Trainer

import torch
import torch.nn as nn
import torch.nn.functional as F
from transformers import Trainer, TrainingArguments, AutoModelForCausalLM

# Load teacher (large, frozen — only used for logit generation)
teacher_model = AutoModelForCausalLM.from_pretrained(
    "meta-llama/Meta-Llama-3-70B-Instruct",
    torch_dtype=torch.bfloat16,
    device_map="auto",
)
teacher_model.eval()  # Teacher is frozen — never update its weights
for param in teacher_model.parameters():
    param.requires_grad = False

# Load student (small, will be fine-tuned)
student_model = AutoModelForCausalLM.from_pretrained(
    "meta-llama/Meta-Llama-3-8B-Instruct",
    torch_dtype=torch.bfloat16,
)

class DistillationTrainer(Trainer):
    def __init__(self, *args, teacher=None, temperature=4.0, alpha=0.7, **kwargs):
        super().__init__(*args, **kwargs)
        self.teacher = teacher
        self.temperature = temperature
        self.alpha = alpha     # Weight of KD loss vs supervised loss
    
    def compute_loss(self, model, inputs, return_outputs=False):
        # Student forward pass
        outputs_student = model(**inputs)
        student_logits = outputs_student.logits  # [batch, seq_len, vocab_size]
        
        # Teacher forward pass (no gradient computation)
        with torch.no_grad():
            outputs_teacher = self.teacher(**inputs)
        teacher_logits = outputs_teacher.logits.to(student_logits.device)
        
        # KL Divergence loss (dark knowledge transfer)
        T = self.temperature
        loss_kd = nn.KLDivLoss(reduction="batchmean")(
            F.log_softmax(student_logits / T, dim=-1),
            F.softmax(teacher_logits / T, dim=-1),
        ) * T * T
        
        # Standard cross-entropy loss on ground truth labels
        labels = inputs.get("labels")
        loss_ce = outputs_student.loss  # Built-in CE loss in HuggingFace
        
        # Combined loss
        loss = self.alpha * loss_kd + (1 - self.alpha) * loss_ce
        
        return (loss, outputs_student) if return_outputs else loss

# Training configuration
args = TrainingArguments(
    output_dir="./distilled-llama3-8b",
    num_train_epochs=3,
    per_device_train_batch_size=4,
    gradient_accumulation_steps=8,
    learning_rate=2e-5,
    lr_scheduler_type="cosine",
    warmup_ratio=0.05,
    fp16=True,
    logging_steps=50,
    save_steps=500,
)

trainer = DistillationTrainer(
    model=student_model,
    teacher=teacher_model,
    temperature=4.0,          # Higher T = more dark knowledge shared
    alpha=0.7,                # 70% KD, 30% supervised loss
    args=args,
    train_dataset=train_dataset,
    eval_dataset=eval_dataset,
)

trainer.train()

3. Chain-of-Thought Distillation (Modern Approach)

Models like DeepSeek, Orca, and Phi-2 use a more powerful technique: they distill the teacher's reasoning process, not just its final answer:

from openai import OpenAI
import json

client = OpenAI()

# Step 1: Generate Chain-of-Thought explanations from teacher (GPT-4o)
def generate_cot_training_data(problems: list[str]) -> list[dict]:
    """Use GPT-4o to generate step-by-step reasoning for each problem."""
    training_examples = []
    
    for problem in problems:
        # Teacher generates detailed reasoning + final answer
        response = client.chat.completions.create(
            model="gpt-4o",
            messages=[{
                "role": "system",
                "content": """Generate a detailed step-by-step solution. 
Format: <thinking>reasoning steps</thinking><answer>final answer</answer>"""
            }, {
                "role": "user", 
                "content": problem
            }],
            temperature=0.1,
        )
        
        cot_response = response.choices[0].message.content
        training_examples.append({
            "instruction": problem,
            "output": cot_response,  # Student learns to produce CoT reasoning
        })
    
    return training_examples

# Step 2: Fine-tune student on CoT data using SFT (simpler than logit distillation)
# The student learns to reproduce the teacher's reasoning chain
# → Much more powerful than just copying final answers

# Key insight from Orca/Phi research:
# Student trained on 5M CoT examples from GPT-4 outperforms 
# student trained on 500M examples of raw text supervised by human labels

4. Task-Specific Distillation (Domain Focus)

# Most production distillation is task-specific: 
# Train student to match teacher ONLY on your domain (e.g., code generation, SQL)

# Step 1: Collect task-specific prompts from production
# (Your actual user queries — the most valuable training signal)

# Step 2: Get teacher logits for each prompt
import torch
from transformers import AutoTokenizer

tokenizer = AutoTokenizer.from_pretrained("meta-llama/Meta-Llama-3-70B-Instruct")

def get_teacher_logits(prompt: str, teacher_model, max_new_tokens: int = 256):
    inputs = tokenizer(prompt, return_tensors="pt").to("cuda")
    
    with torch.no_grad():
        # Generate with output_scores=True to get probability distributions
        outputs = teacher_model.generate(
            **inputs,
            max_new_tokens=max_new_tokens,
            output_scores=True,  # Returns logits at each generation step
            return_dict_in_generate=True,
            do_sample=False,     # Greedy decoding for consistent training signal
        )
    
    # outputs.scores: list of [batch, vocab_size] tensors (one per generated token)
    logits_stack = torch.stack(outputs.scores, dim=1)  # [batch, seq_len, vocab_size]
    return logits_stack.cpu(), tokenizer.decode(outputs.sequences[0])

# Step 3: Train student on (prompt, teacher_logits) pairs
# This creates a distilled student that is:
# - 10x cheaper to run
# - 5-10ms latency vs 200ms for the teacher
# - 90%+ quality within the task domain

5. Distillation vs Fine-Tuning vs Quantization

TechniqueSize ReductionQuality ImpactUse When
Quantization (INT4)4x smallerMinimal — 1-3% worseNeed same model, just faster; quick win
Pruning20-50% smallerModerate — depends on sparsity ratioReducing specific layers; structured pruning
Fine-tuning (SFT)No size changeBetter on domain, possibly worse generallyAdapting existing model to your domain
Knowledge Distillation10-50x smallerSignificant — 85-95% of teacher qualityNeed a much smaller/cheaper model for production
CoT Distillation10-50x smallerOften beats standard KD, near teacher qualityNeed reasoning capability in small model

Frequently Asked Questions

How much data do I need for effective distillation?

For task-specific distillation (e.g., a customer support bot), 10,000-50,000 domain-specific examples is typically sufficient. The Orca paper showed that quality matters far more than quantity: 5M carefully selected CoT examples from GPT-4 produced better results than 500M standard supervised examples. Generate CoT training data from your production query logs using GPT-4o as the teacher — your actual user queries are more valuable than synthetic data.

Can I distill proprietary models like GPT-4?

OpenAI's terms of service prohibit using GPT-4 outputs to train competing models. However, you can use Claude (Anthropic allows this with some restrictions), open-source models like Llama 3.1 405B as teachers, or generate training data through your own model outputs. Check the terms of service for any model you use as a teacher before commercial deployment.

Conclusion

Knowledge distillation is one of the most powerful techniques in applied ML: taking expensive frontier model intelligence and compressing it into a fast, cheap, deployable student model. CoT distillation (training students on reasoning chains from teachers) has become the dominant approach, enabling small 7-8B models to achieve near-70B quality in specific domains. For teams deploying AI at scale, distillation often represents the difference between an ML feature that works in a demo and one that's economically viable in production.

Continue Reading

👨‍💻
Written by

Vivek

AI Engineer

Full-stack AI engineer with 4+ years building LLM-powered products, autonomous agents, and RAG pipelines. I've shipped AI features to production for startups and worked hands-on with GPT-4o, LangChain, LlamaIndex, and the Vercel AI SDK. I started OpnCrafter to share everything I wish I had when learning — no fluff, just working code and real-world context.

GPT-4oLangChainNext.jsVector DBsRAGVercel AI SDK