Когда стандартное внимание стало слишком дорогим
К 2026 году все устали от компромиссов. Multi-Head Attention (MHA) жрет память как не в себя, Grouped Query Attention (GQA) экономит, но теряет качество, а Multi-Query Attention (MQA) вообще превращает модель в калеку при длинных контекстах. DeepSeek посмотрела на эту ситуацию и сказала: "Хватит это терпеть".
Multi-Head Latent Attention (MLA) появилась не из воздуха. Это ответ на конкретную боль: как сохранить выразительность MHA при стоимости MQA? Ответ оказался на поверхности, но до него нужно было додуматься.
На 25.01.2026 MLA используется в DeepSeek-V3.2 и более новых моделях. Архитектура доказала свою эффективность в production-среде с контекстом до 1М токенов.
Суть трюка: латентные ключи вместо компромиссов
Представьте, что у вас есть 32 головы внимания. В MHA каждая голова получает свои уникальные K и V. В GQA вы группируете их по 4, в MQA вообще оставляете одну на всех. MLA предлагает третий путь: создать небольшой латентный пул ключей и значений, из которого все головы будут черпать информацию.
Математически это выглядит так:
# Вместо стандартного внимания:
# Q_i, K_i, V_i для каждой головы i
# MLA делает:
K_latent = проекция(вход) # [batch, seq_len, latent_dim]
V_latent = проекция(вход) # [batch, seq_len, latent_dim]
# Каждая голова получает свои Q, но делит латентные K/V
attention_i = softmax(Q_i @ K_latent.T / sqrt(d_k)) @ V_latent
Звучит просто? Так и есть. Но простота обманчива. Латентый пул обычно в 4-8 раз меньше полного набора ключей, что сразу экономит 75-87% памяти KV-cache.
Почему это работает лучше GQA и MQA?
GQA делит головы на группы, и каждая группа получает свои K/V. Проблема в том, что информация между группами не смешивается. Если важный паттерн попал в "неправильную" группу, другие головы его не увидят.
MQA еще хуже: одна копия K/V на все головы. Для коротких контекстов работает, но при 100К+ токенах модель начинает "путаться" - разные аспекты входных данных конкурируют за одно и то же представление.
MLA решает обе проблемы. Латентый пул - это не просто сжатие, а переработка информации. Все головы видят все латентные ключи, но могут "фокусироваться" на разных аспектах через свои запросы. Это как если бы у вас был общий словарь, но каждый исследователь искал в нем свои термины.
| Архитектура | Память KV-cache | Качество (long-context) | Использование в 2026 |
|---|---|---|---|
| MHA (базовая) | 100% (референс) | Отличное | Только в маленьких моделях |
| GQA | 25-50% от MHA | Хорошее | Llama 4, Claude 4 |
| MQA | ~3% от MHA | Среднее (деградация на 100К+) | Устаревает |
| MLA (DeepSeek) | 12-25% от MHA | Близко к MHA | DeepSeek-V3.2+, новые китайские LLM |
Полная реализация на PyTorch (актуально на 25.01.2026)
Вот как выглядит MLA в коде. Заметьте: я использую PyTorch 2.4+ и новые оптимизации - старые туториалы 2024 года уже неактуальны.
import torch
import torch.nn as nn
import torch.nn.functional as F
from typing import Optional, Tuple
class MultiHeadLatentAttention(nn.Module):
"""
MLA реализация для PyTorch 2.4+
Поддерживает KV-cache и flash attention 3
"""
def __init__(
self,
embed_dim: int,
num_heads: int,
latent_ratio: float = 0.125, # 1/8 от полного размера
dropout: float = 0.0,
bias: bool = True,
device=None,
dtype=None
):
super().__init__()
self.embed_dim = embed_dim
self.num_heads = num_heads
self.head_dim = embed_dim // num_heads
# Латентый размер - ключевой параметр
self.latent_dim = int(embed_dim * latent_ratio)
self.latent_heads = max(1, int(num_heads * latent_ratio))
# Проекции
self.q_proj = nn.Linear(
embed_dim, embed_dim, bias=bias, device=device, dtype=dtype
)
self.k_latent_proj = nn.Linear(
embed_dim, self.latent_dim, bias=bias, device=device, dtype=dtype
)
self.v_latent_proj = nn.Linear(
embed_dim, self.latent_dim, bias=bias, device=device, dtype=dtype
)
self.out_proj = nn.Linear(
embed_dim, embed_dim, bias=bias, device=device, dtype=dtype
)
self.dropout = dropout
# Для KV-cache
self.kv_cache: Optional[Tuple[torch.Tensor, torch.Tensor]] = None
def forward(
self,
x: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
use_cache: bool = False,
past_kv: Optional[Tuple[torch.Tensor, torch.Tensor]] = None
) -> Tuple[torch.Tensor, Optional[Tuple[torch.Tensor, torch.Tensor]]]:
"""
x: [batch, seq_len, embed_dim]
Возвращает: output, (k_cache, v_cache)
"""
batch_size, seq_len, _ = x.shape
# 1. Проекции
Q = self.q_proj(x) # [batch, seq_len, embed_dim]
K_latent = self.k_latent_proj(x) # [batch, seq_len, latent_dim]
V_latent = self.v_latent_proj(x) # [batch, seq_len, latent_dim]
# 2. Reshape для многоголового внимания
Q = Q.view(batch_size, seq_len, self.num_heads, self.head_dim)
Q = Q.transpose(1, 2) # [batch, heads, seq_len, head_dim]
# Латентные K/V не делим по головам
K_latent = K_latent.transpose(1, 2) # [batch, latent_dim, seq_len]
V_latent = V_latent.transpose(1, 2) # [batch, latent_dim, seq_len]
# 3. Объединение с прошлыми KV (инкрементальный decoding)
if past_kv is not None:
past_k, past_v = past_kv
K_latent = torch.cat([past_k, K_latent], dim=-1)
V_latent = torch.cat([past_v, V_latent], dim=-1)
cache_seq_len = K_latent.size(-1)
else:
cache_seq_len = seq_len
# 4. Внимание (используем torch.nn.functional.scaled_dot_product_attention)
# В PyTorch 2.4+ это оптимальнее ручной реализации
attn_output = F.scaled_dot_product_attention(
Q,
K_latent.unsqueeze(1).expand(-1, self.num_heads, -1, -1),
V_latent.unsqueeze(1).expand(-1, self.num_heads, -1, -1),
attn_mask=attention_mask,
dropout_p=self.dropout if self.training else 0.0,
is_causal=attention_mask is None
)
# 5. Собираем обратно
attn_output = attn_output.transpose(1, 2).contiguous()
attn_output = attn_output.view(batch_size, seq_len, self.embed_dim)
output = self.out_proj(attn_output)
# 6. KV-cache для следующего токена
next_kv = None
if use_cache:
next_kv = (K_latent, V_latent)
return output, next_kv
def reset_cache(self):
"""Сброс KV-cache между запросами"""
self.kv_cache = None
Важно: не путайте latent_ratio с group_size в GQA. В MLA все головы видят все латентные ключи, а в GQA каждая группа видит только свои. Это фундаментальное различие в архитектуре.
Оптимизации KV-cache: где реальная экономия
Вот что происходит с памятью при контексте в 100К токенов для модели с 32 головами и размерностью 4096:
- MHA: 100К × 32 × 2 × 128 = 819.2 МБ (убийственно)
- GQA (groups=4): 100К × 8 × 2 × 128 = 204.8 МБ (лучше, но все равно много)
- MLA (latent_ratio=0.125): 100К × 4 × 2 × 128 = 102.4 МБ (в 8 раз меньше MHA!)
Но экономия памяти - не единственный плюс. Меньший KV-cache означает:
- Быстрее загрузка на GPU (особенно важно для масштабирования на несколько карт)
- Меньше трафика между CPU/GPU при streaming inference
- Возможность держать больше concurrent сессий в памяти
Где MLA дает максимальный выигрыш?
Не везде MLA одинаково хороша. Из тестов DeepSeek-V3.2 видно:
Конкретные сценарии:
- RAG-системы с большими базами документов - экономия памяти позволяет обрабатывать больше источников
- Мультимодальные модели - латентные ключи хорошо работают с fused представлениями текста и изображений
- Streaming inference - меньший KV-cache = меньше latency между токенами
- Fine-tuning на длинных контекстах - можно использовать большие batch sizes
Подводные камни и ограничения
MLA не серебряная пуля. Вот что нужно учитывать:
1. Требуется калибровка latent_ratio. Слишком маленький латентный пул (0.0625) ухудшает качество на сложных reasoning задачах. Слишком большой (0.25) сводит на нет преимущества.
2. Несовместимость со старыми оптимизациями. Некоторые kernel'ы для flash attention требуют адаптации под MLA. Если вы используете кастомные CUDA ядра - придется переписывать.
3. Сложнее отлаживать. Когда все головы используют общий пул, труднее понять, какая именно "сломалась". Инструменты вроде TraceML становятся обязательными.
Стоит ли переходить на MLA в 2026?
Если вы начинаете новый проект с нуля - однозначно да. Архитектура проверена в бою на DeepSeek-V3.2, и результаты говорят сами за себя. Но если у вас уже работающая модель на GQA/MQA - считайте ROI.
Переписывание attention слоев - это:
- Изменение архитектуры и checkpoint'ов
- Переобучение или хотя бы дообучение
- Адаптация оптимизаций inference (vLLM, TensorRT-LLM)
- Тестирование на всех edge cases
Для экспериментов можно взять готовую реализацию из репозиториев DeepSeek и адаптировать под свои нужды.
Что будет дальше с архитектурами внимания?
MLA - не конечная точка. Уже видны следующие шаги:
1. Динамический latent_ratio - модель сама решает, сколько латентных ключей ей нужно для текущего контекста
2. Специализированные латентные пулы - отдельные пулы для разных типов информации (факты, reasoning, стиль)
3. Гибридные подходы - комбинация MLA с другими оптимизациями вроде Tuneable Attention
К 2027 году мы, скорее всего, увидим attention-механизмы, которые вообще не хранят полный KV-cache, а динамически регенерируют его по мере необходимости. Но пока MLA - лучший баланс между качеством и эффективностью.
Главный урок MLA прост: иногда чтобы решить проблему, нужно не выбирать между вариантами, а придумать третий путь. DeepSeek это сделала. Теперь очередь за остальными.