MLA от DeepSeek: архитектура, код PyTorch и оптимизация KV-cache 2026 | AiManual
AiManual Logo Ai / Manual.
25 Янв 2026 Инструмент

Multi-Head Latent Attention: как DeepSeek переизобрела механизм внимания и почему это работает быстрее

Полный разбор Multi-Head Latent Attention от DeepSeek с кодом на PyTorch. Оптимизации KV-cache, сравнение с MHA/GQA/MQA и практическая реализация.

Когда стандартное внимание стало слишком дорогим

К 2026 году все устали от компромиссов. Multi-Head Attention (MHA) жрет память как не в себя, Grouped Query Attention (GQA) экономит, но теряет качество, а Multi-Query Attention (MQA) вообще превращает модель в калеку при длинных контекстах. DeepSeek посмотрела на эту ситуацию и сказала: "Хватит это терпеть".

Multi-Head Latent Attention (MLA) появилась не из воздуха. Это ответ на конкретную боль: как сохранить выразительность MHA при стоимости MQA? Ответ оказался на поверхности, но до него нужно было додуматься.

На 25.01.2026 MLA используется в DeepSeek-V3.2 и более новых моделях. Архитектура доказала свою эффективность в production-среде с контекстом до 1М токенов.

Суть трюка: латентные ключи вместо компромиссов

Представьте, что у вас есть 32 головы внимания. В MHA каждая голова получает свои уникальные K и V. В GQA вы группируете их по 4, в MQA вообще оставляете одну на всех. MLA предлагает третий путь: создать небольшой латентный пул ключей и значений, из которого все головы будут черпать информацию.

Математически это выглядит так:

# Вместо стандартного внимания:
# Q_i, K_i, V_i для каждой головы i

# MLA делает:
K_latent = проекция(вход)  # [batch, seq_len, latent_dim]
V_latent = проекция(вход)  # [batch, seq_len, latent_dim]

# Каждая голова получает свои Q, но делит латентные K/V
attention_i = softmax(Q_i @ K_latent.T / sqrt(d_k)) @ V_latent

Звучит просто? Так и есть. Но простота обманчива. Латентый пул обычно в 4-8 раз меньше полного набора ключей, что сразу экономит 75-87% памяти KV-cache.

Почему это работает лучше GQA и MQA?

GQA делит головы на группы, и каждая группа получает свои K/V. Проблема в том, что информация между группами не смешивается. Если важный паттерн попал в "неправильную" группу, другие головы его не увидят.

MQA еще хуже: одна копия K/V на все головы. Для коротких контекстов работает, но при 100К+ токенах модель начинает "путаться" - разные аспекты входных данных конкурируют за одно и то же представление.

MLA решает обе проблемы. Латентый пул - это не просто сжатие, а переработка информации. Все головы видят все латентные ключи, но могут "фокусироваться" на разных аспектах через свои запросы. Это как если бы у вас был общий словарь, но каждый исследователь искал в нем свои термины.

Архитектура Память KV-cache Качество (long-context) Использование в 2026
MHA (базовая) 100% (референс) Отличное Только в маленьких моделях
GQA 25-50% от MHA Хорошее Llama 4, Claude 4
MQA ~3% от MHA Среднее (деградация на 100К+) Устаревает
MLA (DeepSeek) 12-25% от MHA Близко к MHA DeepSeek-V3.2+, новые китайские LLM

Полная реализация на PyTorch (актуально на 25.01.2026)

Вот как выглядит MLA в коде. Заметьте: я использую PyTorch 2.4+ и новые оптимизации - старые туториалы 2024 года уже неактуальны.

import torch
import torch.nn as nn
import torch.nn.functional as F
from typing import Optional, Tuple

class MultiHeadLatentAttention(nn.Module):
    """
    MLA реализация для PyTorch 2.4+
    Поддерживает KV-cache и flash attention 3
    """
    
    def __init__(
        self,
        embed_dim: int,
        num_heads: int,
        latent_ratio: float = 0.125,  # 1/8 от полного размера
        dropout: float = 0.0,
        bias: bool = True,
        device=None,
        dtype=None
    ):
        super().__init__()
        self.embed_dim = embed_dim
        self.num_heads = num_heads
        self.head_dim = embed_dim // num_heads
        
        # Латентый размер - ключевой параметр
        self.latent_dim = int(embed_dim * latent_ratio)
        self.latent_heads = max(1, int(num_heads * latent_ratio))
        
        # Проекции
        self.q_proj = nn.Linear(
            embed_dim, embed_dim, bias=bias, device=device, dtype=dtype
        )
        self.k_latent_proj = nn.Linear(
            embed_dim, self.latent_dim, bias=bias, device=device, dtype=dtype
        )
        self.v_latent_proj = nn.Linear(
            embed_dim, self.latent_dim, bias=bias, device=device, dtype=dtype
        )
        self.out_proj = nn.Linear(
            embed_dim, embed_dim, bias=bias, device=device, dtype=dtype
        )
        
        self.dropout = dropout
        
        # Для KV-cache
        self.kv_cache: Optional[Tuple[torch.Tensor, torch.Tensor]] = None
        
    def forward(
        self,
        x: torch.Tensor,
        attention_mask: Optional[torch.Tensor] = None,
        use_cache: bool = False,
        past_kv: Optional[Tuple[torch.Tensor, torch.Tensor]] = None
    ) -> Tuple[torch.Tensor, Optional[Tuple[torch.Tensor, torch.Tensor]]]:
        """
        x: [batch, seq_len, embed_dim]
        Возвращает: output, (k_cache, v_cache)
        """
        batch_size, seq_len, _ = x.shape
        
        # 1. Проекции
        Q = self.q_proj(x)  # [batch, seq_len, embed_dim]
        K_latent = self.k_latent_proj(x)  # [batch, seq_len, latent_dim]
        V_latent = self.v_latent_proj(x)  # [batch, seq_len, latent_dim]
        
        # 2. Reshape для многоголового внимания
        Q = Q.view(batch_size, seq_len, self.num_heads, self.head_dim)
        Q = Q.transpose(1, 2)  # [batch, heads, seq_len, head_dim]
        
        # Латентные K/V не делим по головам
        K_latent = K_latent.transpose(1, 2)  # [batch, latent_dim, seq_len]
        V_latent = V_latent.transpose(1, 2)  # [batch, latent_dim, seq_len]
        
        # 3. Объединение с прошлыми KV (инкрементальный decoding)
        if past_kv is not None:
            past_k, past_v = past_kv
            K_latent = torch.cat([past_k, K_latent], dim=-1)
            V_latent = torch.cat([past_v, V_latent], dim=-1)
            cache_seq_len = K_latent.size(-1)
        else:
            cache_seq_len = seq_len
        
        # 4. Внимание (используем torch.nn.functional.scaled_dot_product_attention)
        # В PyTorch 2.4+ это оптимальнее ручной реализации
        attn_output = F.scaled_dot_product_attention(
            Q, 
            K_latent.unsqueeze(1).expand(-1, self.num_heads, -1, -1),
            V_latent.unsqueeze(1).expand(-1, self.num_heads, -1, -1),
            attn_mask=attention_mask,
            dropout_p=self.dropout if self.training else 0.0,
            is_causal=attention_mask is None
        )
        
        # 5. Собираем обратно
        attn_output = attn_output.transpose(1, 2).contiguous()
        attn_output = attn_output.view(batch_size, seq_len, self.embed_dim)
        
        output = self.out_proj(attn_output)
        
        # 6. KV-cache для следующего токена
        next_kv = None
        if use_cache:
            next_kv = (K_latent, V_latent)
        
        return output, next_kv
    
    def reset_cache(self):
        """Сброс KV-cache между запросами"""
        self.kv_cache = None

