Можно ли реализовать механизм внутреннего внимания, потребляющий гораздо меньше ресурсов, чем обычно?

Говорят, что механизм внимания плохо переносит работу с последовательностями большой длины. Это — идея, которая встречалась любому, кто потратил достаточно много времени, занимаясь трансформерами и механизмом внутреннего внимания. Это, одновременно, и так, и не так. С одной стороны — с этим сталкивался каждый, кто пытался увеличить размеры контекста своей модели, натыкаясь при этом на то, что модель начинала работать с сильным скрипом. С другой стороны — возникает такое ощущение, что практически каждую неделю выходит новая эталонная модель, которая характеризуется новыми размерами контекста, бьющими все рекорды. (Контекстное окно Gemini составляет 2 миллиона токенов!)

Есть много хитроумных методов, вроде RingAttention, которые позволяют обучать модели с очень большими размерами контекста на мощных распределённых системах. Но сегодня меня интересует всего один простой вопрос: «Как далеко можно зайти, применяя лишь механизм линейного внимания?».

Разбираемся с математикой

Этот разбор будет представлять собой нечто вроде поездки «галопом по европам», но прошу вас это прочесть, так как тут будут затронуты некоторые важные вещи, которые пригодятся нам, когда мы доберёмся до результатов моего исследования.

Традиционный механизм внимания можно представить себе в виде конструкции, в которой можно выделить два ключевых момента:

  • Первый — это обычное для механизмов внимания softmax‑выражение, которое принимает произведение матриц Q (query, запрос) и K (key, ключ) и нормализует значения ради стабильности. Далее — к данным применяют функцию softmax (построчно), получая таким образом оценки внимания между каждой парой элементов последовательности.

  • Второй — это оценка временной сложности алгоритма, которая определяется точечным произведением N2. Эта операция, выполняемая внутри softmax, является ограничивающим фактором модели. Именно здесь вычисляются оценки внимания.

Это выражается в виде такой вот традиционной формулы:

https://miro.medium.com/v2/resize:fit:700/1*Lbknm6U_c_gSGBdRv5_iIw.png
Традиционное описание механизма внимания, реализуемого с использованием softmax

Если обратиться к нашим друзьям-математикам, то окажется, что всё это можно рассматривать с несколько иной точки зрения. Функцию softmax можно считать одним из многих способов описания распределения вероятностей, связывающих токены друг с другом. Мы можем использовать любую интересную нам меру сходства токенов (точечное произведение — это один из самых простых таких показателей), и до тех пор, пока мы будем применять нормализацию, всё у нас будет хорошо.

https://miro.medium.com/v2/resize:fit:700/1*hvn8dSdpFBmPnZWQXmZvJw.png
Общее выражение, описывающее механизм внимания, в котором применяется любая функция сходства

Немного некорректно будет говорить, что это и есть механизм внимания. Ведь, на самом деле, в единственном механизме внимания, который мы знаем и любим, функция сходства представляет собой функцию exp, принимающую точечное произведение матриц Q и K (показана ниже), как это можно видеть в softmax.

https://miro.medium.com/v2/resize:fit:700/1*dS2g18i1FgOmOOM8kwmvpA.png
Приближение функции сходства из механизма внутреннего внимания с двумя картами признаков

Можно предположить, что тут имеется некая карта признаков «phi», которая даёт нам результат, практически точно такой же, который получают, применяя функцию exp к точечному произведению. И, что очень важно, подобное выражение позволяет нам экспериментировать с порядком операций умножения матриц.

В этой публикации для формирования карты признаков предложена функция Exponential Linear Unit (ELU):

https://miro.medium.com/v2/resize:fit:540/1*LkzH80BopvEaa7M-0K6MsQ.png
Функция Exponential Linear Unit (ELU)

Она обладает многими полезными свойствами:

  1. Для значения больше 0 ELU(x) выдаёт линейный результат, который, хотя и не является тем же самым, что и результат функции exp, сохраняет относительное упорядочение оценок.

  2. Для значений, которые меньше или равны 0, экспоненциальное выражение сохраняет непрерывную природу функции и обеспечивает такое обращение с градиентами, при котором они не могут просто исчезнуть.

