Всем привет. 

Характерной тенденцией последних нескольких лет в глубоком обучении является проникновение трансформера в различные сферы деятельности, где только можно и нельзя (но если очень хочется, то можно) применить нейронные сети. Универсальность архитектуры позволяет работать с самыми разнообразными данными, предварительно превращая их в последовательность токенов, будь то текст, картинки, аудио, видео или даже состояние среды.

Из недавней статьи про Gato
Из недавней статьи про Gato

Но за невероятную мощь и гибкость архитектуры приходится платить значительной вычислительной сложностью и расходом памяти, ибо сие многоголовое чудище ненасытно в отношении памяти, особенно для длинных последовательностей, что ограничивает применимость моделей на практике. Да и даже при наличии серьезных вычислительных ресурсов  обучение моделей на серьезных задачах - дело отнюдь не быстрое.

В недалеком прошлом вышла целая плеяда работ посвященных удешевлению дорогой операции внимания посредством построения различных приближений, сводящих квадратичную по длине последовательности вычислительную сложность и расход памяти к субквадратичной за счет приближения матрицами более низкого ранга, хэшированием, разреженного внимания, локального внимания, комбинированного и вагон и маленькая тележка других идей. Многие подходы показали себя довольно неплохо, давая небольшую потерю в качестве относительно исходного vanilla attention, но все-таки внимание в его первозданном виде было и остается наиболее распространенным.

Различные вариации разреженного внимания и их смесь - BigBird
Различные вариации разреженного внимания и их смесь - BigBird

И на днях вышла работа Flash Attention, где был предложен способ существенно ускорить вычисление attention на GPU, причем никак не меняя конечный результат. То есть делается то же самое, что и при стандартном вычислении внимания, но по-другому.

Ключевая идея

Ключевым ингредиентом успеха Flash Attention является наблюдение того, что основное время уходит не сколько на сами вычисления, ибо современные карточки с легионом CUDA ядер и всякими наворотами выдают тучу терафлопс, а на обращения к памяти, подрузку матриц и тензоров.

(Слева) Иерархия памяти при обучении на GPU.
 (Справа) Иллюстрация поблочной процедуры вычисления внимания.
(Слева) Иерархия памяти при обучении на GPU. (Справа) Иллюстрация поблочной процедуры вычисления внимания.

Память видеокарты имеет следующую иерархию - есть быстрая кэш память, которой немного, всего пара десятков мегабайт, и относительно медленная, но которой может быть много HBM (high bandwidth memory). Казалось бы, что величина пропускной способности порядка 1 Тб/с это довольно много, но в случае обучения или инференса модели трансформера, HBM память не поспевает за вычислениями (память преследовала его, но он оказался быстрее).

Как вычисляется старый добрый attention?

В стандартной реализации матрицы Q(query), K(key), V (value) загружаются поблочно из HBM памяти и промежуточные результаты вычислений - матрицы S = Q K^Tи \mathrm{P = softmax} (Q K^T / \sqrt{d}) - загружаются из HBM памяти в кэш и обратно. Традиционный подход, таким образом требует порядка O(N d + N^2)обращений к памяти.

Стандартный способ вычисления attention. Из оригинальной статьи.
Стандартный способ вычисления attention. Из оригинальной статьи.

 А что делает Flash Attention?

Вычисление маленькими блоками

Входные данные - матрицы Q, K, V- нарезаются на блоки некоторого размера, такого, чтобы все влезало в кэш.

Затем перемножение матриц Q, Kпроводится поблочно (довольно древняя тема). И \mathrm{softmax}, оказывается, тоже можно вычислять блок за блоком, причем достаточно дополнительно хранить сумму экспонент (нормализационный фактор) от входных данных для каждого блока и максимальное значение входов блока.

Как вычисляется softmax
Как вычисляется softmax
Слияние softmax из разных блоков
Слияние softmax из разных блоков

Пересчет промежуточных матриц

В стандартной реализации приходится хранить довольно увесистые матрицы Sи Pразмера O(N^2) в памяти для того чтобы вычислить градиенты весов при обратном проходе. Но зная нормализационные факторы в \mathrm{softmax} и выходные градиенты слоя, можно легко вычислить градиенты по ключам (как показано в приложении к статье). Такой подход несколько увеличивает количество вычислений, но так как основное время все равно уходило на обращения к памяти, имеем выигрыш по времени выполнения.

Слияние ядер

Кроме того, поблочная процедура вычисления позволяет выполнять все операции (перемножения матриц, \mathrm{softmax}, dropout) разом в одном СUDA kernel, в отличие от стандартной имплементации, где приходилось бы на каждую из операций вызывать отдельное ядро, что приводило к дополнительным накладным расходам.

Attention в PyTorch и предложенный Flash Attention
Attention в PyTorch и предложенный Flash Attention

Алгоритм целиком

Алгоритм Flash Attention
Алгоритм Flash Attention

