SFT для маленьких моделей: стабилизация обучения и борьба с коллапсом ответов | AiManual
AiManual Logo Ai / Manual.
19 Фев 2026 Гайд

Как стабилизировать SFT для маленьких трансформеров: практические советы против коллапса ответов

Практическое руководство по стабилизации Supervised Fine-Tuning для маленьких трансформеров. Методы против коллапса ответов, переобучения и деградации качества

Маленькая модель, большие проблемы

Ты взял базовую модель на 3-7 миллиардов параметров. Например, тот же Mistral-3B-Instruct или Qwen2.5-3B. Идея простая - дообучить под свою задачу. Начинаешь SFT (Supervised Fine-Tuning), а через 200 шагов модель превращается в попугая: на любой запрос выдает один и тот же шаблонный ответ. "Коллапс ответов" - когда разнообразие выходов падает до нуля.

Это не баг, это особенность маленьких трансформеров. У них меньше параметров, значит меньше "емкости" для запоминания разнообразных паттернов. При SFT они быстро переобучаются на самые частые шаблоны в датасете. Результат? Модель, которая вместо творческих ответов выдает одну заученную фразу.

Типичная ошибка: Начинаешь SFT с learning rate 5e-5, как для больших моделей. Через 500 шагов loss падает красиво, но модель уже мертва - отвечает только "Я понимаю ваш запрос. Вот информация по теме:" на все запросы.

Почему коллапс случается именно с маленькими моделями

Большие модели (70B+) имеют запас "емкости". Они могут запомнить тысячи паттернов, не перезаписывая базовые знания. Маленькие модели - другое дело. Каждый новый паттерн в SFT конкурирует за те же параметры, что хранят базовые знания языка.

Представь оперативную память в 4 ГБ против 64 ГБ. На маленькой памяти новые данные вытесняют старые. В трансформерах это называется "катастрофическое забывание", но в случае SFT это еще хуже - модель не просто забывает, она заменяет разнообразие на один доминирующий паттерн.

💡
Интересно, что похожие проблемы с переобучением есть и в других областях. Например, в статье про Step 3.5 Flash описывается, как быстрая модель начинает галлюцинировать вызовы инструментов - тоже форма коллапса, только в domain-specific контексте.

Пошаговый план: от катастрофы к стабильности

1 Начинай с микроскопического learning rate

Забудь про стандартные 5e-5 или 2e-5. Для маленьких моделей стартуй с 1e-6. Да, в десять раз меньше. Почему?

Большой LR заставляет модель быстро адаптироваться к доминирующим паттернам в датасете. Маленький LR позволяет медленно "встраивать" новые знания, не разрушая существующие.

# КАК НЕ НАДО
optimizer = AdamW(model.parameters(), lr=5e-5)

# КАК НАДО
optimizer = AdamW(model.parameters(), lr=1e-6)
# И через 1000 шагов увеличивай до 3e-6, потом до 1e-5

2 Используй cosine annealing с warmup, но не так, как все

Стандартный подход: warmup 10% от общего числа шагов, потом cosine decay. Для маленьких моделей это убийственно. Warmup должен быть дольше - 20-30% от общего обучения.

Зачем? Маленькой модели нужно больше времени "присмотреться" к данным, прежде чем начать активно обучаться.

from transformers import get_cosine_schedule_with_warmup

# Для 10,000 шагов обучения
num_warmup_steps = 3000  # 30%, а не 10%
total_steps = 10000

scheduler = get_cosine_schedule_with_warmup(
    optimizer,
    num_warmup_steps=num_warmup_steps,
    num_training_steps=total_steps
)

3 Добавь dropout даже в inference mode

Звучит еретически, но работает. Современные маленькие трансформеры (на 2026 год) часто идут с отключенным dropout для эффективности. Включай его обратно для SFT.

Не просто включи, а настрой: 0.1 для embedding слоев, 0.2 для attention, 0.1 для FFN. Это создает "шум", который мешает модели заучить точные паттерны.

# Для модели на базе transformers
model.config.attention_probs_dropout_prob = 0.2
model.config.hidden_dropout_prob = 0.1

# Пересоздай модель с новыми параметрами
# или примени к уже загруженной
for module in model.modules():
    if isinstance(module, torch.nn.Dropout):
        module.p = 0.1  # или другое значение

Важный нюанс: После SFT dropout нужно отключить для production. Но во время обучения он критически важен для предотвращения коллапса.

4 Примени label smoothing в loss function

Cross-entropy loss предполагает, что правильный токен имеет вероятность 1.0, остальные - 0.0. Для маленьких моделей это слишком жестко. Они начинают "пережимать" вероятности, что ведет к коллапсу.

Label smoothing размазывает целевую вероятность: правильный токен получает 0.9, остальные - 0.1/(vocab_size-1).

import torch.nn.functional as F

def label_smoothed_nll_loss(logits, labels, epsilon=0.1):
    """
    logits: [batch_size, seq_len, vocab_size]
    labels: [batch_size, seq_len]
    """
    log_probs = F.log_softmax(logits, dim=-1)
    nll_loss = -log_probs.gather(dim=-1, index=labels.unsqueeze(-1)).squeeze(-1)
    
    # Smoothing
    smooth_loss = -log_probs.mean(dim=-1)
    loss = (1 - epsilon) * nll_loss + epsilon * smooth_loss
    return loss.mean()

