Проблема: почему PPO такой сложный для выравнивания LLM?
Если вы когда-либо пытались выравнивать большие языковые модели с помощью Reinforcement Learning from Human Feedback (RLHF), то наверняка сталкивались с Proximal Policy Optimization (PPO). И знаете что? Это настоящая боль.
PPO требует:
- Отдельную модель-критика для оценки наград
- Сложную настройку гиперпараметров
- Нестабильный процесс обучения
- Огромные вычислительные ресурсы
- Проблемы с переобучением и "режимом коллапса"
Режим коллапса — это когда модель начинает генерировать однотипные, шаблонные ответы вместо разнообразных и осмысленных. В PPO это частая проблема из-за неправильной балансировки наград.
Решение: Direct Preference Optimization (DPO)
DPO (Direct Preference Optimization) — это элегантный математический трюк, который позволяет нам обойти все сложности PPO. Вместо того чтобы обучать отдельную модель-критика и балансировать сложные награды, DPO сводит всё к одной простой формуле.
1 Магия одной формулы
Вот она — та самая формула, которая меняет всё:
L_DPO = -log(sigma(beta * log(pi_theta(y_w|x)/pi_ref(y_w|x)) - beta * log(pi_theta(y_l|x)/pi_ref(y_l|x))))
Где:
pi_theta— текущая модель, которую мы обучаемpi_ref— референсная модель (обычно SFT-версия)y_w— предпочтительный (winning) ответy_l— непредпочтительный (losing) ответbeta— гиперпараметр температурыsigma— сигмоидная функция
Пошаговый план реализации DPO
1 Подготовка данных
Вам нужны данные в формате промпт + два ответа + предпочтение:
dataset = [
{
"prompt": "Объясни теорию относительности",
"chosen": "Теория относительности Эйнштейна утверждает...",
"rejected": "Относительность — это когда всё относительно..."
},
# ... больше примеров
]
2 Загрузка моделей
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
# Текущая модель для обучения
model = AutoModelForCausalLM.from_pretrained("your-base-model")
# Референсная модель (обычно SFT-версия)
ref_model = AutoModelForCausalLM.from_pretrained("your-sft-model")
tokenizer = AutoTokenizer.from_pretrained("your-base-model")
# Замораживаем референсную модель
ref_model.eval()
for param in ref_model.parameters():
param.requires_grad = False
3 Реализация функции потерь DPO
def dpo_loss(model, ref_model, batch, beta=0.1):
"""
Вычисляет DPO loss для батча данных.
Args:
model: обучаемая модель
ref_model: референсная модель
batch: батч данных с промптами и парами ответов
beta: гиперпараметр температуры
"""
# Токенизируем промпты и ответы
prompts = batch["prompt"]
chosen_responses = batch["chosen"]
rejected_responses = batch["rejected"]
# Конкатенируем промпт с ответами
chosen_texts = [p + r for p, r in zip(prompts, chosen_responses)]
rejected_texts = [p + r for p, r in zip(prompts, rejected_responses)]
# Токенизация
chosen_inputs = tokenizer(chosen_texts, return_tensors="pt", padding=True, truncation=True)
rejected_inputs = tokenizer(rejected_texts, return_tensors="pt", padding=True, truncation=True)
# Вычисляем логарифмы вероятностей
with torch.no_grad():
ref_chosen_logps = compute_logprobs(ref_model, chosen_inputs)
ref_rejected_logps = compute_logprobs(ref_model, rejected_inputs)
model_chosen_logps = compute_logprobs(model, chosen_inputs)
model_rejected_logps = compute_logprobs(model, rejected_inputs)
# Вычисляем логарифмы отношения вероятностей
pi_logratios = model_chosen_logps - model_rejected_logps
ref_logratios = ref_chosen_logps - ref_rejected_logps
# Формула DPO
losses = -torch.nn.functional.logsigmoid(beta * (pi_logratios - ref_logratios))
return losses.mean()
def compute_logprobs(model, inputs):
"""Вычисляет логарифмы вероятностей для последовательностей."""
outputs = model(**inputs)
logits = outputs.logits
# Сдвигаем логитсы на один токен вперед
shift_logits = logits[..., :-1, :].contiguous()
shift_labels = inputs["input_ids"][..., 1:].contiguous()
# Вычисляем потери
loss_fct = torch.nn.CrossEntropyLoss(reduction="none")
loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)),
shift_labels.view(-1))
# Преобразуем потери в логарифмы вероятностей
logprobs = -loss.view(shift_labels.size())
# Суммируем по длине последовательности
return logprobs.sum(dim=-1)
4 Процесс обучения
from torch.utils.data import DataLoader
from transformers import AdamW
# Создаем DataLoader
train_dataloader = DataLoader(dataset, batch_size=4, shuffle=True)
# Оптимизатор
optimizer = AdamW(model.parameters(), lr=1e-5)
# Цикл обучения
for epoch in range(3): # Обычно достаточно 1-3 эпох
model.train()
for batch in train_dataloader:
optimizer.zero_grad()
# Вычисляем DPO loss
loss = dpo_loss(model, ref_model, batch, beta=0.1)
# Обратное распространение
loss.backward()
optimizer.step()
print(f"Epoch {epoch}, Loss: {loss.item():.4f}")
Нюансы и распространенные ошибки
| Проблема | Причина | Решение |
|---|---|---|
| Модель начинает генерировать пустые ответы | Слишком большой beta или переобучение | Уменьшите beta (0.01-0.1), добавьте регуляризацию KL |
| Потери не уменьшаются | Некачественные данные предпочтений | Пересмотрите аннотации, добавьте больше контрастных примеров |
| Всплески потерь | Слишком большой learning rate | Используйте learning rate 1e-6 до 1e-5, добавьте gradient clipping |
| Модель забывает базовые знания | Слишком сильное отклонение от референсной модели | Используйте меньший beta, добавьте SFT loss в смеси |
Практические советы
- Начните с малого: Протестируйте DPO на небольшой модели (например, 7B параметров) перед масштабированием
- Качество данных критично: Лучше 1000 качественных примеров предпочтений, чем 10000 сомнительных
- Используйте существующие инструменты: Библиотеки вроде TRL (Transformer Reinforcement Learning) от Hugging Face уже имеют готовую реализацию DPO
- Мониторьте KL-дивергенцию: Следите за тем, насколько ваша модель отклоняется от референсной
- Тестируйте промежуточные чекпоинты: Сохраняйте модель каждые несколько шагов и оценивайте качество генерации
Сравнение DPO vs PPO: что выбрать?
| Критерий | DPO | PPO |
|---|---|---|
| Сложность реализации | Низкая (1 формула) | Высокая (множество компонентов) |
| Вычислительные требования | Низкие | Высокие |
| Стабильность обучения | Высокая | Низкая (требует тонкой настройки) |
| Качество результатов | Сопоставимо или лучше | Зависит от настройки |
| Необходимость в reward модели | Нет | Да |
FAQ: часто задаваемые вопросы
1. Нужно ли мне предварительно обучать модель с помощью SFT перед DPO?
Да, обязательно. DPO предполагает, что у вас уже есть референсная модель, обученная с помощью Supervised Fine-Tuning (SFT). DPO лишь "настраивает" эту модель в соответствии с человеческими предпочтениями.
2. Сколько данных нужно для DPO?
Для небольших моделей (до 13B параметров) часто достаточно 1000-5000 пар предпочтений. Для больших моделей (70B+) может потребоваться 10,000+ примеров. Качество важнее количества!
3. Можно ли использовать DPO для многокритериальной оптимизации?
Да, но потребуется модификация. Один из подходов — использовать взвешенную комбинацию нескольких DPO loss функций для разных критериев (например, полезность, безопасность, креативность).
4. Как выбрать оптимальный beta?
Начните с beta=0.1 и экспериментируйте в диапазоне 0.01-0.5. Меньшие значения beta дают более консервативные обновления, большие — более агрессивные. Следите за KL-дивергенцией между текущей и референсной моделью.
5. Что делать, если у меня нет данных с человеческими предпочтениями?
Можно использовать:
- Синтетические данные: Генерировать предпочтения с помощью более мощной LLM (например, GPT-4)
- Правила: Создавать автоматические оценки на основе эвристик (длина ответа, наличие ключевых слов и т.д.)
- Публичные датасеты: Anthropic HH-RLHF, OpenAI WebGPT и другие
Важно: Качество синтетических данных обычно ниже человеческих аннотаций. Если вы планируете использовать модель в production, инвестируйте в сбор человеческих предпочтений.
Заключение
DPO — это революционный подход к выравниванию LLM, который превращает сложный, многоэтапный процесс RLHF в простую задачу оптимизации с одной формулой. Больше не нужно:
- Обучать отдельную reward модель
- Балансировать сложные гиперпараметры PPO
- Бороться с нестабильностью обучения
- Тратить огромные вычислительные ресурсы
С DPO вы можете выравнивать свои LLM быстрее, дешевле и стабильнее. Это особенно важно для локального запуска моделей, где ресурсы ограничены. Если вы хотите глубже погрузиться в работу с локальными LLM, рекомендую ознакомиться с нашими статьями про сравнение инструментов для локального запуска и лучшие open-source инструменты для работы с LLM.
Попробуйте DPO на своих проектах — вы удивитесь, насколько это проще, чем кажется. А если столкнетесь с проблемами интерпретационного дрейфа или других артефактов обучения — теперь у вас есть инструмент, который помогает контролировать поведение модели напрямую через человеческие предпочтения.