Перестаньте гадать, начните верифицировать
Помните то чувство, когда вы запускаете RL-тренировку LLM, а модель начинает выдавать "2+2=5, потому что я так чувствую"? Обычный reinforcement learning с человеческим фидбеком (RLHF) часто страдает от шумных наград — люди непоследовательны, а reward model учится предвзятостям. Но есть способ чище: Verifiable Rewards (RLVR) + GRPO. Никаких оценок "по ощущениям", только железобетонные факты: правильный ответ — да, неправильный — нет.
В этой статье я покажу, как заставить LLM решать математические задачи (датасет GSM8K) с помощью GRPO на AWS SageMaker. И да, я наступил на все грабли, чтобы вы — нет.
Кому это нужно? Инженерам, которые уже попробовали SFT и DPO, но хотят выжать максимум из модели без дорогого человеческого фидбека. Если вы работаете с задачами, где ответ можно проверить автоматически (математика, код, фактология) — это ваш метод.
GRPO без лишнего жира
Классический PPO требует отдельную модель критика (value function), что удваивает память и время. Group Relative Policy Optimization (GRPO) — это "PPO на минималках": вместо критика он использует группу сэмплов, вычисляя преимущество (advantage) как нормализованную награду внутри группы. Относительно исходной статьи DeepSeekMath (разбор здесь) мы добавим верифицируемые награды: например, совпадение с правильным ответом из GSM8K.
Основной движок — SageMaker Training с PyTorch и Hugging Face Transformers (последняя версия 4.50+). Для ускорения используем Flash Attention 2 (доступен на инстансах A100).
Три шага, которые превратят хаос в контролируемое обучение
1 Подготовка данных и верификатора
GSM8K — 8500 задач с пошаговыми решениями. Но нас интересует только финальный ответ (число). Парсим датасет, извлекаем ответы после "####".
from datasets import load_dataset
import re
ds = load_dataset("gsm8k", "main", split="train")
def extract_answer(text):
match = re.search(r"####\s*([+-]?[\d.,]+)", text)
return match.group(1).replace(",", "") if match else None
verified = [(q, extract_answer(a)) for q, a in zip(ds["question"], ds["answer"])]
verified = [(q, a) for q, a in verified if a is not None]
Верификатор — простая функция: совпало число — reward=1, иначе 0. Без плавающих шкал, без soft labels. Жёстко, как арбитр на футболе.
Ошибка новичка: не учитывать формат ответа. Модель может написать "Ответ: 42" вместо просто "42". Добавьте приведение к числовому типу и обработку пробелов. Иначе получите reward=0 за правильный ответ.
2 Реализация GRPO-тренировки
Берём базовую модель, например, Mistral-7B-v0.3 (на май 2026 — всё ещё актуальна для экспериментов). Используем библиотеку TRL (версия 0.16) с кастомным коллбеком для RLVR.
from trl import GRPOConfig, GRPOTrainer
from transformers import AutoModelForCausalLM, AutoTokenizer
from reward_funcs import verify_answer # наша функция
model_name = "mistralai/Mistral-7B-v0.3"
model = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype="bfloat16", device_map="auto")
tokenizer = AutoTokenizer.from_pretrained(model_name)
def reward_fn(prompts, responses, **kwargs):
# responses — список сгенерированных текстов
rewards = []
for prompt, response in zip(prompts, responses):
target = extract_answer_ground_truth(prompt) # из датасета
pred = extract_answer_from_response(response)
rewards.append(1.0 if pred == target else 0.0)
return rewards
training_args = GRPOConfig(
output_dir="./grpo-gsm8k",
num_train_epochs=1,
per_device_train_batch_size=2,
gradient_accumulation_steps=4,
group_size=8, # размер группы для GRPO
learning_rate=1e-6,
logging_steps=10,
save_steps=500,
)
trainer = GRPOTrainer(
model=model,
reward_funcs=[reward_fn],
args=training_args,
train_dataset=dataset,
tokenizer=tokenizer,
)
trainer.train()
Ключевые моменты:
- Group size — сколько сэмплов генерировать для одного промпта. 8 — золотая середина. Меньше — высокий шум, больше — память плачет.
- Reward normalisation — GRPO внутри normaliseet rewards по группе. Не надо делать это вручную.
- KL penalty — чтобы модель не забыла язык. TRL по умолчанию добавляет KL-дивергенцию к policy loss. Крутим коэффициент через
kl_coef.
ml.g5.12xlarge (4x A10G) можно уместить Mistral-7B с batch=2 и group=8, используя gradient checkpointing. Пример — CodeFu-7B на veRL.3 Запуск на SageMaker с гиперпараметрами
SageMaker ожидает скрипт входа. Создаём train.py с кодом выше и запускаем через SDK:
from sagemaker.pytorch import PyTorch
estimator = PyTorch(
entry_point="train.py",
source_dir="./src",
role=role,
instance_count=1,
instance_type="ml.g5.12xlarge",
framework_version="2.5.1",
py_version="py311",
hyperparameters={
"epochs": 1,
"batch_size": 2,
"group_size": 8,
"lr": 1e-6,
"kl_coef": 0.04,
},
environment={
"HF_TOKEN": "hf_yourtoken",
"SAGEMAKER_ENV": "1"
},
debugger_hook_config=False,
)
estimator.fit({"training": "s3://my-bucket/gsm8k-parquet"})
Не забываем загрузить датасет в Parquet — это ускорит чтение.
Ловушки, которые подставят подножку
Моё личное топ-3 геморроя:
- Reward hacking. Модель учится генерировать ответ, который случайно совпадает с числом, но игнорирует логику. Решение: few-shot примеры с требованием "Show reasoning, then answer". Добавьте в промпт "Let's think step by step".
- Out-of-memory. GRPO генерирует множество сэмплов — каждый хранится в памяти. Используйте
torch.compileи Flash Attention. Если не помогает — уменьшите group_size до 4. - Overfitting на верификаторе. После 3-4 эпох модель может запомнить ответы GSM8K. Лучше брать небольшую эпоху (1) и следить за accuracy на валидации.
Типичный баг: вы забыли установить trust_remote_code=True для модели с кастомными слоями. SageMaker воркер молча упадёт с непонятной ошибкой. Проверяйте логи CloudWatch.
А что, если прикрутить few-shot прямо в обучение?
Идея: в reward функцию добавить не только бинарную проверку, но и штраф за отсутствие рассуждений. Например, если модель не выводит цепочку \n\n, даём -0.5. Это направит генерацию в нужное русло. Пример из RLVR с GRPO: от теории к коду.
Кроме того, можно использовать logit-level verification: проверять не только финальный токен, а всю последовательность шагов через программу-верификатор (как в CodeReward). Но это уже оверкилл для математики.
Когда это не работает (и что делать)
Если через 500 шагов accuracy на GSM8K не выросла — скорее всего проблема в:
- Слишком низкий learning rate. Попробуйте 5e-6.
- Слабый reward сигнал. Верификатор возвращает только 0 или 1 — это жёстко. Добавьте частичный reward за правильные промежуточные шаги (если можете их извлечь).
- Модель сломала токенизатор. Проверьте, корректно ли токенизируются числа. Иногда модель пишет " 42" с пробелом, а верификатор ждёт "42".
Более хитрый кейс: отравление группы. Если в группе из 8 сэмплов один случайно дал правильный ответ, остальные получат отрицательное преимущество. Это может затормозить обучение. Решение: увеличить group_size до 16 или использовать clipped advantage.
Неочевидный совет под занавес
Не храните все сгенерированные сэмплы в памяти — используйте replay buffer с off-policy correction, как в IMPALA. SageMaker позволяет монтировать EBS-диски с большим IOPS, так что запись на диск не будет узким местом. И, ради всего святого, логируйте все награды и распределение преимуществ. Если среднее преимущество держится около нуля — вы зря жжёте GPU.
Альтернативный подход — запустить Ablation Study, как описано в GRPO с нуля: ablation studies на RTX 4090. Там же разобраны приёмы экономии памяти, которые работают и на SageMaker.
Последний секрет: используйте @torch.inference_mode() для генерации сэмплов в reward функции. Если случайно оставить градиенты — память улетит в космос. Я однажды потратил $200 за час из-за этой глупости.
Теперь у вас есть всё, чтобы запустить RLVR с GRPO на SageMaker. Не наступайте на те же грабли, что и я. И помните: верифицируемая награда — это не магия, а инженерия. Считайте числа, проверяйте логи, и модель ответит вам точностью.