Предисловие

Начнем повествование с приевшейся, шаблонной, клишированной фразы, мотивирующей сжатие нейронных сетей:

За последние несколько лет нейронные сети достигли значительных успехов в разнообразных приложениях и сферах человеческой (и нечеловеческой) деятельности, превосходя даже человека на ряде задач. Но мощь и гибкость, способность фитировать сложные зависимости, требуют значительных вычислительных ресурсов как на этапе обучения, так и на инференсе, что ограничивает зачастую применение нейронных сетей на мобильных устройствах и при наличии ограниченных вычислительных мощностей.

Поэтому по мере бурного прогресса и развития новых архитектур параллельно идет активная разработка разнообразных подходов по сжатию и повышению эффективности нейронных сетей

Какие-то из перечисленных подходов более эффективны в прикладных приложениях - типа квантования и дистилляции, теоретический выигрыш других (в количестве параметров и вычислений) же (неструктурированный прунинг) сложнее переносится на реальные устройства и вычислители. Разные методы можно комбинировать друг с другом.

Прунинг токенов

Архитектура Transformer, основанная на механизме внимания, благодаря своей универсальности была успешно применена во всех (или почти всех) сферах глубокого обучения. Входные данные, будь то текст, картинки, звук, или состояние среды, преобразуются в последовательность токенов, к которой последовательно применяются слои attention, учитывающие глобальный контекст и взаимосвязи токенов друг с другом, и feedforward слои, применяющие нелинейное преобразование к каждому токену по отдельности.

(Слева) Схематиченое изображение операции Attention
(Центр) Аггрегация информации из различных голов трансформера
(Справа) Блок энкодера/декодера в трансформере
(Слева) Схематиченое изображение операции Attention (Центр) Аггрегация информации из различных голов трансформера (Справа) Блок энкодера/декодера в трансформере

Далее в тексте входная последовательность будет обозначаться как \mathbf{x}, i-ый токен - \mathbf{x}_i. Q, K, V соотвествуют последовательностям query, key, value. N- длина последовательности токенов.

Токен может быть словом (или частью слова) в предложении, или патчем на картинке (размера 8x8, 16x16, 32x32). Если для решения задачи важен только глобальный контекст, как в случае задачи классификации изображений или sentiment analysis, то предсказание модели часто определяется небольшой долей токенов. Например, для того, чтобы понять, какую эмоциональную окраску имеет отзыв, достаточно посмотреть на ключевые слова, и если нас интересует какое животное изображено на картинке - наличие или отсутствие дерева или куста вряд ли существенно повлияет на наш ответ.

(Слева) Пример отзыва на фильм (ключевое слово подсвечено зеленым).  
(Справа) Патчи на которых расположена собака выделены красным.
(Слева) Пример отзыва на фильм (ключевое слово подсвечено зеленым). (Справа) Патчи на которых расположена собака выделены красным.

Кроме того, после обработки данных несколькими слоями трансформера часть токенов может содержать достаточное количество информации для решения поставленной задачи, и даже более того, часть из них дублировать друг друга.

Следовательно, естественным вариантом повышения эффективности трансформера является отбрасывание наименее важных токенов. Причем выигрыш от отбрасывания довольно существенный, так как применение линейных слоев W_Q, W_K, W_V для генерации query, key, value, проекция W_Pпосле операции внимания и последующие feedforward слои линейно в плане вычислений по длине последовательности N, а операция attention \mathrm{softmax}\left(QK/ \sqrt{d}\right)квадратична по N.

Далее будет рассмотрена 1 работа из NLP по прореживанию токенов, и 4 из компьютерного зрения.

Learned Token Pruning for Transformers

Применительно к задачам обработки естественного языка (NLP) способ прореживания токенов был предложен в работе Learned Token Pruning for Transformers.

(Слева) Схема прореживания последовательности и отбрасывания малоинформативных токенов для задачи Sentiment Analysis. 
(Справа) Attention для токенов внутри посоедовательности. Более информативные токены выделены более темным цветом.
(Слева) Схема прореживания последовательности и отбрасывания малоинформативных токенов для задачи Sentiment Analysis. (Справа) Attention для токенов внутри посоедовательности. Более информативные токены выделены более темным цветом.

В качестве критерия важности токена используется средний attention со всеми токенами по всем головам трансформера:

s^{(l)}(\mathbf{x}_i) = \frac{1}{n N_h} \sum_{h=1}^{N_h} \sum_{j=1}^{n} \mathbf{A}^{(h, l)} (\mathbf{x}_i, \mathbf{x}_j)