https://miro.medium.com/v2/resize:fit:700/1*bWwrVXG12wMnrxZFdFnedQ.png
Функция Exponential Linear Unit (ELU)

Здесь мы больше об этом особо говорить не будем, отмечу лишь, что имеются очень хорошие эмпирические доказательства того, что эта конструкция представляет собой адекватное приближение функции softmax.

Важно то, что применение ELU позволяет нам менять порядок операций. Мы можем сначала получить произведение карты признаков с K и V (value, значение) для формирования блока KV, а затем — получить произведение с Q. В результате речь будет идти о работе с данными, размеры которых соответствуют размерности модели, а не длине входной последовательности.

Если свести всё это воедино, получится выражение для вычисления линейного внимания:

https://miro.medium.com/v2/resize:fit:700/1*iKD7-nt-_nEQqEa3n1Fnvg.png
Вычисление линейного внимания с использованием карт признаков для аппроксимации показателя сходства softmax

Здесь всего лишь нужно, один раз на строку Q, вычислить выражения, находящиеся в скобках.

(Если вы хотите разобраться с тем, как сюда вписывается использование методики causal masking, и как вычисляются градиенты, почитайте вышеупомянутую публикацию, или ждите моих новых материалов).

Какой прирост в скорости даёт механизм линейного внимания?

Математика чётко указывает на то, что применение линейного внимания даёт хороший прирост в скорости, но лично я, до тех пор, пока не увижу результаты реальных измерений, отношусь к подобным вещам с некоторым подозрением.

Начнём с рассмотрения следующего фрагмента кода, в котором описывается всё то, о чём мы говорили. Механизм внимания, при реализации которого применяется softmax, выглядит очень привычно, здесь мы не делаем ничего экстраординарного.

class TraditionalAttention(nn.Module):
    def __init__(self, d_k):
        super(TraditionalAttention, self).__init__()
        self.d_k = d_k

    def forward(self, Q, K, V):
        Z = torch.sqrt(torch.tensor(self.d_k, device=Q.device, dtype=torch.float32))
        scores = torch.matmul(Q, K.transpose(-2, -1)) / Z
        attention_weights = F.softmax(scores, dim=-1)
        output = torch.matmul(attention_weights, V)
        return output

А при написании кода для линейного внимания мы начинаем с получения матриц Query, Key и Value, затем формируем карты признаков Query и Key с помощью ELU(x). После этого, используя einsum, выполняем умножение.

class LinearAttention(nn.Module):
    def __init__(self):
        super(LinearAttention, self).__init__()
        self.eps = 1e-6

    def elu_feature_map(self, x):
        return F.elu(x) + 1

    def forward(self, Q, K, V):
        Q = self.elu_feature_map(Q)
        K = self.elu_feature_map(K)
        KV = torch.einsum("nsd,nsd->ns", K, V)
        # Вычисляем нормализатор
        Z = 1/(torch.einsum("nld,nd->nl", Q, K.sum(dim=1))+self.eps)
        # В итоге - вычисляем и возвращаем новые значения
        V = torch.einsum("nld,ns,nl->nd", Q, KV, Z)
        return V.contiguous()

Прекрасно видеть это всё в виде кода. А как эта конструкция покажет себя в эксперименте? О каком росте производительности мы тут говорим? Сложно может быть оценить то, о каком ускорении может идти речь при переходе от квадратичных к линейным ограничениям, поэтому я решил всё проверить на практике.

Мы собираемся взять один слой внимания с фиксированным измерением модели d_k, равным 64, и измерить время, необходимое для выполнения прямого прохода при использовании набора последовательностей с размером пакета 32. Единственная переменная, которая будет меняться — это длина последовательности. Она изменяется в диапазоне от 128 до 6000 (для справки — длина контекста GPT-3 составляет 2048). Прогон модели в каждом из состояний выполняется 100 раз, что позволяет вычислить показатели среднего значения и стандартного отклонения результатов. Эксперименты выполняются на GPU Nvidia T4.

Если учитывать то, о каком простом эксперименте идёт речь, можно сказать, что результаты получились довольно-таки впечатляющими.

