Top-down подход
Если читатель был достаточно внимателен, то, наверное, заметил, что в предыдущей заметке я обошел стороной непосредственно блок механизма внимания, точнее сказать, описание было дано методом черного ящика: вот тут такие-то входы, там такие-то выходы. Теперь, внимание, вопрос знатокам: Что лежит в черном ящике? В действительности, крайне важно понимать, что там внутри и логично посвятить данной теме отдельный текст. Понимание механизма внимания определяет ход дальнейших размышлений вплоть до самых передовых архитектур ИИ и поэтому сложно переоценить важность этой темы.
Обычно в технической литературе принято сначала говорить об элементах по отдельности и далее уже давать композиционное представление сочетания таких элементов, это так называемых bottom-up (снизу вверх) подход. Консультантам обычно ближе подход top-down (сверху вниз), что позволяет сначала поймать идею на уровне композиции и далее погружаться в детали. Именно второму варианту я отдаю предпочтение: я начал с архитектуры Pointer Network и теперь хочу разобрать элемент механизма внимания в деталях.
На Хабре есть замечательная статья-перевод Transformer в картинках, где подробно и наглядно по шагам разбираются вычисления для механизма внимания и архитектуры трансформеров, но вот, что касается основных элементов механизма внимания там дословно написано следующее:
…Что представляют собой векторы «запроса», «ключа» и «значения»?
Это абстракции, которые оказываются весьма полезны для понимания и вычисления внимания. Прочитав ниже, как вычисляется внимание, вы будете знать практически все, что вам нужно о роли этих векторов в рассматриваемой нами модели…
Объяснение в стиле: вот формула, считает так, она работает, слово пацана. Лично я в этом вижу очень много недосказанности:
Что за ключи, значения, запросы? Там внутри база данных? Откуда она взялась?
Почему формула с
softmax
? Как она выведена?Есть ли за механизмом внимания какая-то интуиция?
Еще раз подчеркиваю, статья бодрая и очень полезная, обязательна к прочтению, но принцип работы механизма внимания там раскрыт очень поверхностно.
Непараметрическая регрессия Надарая-Ватсона
В свое время редактор Стивена Хоккинга справедливо заметил, что каждая формула в книге отнимает ноль в цифре его аудитории и поэтому в итоговой версии Краткой истории времени было решено оставить лишь одну формулу . Лично я стараюсь также избегать формул, но в данном случае придется все-таки изменить привычке и добавить чуть больше математического контекста, хотя и спрятанного под спойлер.
Многие слышали, что параметрические методы статистики имеют достаточно много ограничений, перечень которых даже запомнить трудно. Конечно, тот же эксель не запрещает строить линейную регрессию на плохих данных, но вот оценки такой модели будут сомнительны. Собственно поэтому придумали непараметрические методы и сопутствующий математический аппарат регрессионной оценки, который был положен в основу механизма внимания.
Немного ядерной математики
Допустим, мы хотим оценить некую регрессионную функцию непараметрическим методом для некоторой независимой переменной в отношении зависимой . Достаточно сложно оценить полностью условную плотность распределения вероятности, но можно попробовать оценить условное среднее для каждого значения , то есть формально:
Это выражение следует из формулы среднего значения случайной величины и определения условной плотности вероятности. Иными словами, регрессионная функция может быть оценена из совместного и маржинального распределений.
Теперь предлагаю вспомнить формулу ядерной оценки плотности распределения:
,
где – ширина окна (bandwidth), которая также определяет ширину столбцов при построении гистограммы. Собственно все непараметрические методы танцуют математику вокруг некой агрегации значений на интервале . Аналогично формула для совместного распределения двух переменных:
подставляем формулы ядерной оценки плотности распределения в первую формулу расчета среднего:
далее выносим все постоянные, независящие от за знак интеграла:
теперь, выражение под интегралом представляет собой формулу взвешенного по ядерной функции среднего для и следовательно итоговое выражение принимает вид:
Формула Надарая-Вотсона для оценки среднего с использованием ядерной регрессии
,
где
Имеем оценку взвешенного среднего целевой переменной по весам , которые меняются в зависимости от значений . Другими словами, оценка Надарая-Вотсона представляет собой локальное среднее в окрестности значений при этом оценка близости осуществляется с помощью ядерной функции . Действительно, если нам нужно сделать оценку целевой переменной для заданного , то стоит поискать ответ где-то в , которые будут близкими к по значениям .
Интуиция и реализация
Немного ершистой математики и вот уже получилось протянуть мостик от функции взвешенного среднего к миру ИИ, однако ощущений на кончиках пальцев пока еще не появилось. Для того чтобы усилить интуицию вокруг механизма внимания и модели Надарая-Вотсона нужно презентовать немного кода на R.
В первую очередь необходимо вспомнить и определить ядерные функции:
library(torch) # тот самый Torch, нужный для тензорных операций
library(ggplot2) # рисовалка графиков
library(patchwork) # для композиций графиков
# Определяем ядерные функции
gaussian <- function(x) torch_exp(-x**2 / 2)
boxcar <- function(x) torch_abs(x) < 1
constant <- function(x) 1 + 0 * x
epanechikov <- function(x) torch_max(1 - torch_abs(x), other = torch_zeros_like(x))
kernel <- list(`gaussian`, `boxcar`, `constant`, `epanechikov`)
kernel_names <- c("gaussian", "boxcar", "constant", "epanechikov")
# Функция чтобы рисовать ядерные функции
plts <- mapply(kernel, kernel_names, SIMPLIFY = F, FUN=function(f, names){
ggplot() +
xlim(-2, 2) +
geom_function(fun = \(x)f(x) |> torch_tensor(dtype = torch_float()) |> as_array()) +
labs(title = names)
})
Reduce(`+`, plts) + # графики ядерных функций
plot_annotation(title = "Графики распространенных ядерных функций")
Далее необходимо определить какую-то тренировочную задачу. В данном случае выбрана тригонометрическая функция синусоиды и к ней будет подмешан случайны шум. Определение этой тестовой задачи упрощается введением уже знакомой терминологии:
Ключами (keys) будут называться известные (фактические)
Значениями (values) будут называться известные (фактические)
Запросами (queries) будут называться новые xx для которых нужно найти оценку
Вот оно как, Михалыч, получается, все эти ключи, значения и запросы заехали к нам в ИИ из мира непараметрической статистики из модели регрессии на локальное среднее ?
Если воспользоваться формулой Надарая-Вотсона для оценки значений можно получить:
# Генерируем задание для оценки
f <- \(x) 2 * torch_sin(x) + x
n = 40
# Ответ на вопрос: откуда берутся ключи, значения, запросы
x_train = torch_sort(torch_rand(n) * 5)[[1]] # keys - фактические x независимой переменной + шум
y_train = f(x_train) + torch_randn(n) # values - фактические y зависимой переменной + шум
x_val = torch_arange(0, 5, 0.1) # queries - новые x для запроса
y_val = f(x_val) # truth - истинные y зависимой переменной для сравнения оценки
# Функция для оценки на базе формулы Надарая-Вотсона
nadaraya_watson <- function(x_train, y_train, x_val, kernel){
# Определение дистанции для каждого сочетания пар ключей и запросов
dists = x_train$reshape(c(-1, 1)) - x_val$reshape(c(1, -1)) # [keys, values] = [40, 51]
# Ядерная оценка: каждая колонка/строка соответствует запросу/ключу
k = kernel(dists) |> torch_tensor(dtype = torch_float()) # расчет значения ядерной функции
# Нормализация по ключам для каждого запроса
attention_w = k/k$sum(1) # матрица локальных весов W
y_hat = torch_matmul(y_train, attention_w) # Матричное умножения на y (values)
list(y_hat, attention_w)
}
# Табличка с истинными значениями
truth_df <- data.frame(x_train=as_array(x_train), y_train=as_array(y_train))
# Функция для отрисовки модельных и истинных значений
fit_plts <- mapply(kernel, kernel_names, SIMPLIFY = F, FUN=function(f, names){
data.frame(x_val=as_array(x_val), y_val = as_array(y_val), y_train=nadaraya_watson(x_train, y_train, x_val, kernel =f)[[1]] |> as_array()) |>
ggplot() +
geom_line(aes(x_val, y_train, col = "model"), lty = 5) +
geom_line(aes(x_val, y_val, col = "truth")) +
geom_point(data = truth_df, aes(x_train, y_train), alpha = .3) +
scale_color_manual(values = c("model" = "firebrick", "truth" = "white")) +
labs(title = names, col="")
})
# Непосредственно картинки
Reduce(`+`, fit_plts) + plot_layout(guides = "collect") +
plot_annotation(title = "Сравнения оценки модели Надарая-Вотсона и истинных значений")
Получились местами неплохие оценки, впрочем можно усложнить модель включением в ядерные функции параметра – ширина окна (bandwidth), и улучшить результат подстройкой этого параметра. В этом случае нужно правда помнить про Bias–variance tradeoff.
В заключении интересно посмотреть как работают веса , которые теперь можно называть весами внимания (attention weights)
# Функция для отрисовки матрицы внимания
attn_plts <- mapply(kernel, kernel_names, SIMPLIFY = F, FUN=function(f, names){
nadaraya_watson(x_train, y_train, x_val, kernel =f)[[2]] |>
as_array() |>
as.data.frame(row.names = T) |>
dplyr::mutate(keys = seq_len(n)) |>
tidyr::pivot_longer(!keys, names_to = "queries") |>
dplyr::mutate(queries=gsub("V", "", queries) |> as.numeric()) |>
ggplot() +
geom_raster(aes(queries, keys, fill = value)) +
scale_fill_viridis_c() +
scale_y_reverse() +
labs(title = names)
})
# Непосредственно картинки
Reduce(`+`, attn_plts) +
plot_annotation(title = "Матрицы внимания для разных ядерных функций")
Получаем визуально зависимость ключей от значений и наоборот. Для всех ядерных функций кроме постоянной имеем: чем ближе запрос к ключу, тем выше плотность распределения вероятности, что собственно и ожидалось с учетом сортировки значений от малого к большему.
Механизм внимания
Обозначим матрицу , и будем ее называть функцией внимания. Рассмотрим случай для ядерной функции нормального распределения:
В первую очередь можно обратить внимание на последний член, который зависит только от , то есть он одинаковый для всех пар . При нормализация весов внимания к единице (смотри ниже softmax
), эта часть выражения полностью исчезает. Во-вторых, в случае использования нормализации (batch или layer) нормы становятся хорошо ограниченными и постоянными. Таким образом, этот член можно также выбросить из рассмотрения.
Наконец, для избежания проблем с расчетом градиентов хорошо бы контролировать порядок величины аргументов в экспоненциальной функции. Допустим, все элементы запроса qq и ключа являются независимыми и одинаково распределенными случайными с нулевым средним и единичной дисперсией. Скалярное произведение таких векторов имеет нулевое среднее и дисперсию, равную . Чтобы дисперсия скалярного произведения оставалась равной одному независимо от длины вектора, необходимо дополнительно масштабировать скалярное произведение, то есть добавить в знаменатель . Таким образом, получается выражение для функции внимания для скалярного произведения, которая кстати и используется в трансформерах.
Однако веса функции внимания αα всё ещё требуют нормализации, что можно сделать с помощью операции softmax
, которая, напомню, является обобщением логистической функции для многомерного случая. Выбор данной функции обусловлен эмпирическими соображениями: ее много используют и она хорошо работает благодаря своей стабильности, предсказуемости и определенности на интервале от 0 до 1. Итоговая формула функции внимания для скалярного произведения (Dot Product Attention) выглядит следующим образом:
В итоге
Итак, в этой заметке удалось пояснить суть задачи непараметрической регрессии, которая позволяет делать оценку неизвестной функции, но с имеющемся сведениями (знаниями) о независимых переменных (ключей) и зависимых (значений) для новых запросов независимой переменной. Таким образом, вводится терминология для механизма внимания, поясняющая суть входов и выходов этого важного архитектурного элемента.
Если рассмотреть оценку непараметрической регрессии Надарая-Вотсона для ядерной функции нормального распределения и немного доработать метод нормализации матрицы внимания, то на выходе получится формула функции внимания, которая используется в современных архитектурах нейронных сетей в том числе в трансформерах.
Полезные ссылки
Оригинал заметки
ihouser
Яснее не стало. Почему внимание выделяет важное слово? Откуда оно знает что для человека важно? Математики же, не просто так взяли хитрую формулу, а по какой то важной причине. Как они поняли что это сработает и сделает ИИ?
welcome2hype Автор
Сожалею, что заметка не помогла вам лучше разобраться в этой теме. Вы правильно заметили, что мой текст не отвечает на вопрос "Почему внимание выделяет важное слово?". Я рекомендую внимательнее ознакомиться с другой пукбликацией на эту тему: Трансформер в картинках
Там, на мой взгляд, содержится ответ на ваш вопрос.