Выше \mathbf{x}_i- вектор признаков, отвечающий i-му токену, \mathbf{A}^{(h, l)}- матрица внимания для h-ой головы и l-го блока трансформера, n- длина последовательности токенов, и N_h- число голов в multi-head attention.

Определение того, какие токены оставить, а какие отправить в Вальгаллу, осуществляется следующим нехитрым способом: для каждого блока lзадается порог \theta^{(l)}, такой что, если значение s^{(l)}(\mathbf{x}_i)больше порога, то токен остается, а иначе покидает этот мир. Формулируя иначе, на последовательность накладывается бинарная маска следующего вида:

M^{(l)}(\mathbf{x}_i)  = \begin{cases}1 & \mathrm{if} \  s^{(l)}(\mathbf{x}_i) > \theta^{(l)} \\ 0 & \mathrm{else} \end{cases}

Если токен в слое l запрунен, то дальше с ним не проводится никаких вычислений как в слое l, так и последующих за ним.

Обучаемый порог

Однако, сразу же возникает: а каким образом определить правильный порог  \theta^{(l)}? Он зависит от блока l, от поданного на вход примера, и от самой задачи. Поэтому предлагается сделать этот порог обучаемым. Но обучать порог, явным образом отбрасывая токены, не представляется возможным в связи с тем, что операция отбрасывания по порогу недифференцируема. Поэтому напрашивается замена hard pruning на soft pruning, когда вместо бинарной маски, принимающей значения только \{0, 1\}, допустимо любое значение от 0 до 1. Естественной параметризацией для такой маски выступает сигмоида с некоторой температурой T:

\tilde{M}^{(l)}(\mathbf{x}_i)  = \sigma \left(\frac{s^{(l)}(\mathbf{x}_i) - \theta^{(l)}}{T} \right)

Величина T определяет, насколько резко происходит переход от 0 к 1. Таким образом, на этапе обучения менее информативные патчи будут подавлены по сравнению с более информативными, но их вклад не обращается точно в ноль на этапе обучения.

Обучение модели с процедурой прунинга токенов проходит в 3 этапа:

  1. Модель обучается с soft маской \tilde{M}^{(l)}(\mathbf{x}_i)и оптимизируются параметры сети и значения порогов.

  2. Значения порогов  \theta^{(l)}фиксируются и маска становится бинарной.

  3. Параметры самой модели дообучаются при фиксированных значениях порогов и hard прунинге токенов.

Чтобы модели не было выгодно просто брать все токены к исходной функции потерь (скажем, кросс-энтропии для задачи классификации) добавляется регуляризационный член, чтобы уменьшать значение soft маски для как можно большего количества токенов:

\mathcal{L}_{new} = \mathcal{L} + \lambda \mathcal{L} _{reg} \quad \mathcal{L} _{reg} = \frac{1}{L} \sum_{l}^{L} \Vert \tilde{M}^{(l)}(\mathbf{x}_i)  \Vert_{1}

Величина \lambda определяет trade-off между желанием решить хорошо исходную задачу и выигрышем в производительности.

Эксперименты

Предложенный подход проверятся на ряде задач классификации последовательностей - MNLI, MNLI-m, QQP, QNLI, SST-2, SST-B, MRPC, RTE. В качестве модели используется RoBERTa-Base.

Валидация LTP на нескольких датасетах по классификации последовательностей. Под Speedup подразумевается отношение количества FLOPs для исходной модели и с прореживанием токенов.
Валидация LTP на нескольких датасетах по классификации последовательностей. Под Speedup подразумевается отношение количества FLOPs для исходной модели и с прореживанием токенов.

При небольшой просадке в качестве модель удается ускорить примерно в 2 раза с точки зрения теоретических FLOPs. Однако, выигрыш в теоретических FLOPs зачастую трудно перенести на реальные устройства, поэтому более интересны цифры реального ускорения инференса на СPU/GPU. И как показывают графики ниже, за счет некоторой просадки в качестве можно добиться реального ускорения работы модели (интересно было бы сравнить с моделью той же архитектуры, но с меньшим числом каналов или структурированным прунингом).