https://miro.medium.com/v2/resize:fit:700/1*yx6qMZ8N9xvb-K88L-956A.png
Измерение времени, уходящего на итерацию для обработки одной последовательности с помощью традиционного механизма внимания (softmax) и механизма линейного внимания. Показатель для каждой длины последовательности представляет собой усреднённые данные и стандартное отклонение, полученные при выполнении более чем 100 итераций. Длина последовательности находится в диапазоне от 128 до 6000. Здесь же дан показатель соотношения времени, необходимого на выполнении итерации при использовании разных механизмов внимания (Ratio), что позволяет нагляднее представить рост скорости работы модели.

Видно, что даже на таком маленьком примере мы получили 60-кратный прирост скорости.

Анализ результатов

По результатам эксперимента можно сделать несколько очевидных выводов:

  1. Линейное внимание обладает неоспоримым преимуществом перед обычным — как в плане скорости выполнения, так и в плане потребления памяти. Если говорить о скорости, то повышение пропускной способности системы — это всегда хорошо. А экономия потребления памяти может оказаться огромным плюсом в системах, обладающих небольшим объёмом памяти.

  2. График Ratio характеризуется необычным изломом. Это наводит на мысли о некоей дополнительной низкоуровневой оптимизации, и о том, что реальное соотношение скорости работы систем с разными механизмами внимания может отличаться от теоретического. Поэтому не стоит слепо доверять этим результатам.

Для полноты картины отмечу, что не стоит впадать в заблуждение, считая, что «механизм линейного внимания в 60 раз быстрее механизма обычного внимания в маленьких моделях». В реальности слои прямого распространения трансформеров часто представляют собой сущности, описываемые довольно большими количествами параметров. А энкодер и декодер часто представляют собой структуры ограниченного размера. Но если говорить о нашей чётко очерченной задаче, можно сказать, что результаты мы получили очень и очень хорошие!

Вычислительная сложность

Если подумать о реальной временной сложности каждого из подходов — это позволит понять то, в чём именно заключается замеченная нами разница между ними.

Рассмотрим временную сложности традиционного механизма внимания, основанного на softmax. Первое выражение даёт временную сложность умножения Q на K — это будет показатель n2, характеризующий количество оценок внимания, умноженный на вектор размера d_k. Второе выражение описывает сложность применения softmax к оценкам внимания. Он тоже равняется n2. А третий элемент представляет собой точечное произведение матриц n2 и V.

Если, для простоты, предположить, что матрицы Q, K и V имеют одну и ту же размерность, то в итоге окажется, что основной вклад во временную сложность алгоритма вносит показатель n2. (Учитывая то, что размерность модели намного меньше длины входной последовательности.)

https://miro.medium.com/v2/resize:fit:700/1*gPwkm4eh9CXEgWNhMcWIzQ.png
Обычный механизм внимания, основанный на softmax, характеризуется временной сложностью n2, зависящей от длины последовательности, при том, что обычно размерность модели d_k намного меньше n

В случае с линейным механизмом внимания всё выглядит несколько иначе. Мы, как и прежде, глядя на выражение, описывающее временную сложность алгоритма, проанализируем каждую его часть.

Первое выражение характеризует временную сложность применения карты признаков к матрицам Q и K. Второе выражение — это произведение матриц Q и V, что даёт матрицу (d_k, d_v). Операция умножения K(QV) в третьем выражении имеет ту же временную сложность. В результате, учитывая то, что размерности модели для разных матриц будут одними и теми же, мы получаем итоговую сложность — линейную для длины последовательности, и квадратичную для размерности модели.

https://miro.medium.com/v2/resize:fit:700/1*T33iDSmsnv-XSUgk6gA9OA.png
Линейный механизм внутреннего внимания меняет правила игры. Он характеризуется линейной временной сложностью для n и квадратичной сложностью для размерности модели (если скрытые измерения dk и dv одинаковые, как, в нашем случае, сделано ради простоты). В результате в любом режиме, когда n значительно больше dk, мы получаем гораздо меньшую сложность.