Flash attention вычисляет операцию attention O = \mathrm{softmax}(Q K^T/\sqrt{d}) Vза O(N^2d) (вычислительная сложность обычного attention) с использованием O(N)дополнительной памяти.

Блочно-разреженный Flash Attention

Дополнительного ускорения можно добиться, если при вычислении внимания отбросить некоторые из нарезанных блоков (то есть не вычислять). То есть в некотором смысле структурированный прунинг на уровне активаций. Подобрав удачным или правильным образом маску можно уменьшить число вычислений, существенно не теряя в качестве.

Вычисление блочно-разреженного внимания. Матрица произведений между Q и K маскируется некоторой заданной заранее маской.
Вычисление блочно-разреженного внимания. Матрица произведений между Q и K маскируется некоторой заданной заранее маской.

Эксперименты

Авторы протестировали свое детище на трех известных бенчмарках:

  1. Обучение BERTа на Википедии

    Время обучения BERT-large на 8 GPU A100 для достижения точности в 72% в задаче маскированного моделирования языка
    Время обучения BERT-large на 8 GPU A100 для достижения точности в 72% в задаче маскированного моделирования языка
  2. Обучение GPT-2 на OpenWebtext

    Время обучения на OpenWebtext и ускорение по сравнению с бейзлайном от ОбнимающихсяЛиц
    Время обучения на OpenWebtext и ускорение по сравнению с бейзлайном от ОбнимающихсяЛиц
  3. Работа с длинными последовательностями на Long-range Arena

Разные методы Attention на задачах из Long-range arena benchmark. Под чертой приведены различные модели облегченного/приближенного механизма внимания.
Разные методы Attention на задачах из Long-range arena benchmark. Под чертой приведены различные модели облегченного/приближенного механизма внимания.

Как можно заметить, FlashAttention бьет даже довольно качественную и оптимизированную реализацию BERT от Nvidia. Еще более заметная разница между временем обучения для GPT-2 для двух стандартных реализаций и у Flash Attention. На Long-range arena Flash Attention выступает успешнее всех конкурентов с более дешевым вниманием, будучи при этом еще и быстрее. Блочно-разреженное внимание дает еще некоторый прирост в скорости, не теряя по сути в качестве.

Далее авторы сравнивают время прямого и обратного прохода по сети и расход памяти по сравнению с разными реализациями стандартного и "облегченного" внимания. При достаточно больших длинах последовательной Flash Attention оказывается быстрее и еще экономнее по памяти.

(Слева) Время прямого и обратного прохода на задаче моделирования длинных последвоательности (Справа) Зависимость расхода памяти от длины последовательности.
(Слева) Время прямого и обратного прохода на задаче моделирования длинных последвоательности (Справа) Зависимость расхода памяти от длины последовательности.

И в качестве вишенки на торте авторы впервые в истории смогли с хоть каким-то успехом решить задачи Path-X (картинка 128x128), Path-256 (256x256) с качеством выше выдаваемого алгоритмом гадания на кофейной гуще. Собственно, задача заключается в следующем - найти, существует ли путь между двумя белыми точками на черно-белой картинке или нет. Казалось бы, что тут сложного, справится и маленький ребенок. Ребенок-то справится, а вот последовательность длины порядка 10000 не влезет даже в очень емкую карточку в традиционном подходе, да и из представления картинки в виде одномерной последовательности черно-белых пикселей извлечь глобальный контекст не так-то просто. Все предыдущие подходы либо падали по памяти, либо выдавали качество уровня случайного классификатора (т.е 50%). Тут же удалось выбить ~60%.

Пример из Path-X
Пример из Path-X
Все подходы до Flash Attention не справлялись с задачей Path-X/256
Все подходы до Flash Attention не справлялись с задачей Path-X/256

Все эксперименты из статьи были запущены на одной машине с Нвидиевской A100.

Будущие направления

В последующей работе авторы предполагают обобщить метод на случай обучения на нескольких GPU. Кроме учета времени передачи данных от кэша к HBM появляется еще дополнительно время передачи данных между разными карточками и машинами, которое еще на порядок (порядки) продолжительнее. Кроме того, и поточечную нелинейность в трансформере (применение FFN (feed forward net) независимо к каждому токену) можно оптимизировать перейдя к блочным умножениям.

Итог

Кажется, что данный результат позволит расширить и оптимизировать применение трансформеров в различных областях и сократить время на обучение моделей, что влечет за собой экономию денег, элеткроэнергии и выбросов CO2 (куда без зеленой повестки). Интересно, насколько быстро данная идея будет воплощена в стандартных фрейворках глубокого обучения. Полученные результаты выглядят очень сильно, но в зависимости от архитектуры и задачи выигрыш будет более или менее заметен.

Ссылки и литература

Сама статья

Гитхаб-репозиторий

Блог на твиттере от создателей

Комментарии (0)