5 Смешивай датасеты - добавь 10% случайных данных

Если тренируешь на domain-specific данных (медицина, код, юридические тексты), добавь 10% общих диалогов или случайных текстов. Это предотвращает "залипание" в одном стиле.

Технически: каждый батч должен содержать 90% целевых данных и 10% случайных. Случайные данные берутся из того же источника, что и претренинг модели.

💡
Этот прием похож на технику из статьи про Abliteration, где смешивание стилей помогает убрать "цветистость" ответов. Тот же принцип, но примененный на этапе обучения.

Мониторинг: как понять, что коллапс близко

Loss падает - это хорошо. Но это не показатель отсутствия коллапса. Нужны метрики разнообразия:

  • Unique n-grams ratio: сколько уникальных 3-грамм генерирует модель на тестовых промптах. Падает ниже 0.3? Беда.
  • Self-BLEU: сравниваешь сгенерированные ответы между собой. Выше 0.7? Модель повторяется.
  • Entropy выходного распределения: средняя энтропия по всем позициям в последовательности. Резкое падение = коллапс.
# Простой мониторинг unique trigrams
def unique_trigrams_ratio(texts):
    """texts - список сгенерированных ответов"""
    total_trigrams = 0
    unique_trigrams = set()
    
    for text in texts:
        words = text.split()
        for i in range(len(words) - 2):
            trigram = tuple(words[i:i+3])
            unique_trigrams.add(trigram)
            total_trigrams += 1
    
    return len(unique_trigrams) / total_trigrams if total_trigrams > 0 else 0

Экстренные меры: если коллапс уже случился

Модель уже выдает одно и то же? Не все потеряно. Попробуй:

  1. Резко уменьши LR в 10 раз и продолжи обучение еще на 500 шагов
  2. Добавь шум в embeddings: model.input_embeddings.weight.data += torch.randn_like(embeddings) * 0.01
  3. Поменяй порядок данных - иногда помогает просто перемешать датасет по-другому
  4. Временно увеличь температуру генерации во время обучения (но это хак, не решение)

Предупреждение: Если коллапс глубокий (модель выдает ровно одну фразу на все запросы уже 1000 шагов), часто проще начать заново с чекпоинта до коллапса. Сохраняй чекпоинты каждые 500 шагов!

Глубинная проблема: почему это происходит именно сейчас

К 2026 году маленькие модели стали слишком хороши в имитации больших. Архитектурные улучшения (например, лучшее attention, эффективные FFN) дали им способность быстро обучаться, но эта же способность делает их уязвимыми к коллапсу.

Парадокс: чем лучше архитектура маленькой модели, тем осторожнее нужно быть с SFT. Старые модели на 3B параметров 2023 года были более устойчивы просто потому, что обучались медленнее.

💡
Интересный параллельный феномен описан в статье про трансформеры против State-Space моделей. Оказывается, архитектурный выбор влияет не только на качество, но и на устойчивость к разным видам сбоев, включая коллапс при SFT.

Конфигурация, которая работает (проверено на Qwen2.5-3B)

Параметр Значение Комментарий
Learning rate 1e-6 → 3e-6 → 1e-5 Постепенное увеличение каждые 1000 шагов
Warmup steps 30% от total steps Не 10%, как обычно
Dropout attention 0.2 Включать только на время SFT
Label smoothing 0.1 Обязательно, даже если в документации не советуют
Batch size 16-32 Меньше = стабильнее, но медленнее
Gradient accumulation 2-4 Для стабильности, если batch size маленький

Что будет, если проигнорировать проблему

Получишь модель, которая:

  • На любой запрос отвечает шаблонной фразой
  • Имеет красивый низкий loss на тренировочных данных
  • Бесполезна в production
  • Требует полного переобучения с нуля

Потратишь неделю на сбор датасета, день на обучение, и получишь цифрового попугая. Знакомо? Именно поэтому 80% кастомных маленьких моделей 2024-2025 годов были бракованными - их создатели не знали об этой проблеме.

Будущее: будут ли маленькие модели стабильнее?

К 2026 году появляются новые архитектурные подходы. Например, модели с "встроенным" регуляризационным механизмом, который автоматически предотвращает коллапс. Но они еще сырые.

Мой прогноз: к 2027 году проблема решится на уровне фреймворков. Hugging Face Transformers или его аналоги добавят специальный флаг `prevent_collapse=True` в Trainer. Но пока этого нет - используй методы выше.

Последний совет: если делаешь SFT для production, всегда выделяй 10% GPU времени на генерацию тестовых ответов и подсчет метрик разнообразия. Лучше потратить лишний час на мониторинг, чем неделю на переобучение.

💡
Кстати, если интересно, как подобные проблемы решаются в других контекстах, посмотри статью про проблему 3-го хода в RLHF. Там тоже речь про "аттракторы", которые засасывают модель в неоптимальные состояния, только на этапе reinforcement learning.