Поэтому получается, что до тех пор, пока размерность модели меньше, чем длина входной последовательности, модель будет работать гораздо быстрее. У нас остался лишь один реальный вопрос: насколько хорошо наша конструкция способна воспроизвести то, что умеет классический механизм внимания?

За всё надо платить — можно ли обучить нашу модель?

Пожалуй, хватит уже о механизме линейного внимания — надеюсь, все уже поняли, что оно гораздо быстрее традиционного. Предлагаю проверить то, что у нас получилось, на реальной задаче. Можно ли обучать модели с линейным механизмом внимания? Будут ли модели с разными механизмами внимания работать похожим образом?

Те модели, которыми мы здесь пользуемся, очень малы (если аудитории будет интересна тема настройки простого окружения обучения подобных моделей — я вполне могу об этом когда-нибудь написать), мы пользуемся простым набором данных. Речь идёт о наборе данных Penn Treebank (он находится в свободном доступе, найти его можно в torchtext). Он содержит коллекцию коротких фрагментов текста, которые можно использовать для обучения и тестирования маленьких языковых моделей.

Можно ли научить реальную модель делать реальные прогнозы?

Ну — реальные прогнозы, если честно, в нашем случае — это немного притянуто за уши. Не будем забывать о числе параметров модели и о времени, которое планируется потратить на обучение. Главное, что меня здесь интересует — это то, похожа ли динамика обучения модели с линейным механизмом внимания на динамику обучения модели с классическим механизмом внимания. Мы взглянем на кривые потерь для авторегрессионного обучения «линейной» модели на простом наборе данных для языкового моделирования. Если эти кривые будут похожи на те, что характерны для «классических» моделей — это, по крайней мере, даст нам некоторую уверенность в том, что разные механизмы дают похожие результаты.

Учитывая природу данных, можно сказать, что выходы модели редко можно признать высококачественными, но они дают нам всё то, что мы ожидаем увидеть в том случае, если прогон обучения модели пройдёт так, как нужно.

Посмотрим на кривые обучения. Первый график, на котором различимы две кривые, показывает потери обучения и валидации для моделей с традиционным и линейным механизмами внимания. Можно видеть, проанализировав более 10 эпох, что эти два подхода практически неразличимы. Аналогично, если взглянуть на второй график, видно, что потери для традиционного softmax-подхода и для линейного механизма внимания, снова, показывают абсолютно идентичную динамику обучения.

https://miro.medium.com/v2/resize:fit:1200/1*kM3EWUb_JQNnp0yLTNwpSg.png
Потери обучения и валидации по эпохам для моделей с линейным и традиционным механизмами внимания
https://miro.medium.com/v2/resize:fit:1200/1*4WceavsoqH9wXIHjxfxxbA.png
Кривые потерь обучения для линейного и традиционного механизмов внимания

Итоги

Конечно, описанный здесь подход к построению моделей нельзя назвать совершенно универсальным, но мы тут и не собирались конкурировать с GPT-моделями. Однако то, о чём тут шла речь, позволяет весьма оптимистично смотреть в сторону уменьшения сложности механизма внимания, не теряя при этом возможности моделировать что-то полезное с помощью соответствующих конструкций.

О, а приходите к нам работать? ? ?

Мы в wunderfund.io занимаемся высокочастотной алготорговлей с 2014 года. Высокочастотная торговля — это непрерывное соревнование лучших программистов и математиков всего мира. Присоединившись к нам, вы станете частью этой увлекательной схватки.

Мы предлагаем интересные и сложные задачи по анализу данных и low latency разработке для увлеченных исследователей и программистов. Гибкий график и никакой бюрократии, решения быстро принимаются и воплощаются в жизнь.

Сейчас мы ищем плюсовиков, питонистов, дата-инженеров и мл-рисерчеров.

Присоединяйтесь к нашей команде

Комментарии (2)


  1. VPryadchenko
    08.07.2024 09:03

    Как, все, что нам действительно нужно, это умножение карт признаков на выходе блока внимания, а как они получаются внутри - уже дело десятое. Отсюда и множество вариантов, как обойти использование softmax


    1. VPryadchenko
      08.07.2024 09:03

      За статью плюс, безусловно