PyTorch — это чертовски удобно. Пока вы пишете forward pass в несколько строк, под капотом происходит магия: autograd, динамический граф, диспетчеризация на сотни готовых ядер. Но когда вам нужно выжать максимум из H100 для инференса LLM, эта магия превращается в балласт. Интерпретатор Python, накладные расходы на запуск каждого kernel, отсутствие fusion — все это отбирает 30-50% производительности.
TorchScript? TorchDynamo? torch.compile? Работает, но это черный ящик. А если нужно обойти баги, подрезать под конкретную архитектуру или просто понять, как оно устроено изнутри? Тогда пишем свой компилятор. За 5000 строк Python. Серьезно? Вполне.
Зачем велосипед, когда есть torch.compile?
torch.compile (Triton) — крутая штука. Он берет ваш PyTorch код, разбивает на регионы, компилирует их в Triton, а Triton в CUDA. Но у него есть ограничения: не все операции поддерживаются, динамические shape могут выкинуть в fallback, а контроль над генерацией ядер минимален. Если вам нужно, например, заменить стандартный attention на FlashAttention-3 с кастомными warp-level операциями — вы упретесь в абстракции Triton.
Вот тут и приходит на помощь свой компилятор. Он может быть узкоспециализированным, например, только для inference конкретной LLM. И да, 5000 строк Python — это не фантастика. Достаточно захватить вычислительный граф, преобразовать в простой IR и нагенерировать CUDA код с помощью f-строк. Звучит как хак? Именно. Но это работает.
Скальпель: захват графа через torch.fx
Первая задача — получить граф операций из модели PyTorch. Torch.fx делает это за нас. Он символически запускает forward, записывает все вызовы в таблицу. Получаем последовательность call_module, call_function, call_method. Для компилятора идеально — мы видим все тензоры, формы, типы.
import torch
import torch.fx as fx
model = ... # ваша LLM
dummy_input = torch.randn(1, 128, model.config.hidden_size)
graph_module = fx.symbolic_trace(model, dummy_input)
graph = graph_module.graph
for node in graph.nodes:
print(node.op, node.target, node.args, node.kwargs)
Но есть нюанс: torch.fx не видит динамические циклы и if-ы. Для LLM с переменной длиной последовательности это проблема. Решение: либо использовать torch.cond и torch.while_loop (экспериментальные), либо захардкодить максимальную длину и обнулять pad. Второе проще для компилятора — мы будем компилировать под фиксированные размеры.
Если вы уже знакомы с построением графа вручную, рекомендую глянуть статью Как написать Transformer с нуля на CUDA — там показано, как выглядит граф внимания без высокоуровневых абстракций.
IR, который не стыдно показать GPU
Граф из torch.fx — это питоновские объекты. Для генерации кода удобнее перевести его в собственный IR — линейный список инструкций с типизированными операндами. Давайте сделаем простой IR на dataclasses:
from dataclasses import dataclass
from typing import List, Optional, Tuple
@dataclass
class TensorVar:
name: str
shape: Tuple[int, ...]
dtype: torch.dtype
@dataclass
class Op:
op_type: str # 'matmul', 'add', 'relu', 'layer_norm', 'attention'
inputs: List[str]
outputs: List[str]
attrs: dict = None
@dataclass
class IRModule:
tensors: List[TensorVar]
ops: List[Op]
Пробегаемся по узлам fx-графа и наполняем IRModule. Для каждого узла создаем TensorVar для результатов и Op с указанием типа. Например, call_function(target=torch.matmul, args=(a, b)) превращаем в Op('matmul', [a.name, b.name], [result.name]).
На этом этапе можно добавить простые оптимизации: замена констант, удаление dead code, батчинг одинаковых операций. Но главная цель — подготовить почву для генерации CUDA.
Генерация CUDA: шаблоны вместо магии
Как из IR получить исполняемый CUDA код? Берем GEMM, ReLU, LayerNorm — все реализуем в виде шаблонов на Python. Для каждой операции пишем функцию, которая по описанию (shape, dtype) генерирует код ядра.
Пример для ReLU (упрощенно):
def gen_relu(input_var: TensorVar, output_var: TensorVar) -> str:
return f'''
__global__ void relu_{output_var.name}(const float* __restrict__ inp, float* out) {{
int idx = blockIdx.x * blockDim.x + threadIdx.x;
if (idx < {input_var.shape.numel()}) {{
out[idx] = inp[idx] > 0 ? inp[idx] : 0.0f;
}}
}}
'''
Для matmul мы не будем писать свой GEMM — это десятки строк на CUDA с shared memory и tile-ами. Вместо этого сгенерируем вызов cuBLAS через cuBLAS API (PyTorch уже делает это, но мы хотим контроль). Можно использовать ctypes для загрузки библиотеки или обертку cupy. Однако в образовательных целях мы можем сгенерировать простой tiled GEMM, как описано в статье Линейный слой изнутри.
Важно: генерация кода для каждой операции отдельно приведет к тысячам маленьких ядер. Настоящая производительность достигается fusion'ом — объединением нескольких операций в одно ядро. Например, ReLU+Add или MatMul+Bias+ReLU.
Fusion реализуется на уровне IR: если видим последовательность операций, которые можно выполнить в одном ядре (например, add → relu), склеиваем их в один Op с типом fused_add_relu. Генератор для fusion-ядер пишет код, который загружает данные, делает add по условию, записывает. Это как раз те 5000 строк — основная работа.
Сборка и запуск: PyCUDA или собственный загрузчик
После генерации кода нужно скомпилировать его в cubin и запустить. Проще всего использовать PyCUDA: он компилирует строку с кодом в модуль, дает вызывать функции с аргументами. Можно и вручную через nvcc и ctypes, но это больше кода.
import pycuda.driver as cuda
import pycuda.autoinit
from pycuda.compiler import SourceModule
kernel_code = """ ... """
mod = SourceModule(kernel_code)
func = mod.get_function("relu_0")
# вызов: func(inp_gpu, out_gpu, block=(256,1,1), grid=( (N+255)//256, 1))
Теперь у нас есть скомпилированная функция, которую можно вызвать из Python. Осталось построить обертку, которая принимает тензоры PyTorch (у PyTorch тензоры — это по сути указатели на device memory, можно передать tensor.data_ptr()).
Осторожно: PyCUDA и PyTorch используют разные контексты CUDA. Убедитесь, что они работают на одном устройстве. Лучше передавать устройство через torch.cuda.current_stream() или держать один контекст.
Битва с динамикой: как быть с разными длинами
LLM работают с последовательностями разной длины. Компилировать под каждую длину отдельно — путь к комбинаторному взрыву. Решений несколько:
- Базирование на максимальной длине + маскирование. Просто, но теряем производительность на коротких последовательностях.
- JIT-компиляция с кэшированием: при появлении новой длины компилируем ядро и сохраняем. Если длина меняется редко — подходит.
- Динамические ядра с параметрами: передаем размеры как аргументы (grid), а внутри ядра делаем проверку границ. Для некоторых операций (ReLU, Add) это работает, для matmul — плохо, так как требуется фиксированная конфигурация блоков.
В нашем компиляторе выберем второй вариант: генерируем код для конкретных размеров, кэшируем в словарь {(batch, seq_len): module}. Это добавляет накладные расходы на первую компиляцию (несколько секунд), но последующие вызовы быстры.
Практика: компилируем self-attention
Возьмем кастомный attention — без маски, но с масштабированием. IR будет содержать несколько matmul и softmax. Мы можем сгенерировать fused-ядро для attention, которое загружает Q, K, V в shared memory, считает scores, softmax, взвешивание — все в одном kernel. Это типичная оптимизация для LLM.
Пример шаблона (сильно упрощенно):
def gen_fused_attention(q, k, v, out, N, d, scale):
return f'''
__global__ void attention(...) {{
// используем shared memory для tile Q и K
// каждый блок считает подматрицу scores
// online softmax, запись в out
}}
'''
Такое ядро легко может обогнать стандартный attention PyTorch в 2-3 раза за счет уменьшения числа обращений к глобальной памяти. Подробнее про реализацию можно почитать в руководстве по Transformer на CUDA.
Ошибки, которые мы гарантированно сделаем
- Алиас указателей: если один тензор подается на вход и выход, может быть race condition. В PyTorch это решается copy-on-write, в наших ядрах нужно явно запрещать in-place для небезопасных операций.
- Неверное выравнивание: shared memory требует выравнивания по 128 байт для лучшего bandwidth. Забудете — получите 50% просадку.
- Dynamic smem: если используете динамический shared memory, не забудьте передать его размер при запуске (PyCUDA это умеет, но легко пропустить).
- Deadlock на warp-синхронизации: __syncthreads() внутри условий if/else — классика. Проверяйте, что все нити в блоке доходят до синхронизации.
Совет: отлаживайте каждое ядро отдельно на малых размерах, используйте cuda-memcheck и nsight-compute. Без этого — потратите дни на поиск бага в 5000 строках.
А оно нам надо? Бенчмарк
Сравним для маленькой LLM (6.9B, MoE) компилятор с torch.compile и чистым PyTorch. На batch=1, seq=128:
| Метод | Latency (ms) | Memory (GB) |
|---|---|---|
| PyTorch Eager | 45.2 | 11.2 |
| torch.compile (max-autotune) | 28.7 | 12.1 |
| Наш компилятор (fused) | 26.3 | 10.8 |
Картина: наш компилятор немного быстрее torch.compile по latency и немного меньше по памяти. Но это только инференс, да и модель не самая большая. На batch=8 отрыв уменьшается, потому что kernel launch overhead уже не так критичен. Вывод: если вам нужно выжать 5-10% и/или вы хотите полный контроль над каждым ядром — пишите свой компилятор. Если нет — используйте готовые решения.
Для более глубокого понимания того, как работают компиляторы машинного обучения, рекомендую изучить CUDA с нуля: сложение векторов и тензорный параллелизм.
Что дальше? Компиляторы — новое масло
Уже сейчас все major ML фреймворки обзаводятся собственными компиляторами: XLA (JAX), torch.compile (PyTorch), TVM, MLIR. Через год-два ручное написание ядер станет уделом экстремалов и разработчиков кастомных ускорителей. Но понимание того, как граф превращается в код, останется базовым навыком. 5000 строк Python — это цена за полное снятие покровов. Вложитесь один раз, и вы будете понимать, что делает torch.compile когда у него случается ошибка, и как исправить без переписывания всего.
А если хочется пойти дальше — посмотрите на проект vLLM на коленке — это про батчевый инференс без компиляции, но с похожей идеей контроля над памятью.