Относительное ускорение модели (отношение времен обработки одного батча) для задач QNLP, QQP на CPU (Intel Haswell) и GPU (Nvidia V100).
Относительное ускорение модели (отношение времен обработки одного батча) для задач QNLP, QQP на CPU (Intel Haswell) и GPU (Nvidia V100).
  • К достоинствам предложенного подхода можно отнести простоту и интуитивность подхода, наличие реального выигрыша в производительности на практике.

  • К недостаткам же можно отнести невозможность задать автоматически желаемую степень ускорения модели. В зависимости от значения \lambda ускорение будет большим или меньшим, но априори понять каким оно будет не представляется возможным. Поэтому если, скажем, стоит задача ускорить модель в 2 раза потребуется провести несколько экспериментов, чтобы определить нужное значение \lambda. Кроме того, число прореживаемых токенов зависит от примера, от того насколько равномерно или локализовано на конкретных токенах распределен attention. То есть speedup варьируется в зависимости от того, что подается на вход.

Dynamic-ViT

В работе Dynamic-ViT (уже применительно к задачам компьютерного зрения) отбор информативных и неинформативных токенов проводится с помощью обучаемого классификатора. Каждый токен \mathbf{x}_i \in \mathbb{R}^{D}на вход небольшой полносвязной сети \mathrm{MLP}_1(\mathrm{x})выдающей некоторый промежуточный эмбеддинг \mathbf{z}_i^{(\mathrm{local})} \in \mathbb{R}^{D'}меньшей размерности, а затем вычисляется средний \mathbf{z}_iпо всей последовальности - глобальный эмбеддинг \mathbf{z}^{(\mathrm{global})}, и вероятность сохранения токена \pi вычисляется с помощью еще одной полносвязной сети \mathrm{MLP}_2(\mathrm{x})принимающей на вход конкатенацию \mathbf{z}_i^{(\mathrm{local})} и \mathbf{z}^{(\mathrm{global})}, к выходу которой применяется \mathrm{softmax} (хотя можно было и просто сигмоиду):

\mathbf{z}^{(\mathrm{local})} = \mathrm{MLP}_1(\mathrm{\mathbf{x}}) \\ \mathbf{z}^{(\mathrm{global})} = \mathrm{AvgPool}(\mathbf{z}^{(\mathrm{local})}) \\ \mathbf{\pi} = \mathrm{MLP}_2([\mathbf{z}^{(\mathrm{local})}, \mathbf{z}^{(\mathrm{global})}])

Затем часть токенов отбрасывается или сохраняется с вероятностями \pi_{0}и \pi_{1}, cоотвественно. Но каким образом тогда обучать данную конструкцию, если отбрасывание - операция недифференцируемая? И авторы предлагают один из популярных способов сэпмлирования из категориального распределения с возможностью проброса градиента через операцию сэмплирования - GumbelSoftmax.

Функции потерь

Для того, чтобы модель с прунингом токенов эффективнее обучалась и для фиксации заданного уровня прореживания, авторы добавляют 3 дополнительных члена в исходную функцию потерь (кросс-энтропию):

  • Дистилляция с незапруненными токенами исходной модели.

    Ниже \hat{\mathbf{D}}_{i}^{b, S}означает бинарную маску для i-го токена в b-ом блоке модели. Оставшиеся токены по замыслу должны хранить ту же самую информацию, что и при прогонке через слои исходной модели.

  • Дистилляция с логитами исходной модели (в привычном понимании).

  • MSE Loss, устремляющий долю сохраняемых патчей к целевому показателю \rho^{(s)}:

Итоговый лосс является взвешенной суммой всех четырех (включая исходный) лоссов:

На инференсе в s-ом по порядку блоке отбирается [N \rho^{(s)}]входных пачтей с наибольшим значением вероятности сохранения \pi_1.

Вся конструкция обучается в течение 30 эпох (1/10 от времени обучения DeiT в работе Training data-efficient image transformers & distillation through attention).

Кривые Парето и выигрыш в производительности

Всякий уважающий себя метод построения эффективной модели должен быть изображен на кривой Парето священной злобой возвышаясь над дышащими в спину конкурентами. Ниже представлено качество на ImageNet против числа операций с плавающей точкой (теоретических FLOPs).

В разделе 4.1 утверждается, что предложенный подход дает выигрыш в throughput (числе обрабатывамых примеров в единицу времени) в 43-54% на инференсе при замерах на Nvidia RTX 3090 с батчом размера 32, однако для полноты картины немного не хватает кривых зависимости accuracy от throughtput для различных параметров прореживания и моделей.

Занятной и приятной особенностью данного подхода является то, что Dynamic-ViT действительно отбирает токены, содержащие сам интересующий объект.

Иллюстрация прореживания патчей. 
1-ая стадия расположена перед 4-ым блоком, 2-ая перед 7-ым, 3-ая перед 10-ым.
Иллюстрация прореживания патчей. 1-ая стадия расположена перед 4-ым блоком, 2-ая перед 7-ым, 3-ая перед 10-ым.

Ablation study

Далее авторы анализируют важность отдельных компонент и ингредиентов для достижения хорошего качества. Оказывается, что дистилляция как логитов исходной модели, так и токенов, несущественно сказывается на конечном результате.

А вот сам обучаемый критерий отбора довольно важен и работает заметно лучше, чем average pooling в некотором слое (при том же самом числе FLOPs) - Structural Sparsification, и при прореживании токенов с некоторой вероятностью без обучаемого модуля - Static Sparsification (таблица (a)). Кроме того, обучаемый модуль лучше случайного и отсортированного по attention score критерия отбора (таблица (b)).

Имеет место некоторая свобода в плане выбора блоков трансформера, в которых проводится прореживание. Можно прореживать понемного во всех 12 блоках Vision трансформера (в стандартных моделях ViT/DeiT по 12 блоков энкодера и разные модели отличаются размерностью эмбеддингов), можно только в одном, или в некоторых из 12. В данной работе основной результат был получен при прореживании в трех блоках из 12, а именно -4-го, 7-го, 10-го одной и той же доли токенов. При прореживании только в одном блоке при фиксированном числе FLOPs качество заметно проседает, но уже при двух просадка становится незначительной, а прореживание более чем в трех местах практически не улучшает результат.

Сравнение одностадийного и многостадийного прореживания при заданном количестве FLOPs.
Сравнение одностадийного и многостадийного прореживания при заданном количестве FLOPs.

S-ViTE

В работе Chasing Sparsity in Vision Transformers: An End-to-End Exploration авторы одновременно прореживают модель (веса матриц) с помощью неструктурированного и структурированного прунинга и отбирают наиболее релеватные токены.

Для прунинга используется алгоритм RigL разреженного обучения, периодически убирающий веса с наименьшей абсолютной величиной и восстанавливающий веса с наибольшим значением градиента по ним.

Отбор токенов использует примерно тот же подход, что и Dynamic-ViT, но предсказание вероятности сохранения патча осуществляется только за счет самого токена (без конкатенации с \mathbf{z}^{(\mathrm{global})}, как в Dynamic-ViT). Отбор токенов осуществляется перед первым слоем, сразу после превращения патчей картинки в эмбеддинги.

Сам по себе результат прореживания относильно скромен, так как без существенной просадки в качестве удалось, по всей видимости, проредить модели только до уровня 40-60%, где теоретический выигрыш в FLOPs проблематично воспроизвести на практике.

Ниже

  • \mathrm{SViTE} - неструктрурированный прунинг модели;

  • \mathrm{S^2ViTE}- структурированный прунинг модели;

  • \mathrm{S(S^2)+ViTE}- прунинг модели и токенов.

Неструктурированный прунинг DeiT-Tiny(Small)
Неструктурированный прунинг DeiT-Tiny(Small)
Неструктурированный прунинг DeiT-Base и прунинг DeiT-Small c отбором патчей
Неструктурированный прунинг DeiT-Base и прунинг DeiT-Small c отбором патчей

При небольшом прореживании токенов (5-10%) и умеренно прореженной модели авторам удается примерно сохранить исходное качество (а для неструктурированного будто бы немного даже и улучшить). Правда, эта доля и связанный с ней выигрыш невелики. Отбор патчей выглядит довольно случайным.

К другим недостаткам следует еще отнести дороговизну экспериментов (в обьеме вычислений). Полный цикл обучения разреженной модели использует в 2 раза больше эпох (аж 600 штук), чем само обучение с нуля исходных моделей в DeiT, что требует вычислительных ресурсов значительно превосходящих возможности google.colab. Но тем не менее, работа интересна как первая попытка обьединить model и data sparsity применительно к Vision Transformer.

IA-RED^2

В работе Interpretability-Aware Redundancy Reduction for Vision Transformers авторы предложили использовать аж целого RL-агента для отбора патчей.

Более конкретно, происходит следующее : перед каждой группой блоков трансформера (в данной работе авторы делят 12 блоков на 3 группы по 4 блока) вставляется блок отбора патчей. В этом блоке к обучаемому policy токену и токенам патчей применяются линейные слои - по типу key для policy токена и query для патчей, а затем считается скалярное произведение между полученными векторами и к этому скалярному произведению применяют сигмоиду \phi для получения вероятности, оставить или отбросить данный токен.

RL-aгент с policy из двух действий - сохранить токен u_i=1или отбросить токен u_i=0 - принимает решение на основе распределения Бернулли с определенной выше вероятностью I_{ij}. За принятие удачного решения агент получает награду, растущую с числом отброшенным патчей, и некоторую отрицательную награду \tauв противном случае. Выбора \tau определяет trade-off между качеством решаемой задачи или уменьшением количества операций.

Для обучения этого хозяйства используется широко известный (в узких кругах) REINFORCE c advantage, где из награды R(u)для стохастической политики вычитается награда для политики, выбирающей u_i=1 , если I_{ij} > 0.5, иначе u_i=0.

Данная конструкция обучается для каждой группы в течение 30 эпох на ImageNet (т.e суммарно 90). Подобно Dynamic-ViT , \mathrm{IA}-\mathrm{RED}^2умеет фокусироваться на более релеватных для решения задачи регионах изображения. Карта внимания в их подходе выходит более локализованной по сравнению с attention исходной модели.

При замерах на Nvidia-V100 подход позволяет добиться ускорения примерно в ~1.5 раза с просадкой в качестве 0.7-0.9%.

Подход выглядит довольно интересным и эффективным, но к недостаткам стоит отнести необходимость гадать на кофейной гуще, чтобы определить нужное значение \tau для заданного ускорения (ну или провести серию экспериментов).

Evo-ViT

В работе Evo-ViT рассматривая СKA (Center Kernel Alignment) similarity между [CLS] токеном в последнем слое и токенами патчей в промежуточных слоях, авторы делают вывод о том, что представления эволюционируют постепенно и прореживание их в верхних слоях модели может негативно сказываться на качестве модели. Поэтому предлагается не избавляться полностью от патчей с меньшей информативностью, а проводить с ними меньше вычислений.

CKA similarity между токенами в данном блоке и [CLS] токеном в последнем слое
CKA similarity между токенами в данном блоке и [CLS] токеном в последнем слое
Визуализация Grad-Cam и attention c класс токеном в 5 блоке (в середине модели) и в 10-м (ближе к концу). Картинка иллюстирует тот факт, что на промежуточных слоях модель еще не умеет аккуратно выделять целевой обьект.
Визуализация Grad-Cam и attention c класс токеном в 5 блоке (в середине модели) и в 10-м (ближе к концу). Картинка иллюстирует тот факт, что на промежуточных слоях модель еще не умеет аккуратно выделять целевой обьект.

Информативность токена \mathbf{x}_iопределяется по величине attention между \mathbf{x}_i и [CLS] токеном в данном слое трансформера. k токенов с наибольшим значением внимания считаются "информативными" и их обработка происходит точно так же, как для исходной модели. А N-k менее информативных токенов считаются фоном \mathbf{x}_{\mathrm{ph}}и с помощью взвешенной суммы с обучаемыми коэффициентами \phi_{\mathrm{agg}}аггрегируются в один токен "фона" \mathbf{x}_{\mathrm{rep}}

\mathbf{x}_{\mathrm{rep}} = \phi_{agg} (\mathbf{x}_{\mathrm{ph}}) \quad  \phi_{agg} : \mathbb{R}^{(N-k) \times D} \rightarrow \mathbb{R}^{D}

Затем к полученной последовательности применяются слои трансформера точно так же, как и было бы и для исходного трансформера:

Затем к "менее информативным" токенам прибавляется одна и та же добавка - токен "фона" после attention и feedforward слоев.

То есть, менее информативные токены остаются, но не участвуют в вычислениях внутри блока трансформера по отдельности (только через \mathbf{x}_{\mathrm{rep}}

Согласно таблице, Evo-ViT при примерно том же throughput немного выигрывает по качеству у \mathrm{IA}-\mathrm{RED}^2 и Dynamic-ViT. Модель обучается с нуля в течение 300 эпох в отличие от DynamicViT и \mathrm{IA}-\mathrm{RED}^2, использовавших преобученные модели. Разбиение на информативные/неинформативные токены тоже довольно иллюстративно и наглядно:

Из ablation авторы демонстрируют, что предложенная стратегрия разбиеная на фон/обьект работает лучше, чем случайная, или средний attention c другими токенами для данного токена:

Заключение

Таким образом, прунинг токенов - довольно интересное направление для повышения эффективности трансформеров, и его можно комбинировать вместе с другими способами сжатия модели. Кратного увеличения производительности без потери от прореживания токенов без потери качества ожидать не стоит в большинстве случаев, так как значительная часть картинки или текста является полезной для интересующей задачи, но зато небольшой, но выигрыш сравнительно легко реализовать на железе.

Дополнительные ссылки

Комментарии (0)