Ваша нейросеть выдала странный ответ. Вы смотрите на слои, как на черный ящик, и гадаете: что там происходит? Представьте, если бы вы могли заглянуть внутрь, как нейробиологи смотрят на мозг с помощью fMRI. Это не магия, а техника probing с хуками PyTorch.
Сегодня вы построите интерактивный сканер для любой модели на трансформерах за 15 минут. Без сложных теорий, только код и результат.
Зачем это нужно? Черный ящик vs прозрачность
Модели типа BERT или GPT стали слишком сложными, чтобы доверять им вслепую. Ошибка в продакшне может стоить дорого. Например, в системах с LLM непредсказуемость — главный враг.
fMRI-стиль probing — это не просто визуализация. Вы меняете вход, смотрите, как "возбуждаются" нейроны в разных слоях, и находите причинно-следственные связи. Почему модель решила, что "банка" — это сосуд, а не финансовое учреждение? Ответ в активациях.
Не путайте с обычной визуализацией внимания. Attention maps показывают связи между токенами, но скрытые состояния — это смыслы, которые модель построила. Именно там живут представления о мире.
Инструменты: PyTorch, Gradio и один хитрый хук
Вам не нужны тонны кода. Основа — механизм хуков в PyTorch. Хук — это функция, которая цепляется к любому слою модели и забирает его выход (или вход) в момент прямого прохода.
Gradio превратит ваш скрипт в веб-интерфейс за три строки. Это как IDEAV для базы данных, но для нейросетей: интуитивно и сразу.
Что понадобится:
- PyTorch (любая версия)
- Gradio (
pip install gradio) - Библиотека transformers от Hugging Face (для готовых моделей)
- Немного любопытства
1 Установка и подготовка модели
Сначала загрузим модель и токенизатор. Возьмем BERT-base как пример — он есть у всех, и он достаточно мал для быстрых экспериментов.
import torch
from transformers import AutoTokenizer, AutoModel
tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased")
model = AutoModel.from_pretrained("bert-base-uncased", output_hidden_states=True)
model.eval() # переводим в режим оценкиФлаг output_hidden_states=True заставляет модель возвращать все скрытые состояния. Но мы не будем использовать этот встроенный механизм — он дает сразу все слои, а мы хотим контролировать процесс и, возможно, вмешиваться. Поэтому хуки.
2 Вешаем хуки на нужные слои
Допустим, мы хотим следить за активациями в промежуточных слоях. BERT имеет 12 слоев. Мы создадим словарь, куда будем складывать активации по мере прохода.
activations = {}
def get_activation(name):
def hook(model, input, output):
# output - это кортеж, но обычно нам нужен первый элемент
activations[name] = output.detach()
return hook
# Регистрируем хуки на каждый слой encoder'а
for i, layer in enumerate(model.encoder.layer):
layer.register_forward_hook(get_activation(f"layer_{i}"))Теперь при каждом вызове model(input_ids) в словаре activations появятся тензоры для каждого слоя.
register_forward_pre_hook. Но осторожно: это может сломать вычисления, если вы что-то измените.3 Функция прогона и сбора данных
Напишем функцию, которая принимает текст, токенизирует его, прогоняет через модель и возвращает активации в удобном формате.
def probe_text(text):
# Очищаем предыдущие активации
activations.clear()
# Токенизация
inputs = tokenizer(text, return_tensors="pt", truncation=True, max_length=512)
# Прямой проход без вычисления градиентов (экономия памяти)
with torch.no_grad():
outputs = model(**inputs)
# Преобразуем активации в список numpy массивов для визуализации
# Берем среднее по последовательности (убираем dimension токенов)
activation_list = []
for i in range(len(model.encoder.layer)):
layer_act = activations.get(f"layer_{i}")
if layer_act is not None:
# Усредняем по всем токенам (кроме [CLS] и [SEP]?)
# Для простоты берем среднее по всей последовательности
activation_list.append(layer_act.mean(dim=1).squeeze().cpu().numpy())
else:
activation_list.append(np.zeros(model.config.hidden_size))
return activation_listЗдесь мы усредняем активации по всем токенам последовательности, чтобы получить одно значение на нейрон для каждого слоя. Это упрощение, но для визуализации трендов по слоям — то что нужно.
4 Визуализация с Gradio: строим тепловую карту
Теперь самое интересное — интерфейс. Gradio позволяет создать веб-форму с полем ввода и графиком за несколько строк.
import gradio as gr
import numpy as np
import matplotlib.pyplot as plt
def visualize_activations(text):
acts = probe_text(text)
# acts - список из 12 массивов (по слоям), каждый размером hidden_size (768)
# Преобразуем в 2D массив: слои x нейроны
heatmap_data = np.vstack(acts) # shape: (12, 768)
# Создаем тепловую карту
fig, ax = plt.subplots(figsize=(10, 6))
cax = ax.imshow(heatmap_data, aspect="auto", cmap="viridis")
ax.set_xlabel("Нейроны")
ax.set_ylabel("Слои")
ax.set_title(f"Активации для: {text[:50]}...")
plt.colorbar(cax, ax=ax)
return fig
# Запускаем интерфейс
iface = gr.Interface(
fn=visualize_activations,
inputs=gr.Textbox(label="Введите текст"),
outputs=gr.Plot(label="Тепловая карта активаций"),
title="fMRI Probing для BERT",
description="Введите текст и наблюдайте, как активируются нейроны по слоям."
)
iface.launch(share=True) # share=True дает публичную ссылкуЗапустите скрипт. Откроется локальный сервер, а если у вас есть интернет, Gradio сгенерирует публичную ссылку (как в Google Colab). Теперь экспериментируйте.
Что может пойти не так? Ошибки и нюансы
Кажется просто, но подводные камни есть всегда.
| Ошибка | Почему возникает | Как исправить |
|---|---|---|
| Словарь activations не очищается | Хуки пишут в один и тот же словарь при каждом вызове, данные накапливаются | Всегда вызывать activations.clear() в начале probe_text |
| График пустой или все значения нули | Модель в режиме обучения (train), а не eval; или не вызван forward pass | Убедитесь, что model.eval() и torch.no_grad() |
| Память растет с каждым запросом | Тензоры остаются на GPU, сборщик мусора не срабатывает | Используйте .cpu() и del для промежуточных тензоров |
Еще один нюанс: усреднение по токенам стирает информацию. Для более тонкого анализа нужно смотреть на конкретные токены. Например, как активируется нейрон, отвечающий за определение предмета, на слове "яблоко" в предложении "Я съел яблоко".
А что дальше? От визуализации к интервенциям
Тепловая карта — это только начало. Настоящая сила probing в интервенциях: вы меняете активации и смотрите, как меняется выход модели.
Допустим, вы обнаружили нейрон, который активируется на упоминание еды. Вы можете занулить его и посмотреть, перестанет ли модель отвечать на вопросы про рецепты. Это как в обучении физике дефектов, где важно не просто предсказать, а понять механизм.
Код для интервенции через хук:
def intervention_hook(neuron_index, layer_index):
def hook(model, input, output):
# output shape: (batch_size, seq_len, hidden_size)
output[0, :, neuron_index] = 0.0 # зануляем конкретный нейрон
return output
return hook
# Вешаем хук на конкретный слой
handle = model.encoder.layer[5].register_forward_hook(intervention_hook(123, 5))
# Прогоняем модель
# ...
handle.remove() # не забываем убрать хукТеперь вы не просто наблюдатель, а экспериментатор. Можете проверять гипотезы о том, за что отвечают нейроны.
Вместо заключения: где это применить?
Этот метод — не академическая игрушка. В продакшне он помогает отлаживать модели, находить уязвимости (bias), объяснять решения заказчикам. Представьте, что вы делаете Virtual Try-On систему и нужно понять, почему модель путает цвета одежды. Probing покажет, в каком слое возникает ошибка.
Или вы создаете нейросетевой квест и хотите, чтобы модель генерировала последовательные сюжеты. Отслеживая активации, вы поймете, где теряется логика.
Совет напоследок: не ограничивайтесь тепловыми картами. Стройте графики активности конкретных нейронов по слоям, сравнивайте разные входы, ищите корреляции. Через 15 минут у вас будет работающий инструмент. А через час — первые инсайты.
Теперь у вас есть fMRI-сканер для нейросетей. Включите его и посмотрите, что скрывает ваша модель.