Всем привет.
Характерной тенденцией последних нескольких лет в глубоком обучении является проникновение трансформера в различные сферы деятельности, где только можно и нельзя (но если очень хочется, то можно) применить нейронные сети. Универсальность архитектуры позволяет работать с самыми разнообразными данными, предварительно превращая их в последовательность токенов, будь то текст, картинки, аудио, видео или даже состояние среды.
![Из недавней статьи про Gato Из недавней статьи про Gato](https://habrastorage.org/getpro/habr/upload_files/d24/72a/c95/d2472ac95ececaa2f41700f904f3cc23.png)
Но за невероятную мощь и гибкость архитектуры приходится платить значительной вычислительной сложностью и расходом памяти, ибо сие многоголовое чудище ненасытно в отношении памяти, особенно для длинных последовательностей, что ограничивает применимость моделей на практике. Да и даже при наличии серьезных вычислительных ресурсов обучение моделей на серьезных задачах - дело отнюдь не быстрое.
В недалеком прошлом вышла целая плеяда работ посвященных удешевлению дорогой операции внимания посредством построения различных приближений, сводящих квадратичную по длине последовательности вычислительную сложность и расход памяти к субквадратичной за счет приближения матрицами более низкого ранга, хэшированием, разреженного внимания, локального внимания, комбинированного и вагон и маленькая тележка других идей. Многие подходы показали себя довольно неплохо, давая небольшую потерю в качестве относительно исходного vanilla attention, но все-таки внимание в его первозданном виде было и остается наиболее распространенным.
![Различные вариации разреженного внимания и их смесь - BigBird Различные вариации разреженного внимания и их смесь - BigBird](https://habrastorage.org/getpro/habr/upload_files/3c9/f25/2a1/3c9f252a1376a23e1c0f0728442ed2b7.png)
И на днях вышла работа Flash Attention, где был предложен способ существенно ускорить вычисление attention на GPU, причем никак не меняя конечный результат. То есть делается то же самое, что и при стандартном вычислении внимания, но по-другому.
Ключевая идея
Ключевым ингредиентом успеха Flash Attention является наблюдение того, что основное время уходит не сколько на сами вычисления, ибо современные карточки с легионом CUDA ядер и всякими наворотами выдают тучу терафлопс, а на обращения к памяти, подрузку матриц и тензоров.
![(Слева) Иерархия памяти при обучении на GPU.
(Справа) Иллюстрация поблочной процедуры вычисления внимания. (Слева) Иерархия памяти при обучении на GPU.
(Справа) Иллюстрация поблочной процедуры вычисления внимания.](https://habrastorage.org/getpro/habr/upload_files/2cb/3ee/680/2cb3ee68094e96a04288008aefb8d588.png)
Память видеокарты имеет следующую иерархию - есть быстрая кэш память, которой немного, всего пара десятков мегабайт, и относительно медленная, но которой может быть много HBM (high bandwidth memory). Казалось бы, что величина пропускной способности порядка 1 Тб/с это довольно много, но в случае обучения или инференса модели трансформера, HBM память не поспевает за вычислениями (память преследовала его, но он оказался быстрее).
Как вычисляется старый добрый attention?
В стандартной реализации матрицы (query),
(key),
(value) загружаются поблочно из HBM памяти и промежуточные результаты вычислений - матрицы
и
- загружаются из HBM памяти в кэш и обратно. Традиционный подход, таким образом требует порядка
обращений к памяти.
![Стандартный способ вычисления attention. Из оригинальной статьи. Стандартный способ вычисления attention. Из оригинальной статьи.](https://habrastorage.org/getpro/habr/upload_files/84a/4e8/a54/84a4e8a5465fd227c8d3d234141d2745.png)
А что делает Flash Attention?
Вычисление маленькими блоками
Входные данные - матрицы - нарезаются на блоки некоторого размера, такого, чтобы все влезало в кэш.
Затем перемножение матриц проводится поблочно (довольно древняя тема). И
, оказывается, тоже можно вычислять блок за блоком, причем достаточно дополнительно хранить сумму экспонент (нормализационный фактор) от входных данных для каждого блока и максимальное значение входов блока.
![Как вычисляется softmax Как вычисляется softmax](https://habrastorage.org/getpro/habr/upload_files/dcd/fb0/482/dcdfb0482ba12353ae56e90d32fb5132.png)
![Слияние softmax из разных блоков Слияние softmax из разных блоков](https://habrastorage.org/getpro/habr/upload_files/0ff/671/3af/0ff6713afb81f93f4595d5ee84f1b52e.png)
Пересчет промежуточных матриц
В стандартной реализации приходится хранить довольно увесистые матрицы и
размера
в памяти для того чтобы вычислить градиенты весов при обратном проходе. Но зная нормализационные факторы в
и выходные градиенты слоя, можно легко вычислить градиенты по ключам (как показано в приложении к статье). Такой подход несколько увеличивает количество вычислений, но так как основное время все равно уходило на обращения к памяти, имеем выигрыш по времени выполнения.
Слияние ядер
Кроме того, поблочная процедура вычисления позволяет выполнять все операции (перемножения матриц, , dropout) разом в одном СUDA kernel, в отличие от стандартной имплементации, где приходилось бы на каждую из операций вызывать отдельное ядро, что приводило к дополнительным накладным расходам.
![Attention в PyTorch и предложенный Flash Attention Attention в PyTorch и предложенный Flash Attention](https://habrastorage.org/getpro/habr/upload_files/64a/bf1/69e/64abf169e8ef040367a17a5a74e94a35.png)
Алгоритм целиком
![Алгоритм Flash Attention Алгоритм Flash Attention](https://habrastorage.org/getpro/habr/upload_files/143/f6c/366/143f6c36635f8addf70aea519ebd6ca5.png)
Flash attention вычисляет операцию attention за
(вычислительная сложность обычного attention) с использованием
дополнительной памяти.
Блочно-разреженный Flash Attention
Дополнительного ускорения можно добиться, если при вычислении внимания отбросить некоторые из нарезанных блоков (то есть не вычислять). То есть в некотором смысле структурированный прунинг на уровне активаций. Подобрав удачным или правильным образом маску можно уменьшить число вычислений, существенно не теряя в качестве.
![Вычисление блочно-разреженного внимания. Матрица произведений между Q и K маскируется некоторой заданной заранее маской. Вычисление блочно-разреженного внимания. Матрица произведений между Q и K маскируется некоторой заданной заранее маской.](https://habrastorage.org/getpro/habr/upload_files/060/321/967/0603219674b5bd2c195d11b00343ef6f.png)
Эксперименты
Авторы протестировали свое детище на трех известных бенчмарках:
-
Обучение BERTа на Википедии
Время обучения BERT-large на 8 GPU A100 для достижения точности в 72% в задаче маскированного моделирования языка -
Обучение GPT-2 на OpenWebtext
Время обучения на OpenWebtext и ускорение по сравнению с бейзлайном от ОбнимающихсяЛиц Работа с длинными последовательностями на Long-range Arena
![Разные методы Attention на задачах из Long-range arena benchmark. Под чертой приведены различные модели облегченного/приближенного механизма внимания. Разные методы Attention на задачах из Long-range arena benchmark. Под чертой приведены различные модели облегченного/приближенного механизма внимания.](https://habrastorage.org/getpro/habr/upload_files/575/26c/95d/57526c95d1d9ef974f38141a9a4f7b02.png)
Как можно заметить, FlashAttention бьет даже довольно качественную и оптимизированную реализацию BERT от Nvidia. Еще более заметная разница между временем обучения для GPT-2 для двух стандартных реализаций и у Flash Attention. На Long-range arena Flash Attention выступает успешнее всех конкурентов с более дешевым вниманием, будучи при этом еще и быстрее. Блочно-разреженное внимание дает еще некоторый прирост в скорости, не теряя по сути в качестве.
Далее авторы сравнивают время прямого и обратного прохода по сети и расход памяти по сравнению с разными реализациями стандартного и "облегченного" внимания. При достаточно больших длинах последовательной Flash Attention оказывается быстрее и еще экономнее по памяти.
![(Слева) Время прямого и обратного прохода на задаче моделирования длинных последвоательности (Справа) Зависимость расхода памяти от длины последовательности. (Слева) Время прямого и обратного прохода на задаче моделирования длинных последвоательности (Справа) Зависимость расхода памяти от длины последовательности.](https://habrastorage.org/getpro/habr/upload_files/382/37f/53d/38237f53deda00b40dfe33c2ff7363bc.png)
И в качестве вишенки на торте авторы впервые в истории смогли с хоть каким-то успехом решить задачи Path-X (картинка 128x128), Path-256 (256x256) с качеством выше выдаваемого алгоритмом гадания на кофейной гуще. Собственно, задача заключается в следующем - найти, существует ли путь между двумя белыми точками на черно-белой картинке или нет. Казалось бы, что тут сложного, справится и маленький ребенок. Ребенок-то справится, а вот последовательность длины порядка 10000 не влезет даже в очень емкую карточку в традиционном подходе, да и из представления картинки в виде одномерной последовательности черно-белых пикселей извлечь глобальный контекст не так-то просто. Все предыдущие подходы либо падали по памяти, либо выдавали качество уровня случайного классификатора (т.е 50%). Тут же удалось выбить ~60%.
![Пример из Path-X Пример из Path-X](https://habrastorage.org/getpro/habr/upload_files/be5/bab/112/be5bab11221eac66846e36bc15d5b0b9.png)
![Все подходы до Flash Attention не справлялись с задачей Path-X/256 Все подходы до Flash Attention не справлялись с задачей Path-X/256](https://habrastorage.org/getpro/habr/upload_files/102/da7/228/102da7228fe8da64dbc8b09158727a9c.png)
Все эксперименты из статьи были запущены на одной машине с Нвидиевской A100.
Будущие направления
В последующей работе авторы предполагают обобщить метод на случай обучения на нескольких GPU. Кроме учета времени передачи данных от кэша к HBM появляется еще дополнительно время передачи данных между разными карточками и машинами, которое еще на порядок (порядки) продолжительнее. Кроме того, и поточечную нелинейность в трансформере (применение FFN (feed forward net) независимо к каждому токену) можно оптимизировать перейдя к блочным умножениям.
Итог
Кажется, что данный результат позволит расширить и оптимизировать применение трансформеров в различных областях и сократить время на обучение моделей, что влечет за собой экономию денег, элеткроэнергии и выбросов CO2 (куда без зеленой повестки). Интересно, насколько быстро данная идея будет воплощена в стандартных фрейворках глубокого обучения. Полученные результаты выглядят очень сильно, но в зависимости от архитектуры и задачи выигрыш будет более или менее заметен.