Supervised Fine-tuning#
Overview#
Supervised Fine-Tuning (SFT) adapts a pretrained large language model (LLM) to follow instructions, match a desired style, or specialize in a domain. This tutorial covers:
What SFT is optimizing (theory)
How SFT is implemented in practice (concatenation + causal masking + loss masking)
What catastrophic forgetting is and why it happens
Practical mitigation strategies
1. What problem is SFT solving?#
A pretrained autoregressive LLM is typically trained with next-token prediction on large-scale text:
Pretraining gives broad language competence, but it does not guarantee reliable instruction-following.
SFT uses supervised examples of prompt-response pairs:
Prompt \(x\): instruction / user message / context
Response \(y\): the ideal assistant output
The SFT objective is to maximize the conditional likelihood of responses given prompts:
Equivalently, minimize negative log-likelihood:
2. Token-level SFT loss (teacher forcing)#
Let the response tokens be \(y=(y_1,\dots,y_T)\). For an autoregressive model:
So the standard cross-entropy / negative log-likelihood loss is:
In training, we feed the ground-truth previous response tokens \(y_{<t}\) rather than the model’s sampled tokens. This is called teacher forcing.
3. How SFT is implemented for decoder-only Transformers#
Most modern LLMs are decoder-only Transformers (causal language models). SFT is commonly implemented by concatenating prompt and response into a single token sequence.
3.1 Sequence construction#
Given a prompt \(x\) and response \(y\), we build a single sequence:
In tokens:
<BOS> prompt_tokens response_tokens <EOS>
For chat models, \(x\) and \(y\) are usually wrapped with a chat template (e.g., System/User/Assistant markers). The mechanics remain the same.
3.2 Causal attention mask#
Even though the full sequence is provided, the model uses a causal mask so position \(t\) can only attend to positions \(\le t\). That keeps the factorization:
4. Loss masking: compute loss only on the response#
A key practical detail: we usually compute loss only for tokens in the response.
4.1 Why not compute loss on prompt tokens?#
At inference time, the prompt is given as input context; we do not want the model to learn to “generate” the prompt. So prompt tokens are treated as conditioning context.
4.2 Masked loss#
Let \(m_t \in \{0,1\}\) indicate whether position \(t\) is part of the response. Then:
This means:
prompt tokens: \(m_t=0\) (no gradient)
response tokens: \(m_t=1\) (contribute to loss)
5. Why training is parallel but inference is sequential#
5.1 Training (SFT)#
During training, the full prompt+response sequence is known. The model can compute logits for all positions in a single forward pass (batched matrix ops). Even though dependencies are left-to-right, computation is parallelizable across positions.
5.2 Inference (decoding)#
During inference, future response tokens are unknown. So generation must proceed token-by-token:
Start with prompt tokens
Predict next token
Append it
Repeat
This is inherently sequential.
Catastrophic Forgetting in SFT#
6. What is catastrophic forgetting?#
Catastrophic forgetting means that after fine-tuning on a new dataset, the model improves on the new distribution but degrades on previously learned capabilities. In LLMs, “previous capabilities” can include:
general instruction following
reasoning (math/logic)
coding
factual knowledge and breadth
conversational naturalness
safety/refusal behaviors (if applicable)
7. Why does forgetting happen?#
SFT optimizes only the new dataset objective:
If \(D_{\text{SFT}}\) is narrow or distribution-shifted relative to the model’s pretraining/instruction distribution, gradients can “pull” parameters toward the new behavior and overwrite representations that supported old behaviors.
Common triggers:
Narrow single-domain dataset
Highly templated outputs / style bias
Too many steps or too large learning rate
Small dataset leading to overfitting + style collapse
8. Forgetting vs classic overfitting#
They are related but not identical:
Overfitting: memorization / spurious correlations, train improves but test degrades on the same task distribution.
Forgetting: new task improves but other tasks/capabilities degrade due to interference/overwriting.
In practice, narrow SFT can cause both at once.
Mitigating Catastrophic Forgetting#
9. Data replay / mixture training (most practical)#
Mix your new dataset with a replay dataset representing capabilities you want to preserve:
If the new dataset is narrow, a common starting point is \(\alpha \in [0.3, 0.7]\).
Replay sources:
a subset of high-quality instruction tuning data
a curated “keep skills” set (math/code/general QA)
10. Reduce update aggressiveness#
If forgetting is severe, your fine-tuning is likely too aggressive. Common fixes:
lower learning rate
fewer epochs / fewer steps
early stopping
mild weight decay
11. Parameter-efficient fine-tuning (LoRA / QLoRA)#
LoRA-style methods adapt behavior by learning low-rank updates while keeping most base weights fixed. This often preserves general capabilities better than full fine-tuning.
12. Regularization toward the base model#
12.1 L2 regularization to initial weights#
Let \(\theta_0\) be pretrained weights. Add:
12.2 KL regularization on outputs#
Penalize deviation in output distributions from the base model:
13. Multi-task balancing and batching#
If you fine-tune across multiple tasks, use explicit multi-task sampling:
tag samples by task
balance proportions
ensure batches contain diverse tasks
This reduces specialization collapse.
14. Freeze parts of the model#
Freezing lower layers (embeddings, early transformer blocks) can help preserve general language representations. Fine-tune only higher layers to reduce overwriting.
15. Improve dataset diversity and quality#
Many forgetting issues are data-driven. To reduce drift:
increase prompt diversity
avoid overly repetitive templates
include multiple styles and lengths
remove low-quality or inconsistent labels
Summary#
SFT trains an LLM to maximize \(\log p_{\theta}(y\mid x)\) using teacher forcing on concatenated prompt+response sequences, while computing loss only on response tokens. Because SFT focuses on a potentially narrow distribution, it can cause catastrophic forgetting. The most practical mitigations are replay/mixed training, gentler optimization (LR/steps), and parameter-efficient tuning (LoRA/QLoRA).