Важно: не путайте latent_ratio с group_size в GQA. В MLA все головы видят все латентные ключи, а в GQA каждая группа видит только свои. Это фундаментальное различие в архитектуре.

Оптимизации KV-cache: где реальная экономия

Вот что происходит с памятью при контексте в 100К токенов для модели с 32 головами и размерностью 4096:

  • MHA: 100К × 32 × 2 × 128 = 819.2 МБ (убийственно)
  • GQA (groups=4): 100К × 8 × 2 × 128 = 204.8 МБ (лучше, но все равно много)
  • MLA (latent_ratio=0.125): 100К × 4 × 2 × 128 = 102.4 МБ (в 8 раз меньше MHA!)

Но экономия памяти - не единственный плюс. Меньший KV-cache означает:

  1. Быстрее загрузка на GPU (особенно важно для масштабирования на несколько карт)
  2. Меньше трафика между CPU/GPU при streaming inference
  3. Возможность держать больше concurrent сессий в памяти

Где MLA дает максимальный выигрыш?

Не везде MLA одинаково хороша. Из тестов DeepSeek-V3.2 видно:

💡
MLA особенно эффективна в задачах с длинным контекстом (100К+ токенов) и multi-turn диалогах. Для коротких промптов (до 4К токенов) разница с GQA минимальна.

Конкретные сценарии:

  • RAG-системы с большими базами документов - экономия памяти позволяет обрабатывать больше источников
  • Мультимодальные модели - латентные ключи хорошо работают с fused представлениями текста и изображений
  • Streaming inference - меньший KV-cache = меньше latency между токенами
  • Fine-tuning на длинных контекстах - можно использовать большие batch sizes

Подводные камни и ограничения

MLA не серебряная пуля. Вот что нужно учитывать:

1. Требуется калибровка latent_ratio. Слишком маленький латентный пул (0.0625) ухудшает качество на сложных reasoning задачах. Слишком большой (0.25) сводит на нет преимущества.

2. Несовместимость со старыми оптимизациями. Некоторые kernel'ы для flash attention требуют адаптации под MLA. Если вы используете кастомные CUDA ядра - придется переписывать.

3. Сложнее отлаживать. Когда все головы используют общий пул, труднее понять, какая именно "сломалась". Инструменты вроде TraceML становятся обязательными.

Стоит ли переходить на MLA в 2026?

Если вы начинаете новый проект с нуля - однозначно да. Архитектура проверена в бою на DeepSeek-V3.2, и результаты говорят сами за себя. Но если у вас уже работающая модель на GQA/MQA - считайте ROI.

Переписывание attention слоев - это:

  • Изменение архитектуры и checkpoint'ов
  • Переобучение или хотя бы дообучение
  • Адаптация оптимизаций inference (vLLM, TensorRT-LLM)
  • Тестирование на всех edge cases

Для экспериментов можно взять готовую реализацию из репозиториев DeepSeek и адаптировать под свои нужды.

Что будет дальше с архитектурами внимания?

MLA - не конечная точка. Уже видны следующие шаги:

1. Динамический latent_ratio - модель сама решает, сколько латентных ключей ей нужно для текущего контекста

2. Специализированные латентные пулы - отдельные пулы для разных типов информации (факты, reasoning, стиль)

3. Гибридные подходы - комбинация MLA с другими оптимизациями вроде Tuneable Attention

К 2027 году мы, скорее всего, увидим attention-механизмы, которые вообще не хранят полный KV-cache, а динамически регенерируют его по мере необходимости. Но пока MLA - лучший баланс между качеством и эффективностью.

Главный урок MLA прост: иногда чтобы решить проблему, нужно не выбирать между вариантами, а придумать третий путь. DeepSeek это сделала. Теперь очередь за остальными.