В данный момент мы находимся на стадии развития глубинного обучения, когда просто увеличивать кластера для более качественного результата становится проблематично. А потому все начинают спускаться на уровень ниже. И одна из таких ниш для улучшения это, конечно, оптимайзеры.
И хотя за ночными разговорами на тихих улочках Санкт-Петербурга мелькают идеи о создании быстрых оптимизаторов второго порядка (Софочка, София, ты как ностальгия…), результат нам нужен здесь и сейчас, а потому в мире, к сожалению, все еще главенствуют Adam и AdamW. Но у них есть ряд проблем, которые исследователи усердно пытаются решить, и одна из них – это память. AdEMAMix предлагает максимально топорное решение данной проблемы просто путем внедрения двух импульсов с различными β. Но от этого оно менее эффективным не становится!
Ну, а если вам покажется, что никакого rocket science тут нет, то милости прошу посетить курс по вычислительной линейной алгебре от Ивана Оселедца, который знает толк во вкусной математике, и как с ее помощью делать больно.
Готовим чаёк и погнали. :)
Интро
В самом сердце парадигмы нейронных сетей лежит задача оптимизации сложных невыпуклых функций потерь с использованием зашумленных оценок градиента. Этот процесс оптимизации обычно выполняется с использованием вариантов стохастического градиентного спуска (SGD) или адаптивных методов, таких как Adam и AdamW. Ключевым компонентом многих из этих алгоритмов итеративной оптимизации является импульс, который, как уже давно было доказано, ускоряет сходимость и часто приводит к решениям с превосходными обобщающими свойствами.
Импульс представляет собой экспоненциальное скользящее среднее (EMAs) предыдущих градиентов и обладает двумя ключевыми преимуществами:
С практической точки зрения, рекурсивная формула EMA позволяет создавать эффективные реализации, которые не требуют сохранения в буфер прошлых градиентов.
С теоретической точки зрения, градиентный спуск с импульсом приводит к оптимальным скоростям сходимости для квадратичных уравнений. Однако эти результаты не гарантируют какой-либо оптимальности для общих неквадратичных случаев.
Круто, мы смогли как-то прикрутить память в эту всю историю. И если для того, чтобы добавить механизмы памяти в роботы, обучаемые с помощью обучения с подкреплением, нам нужны тяжеловесные трансформеры, то тут мы можем обойтись… хранением всего-то двух копий весов всей модели. Мда уж, не особо лучше, не правда ли?!
Но сейчас не об этом. У нас появилась память и у нас возникают конфликты о том, на чем делать упор: на краткосрочном знании или на долгосрочном? Ну вообще хочется и то, и то, но если уж надо выбирать, то эмпирически предпочитают значения порядка 0.9 для β.
Мы наблюдаем, что один ЕМА не может одновременно придать значительный вес недавним градиентам и незначительный вес более старым градиентам. Однако линейная комбинация между «быстро меняющимся» (например, β = 0,9) и «медленно меняющимся» (например, β = 0,9999) EMA позволяет итерации извлекать выгоду:
Значительное ускорение, обеспечиваемое большим (медленно меняющимся) импульсом := постоянно копаем в одном месте.
Все еще реагируем на небольшие изменения в структуре потерь (быстро меняющейся).
Все это калифорнийские исследователи заворачивают в следующую очаровательную мысль:
Хотя изменить направление медленного ленивого импульса довольно сложно, любое ортогональное уточнение все еще вносит высокий вклад, что способствует быстрому продвижению по извилистым ландшафтам, похожим на каньоны.
Вот так вот, не зря вам дают курсы по философии в университете!
Непосредственно AdEMAMix
В базовом AdamW при β = 0.9 мы имеем, что 50% массы импульса приходится на последние 6 градиентов. При этом в реальности значение β редко превышает ∼ 0.9. В наших экспериментах с AdamW увеличение β существенно ухудшало производительность.
Давайте взглянем на проблему под другим углом. Мы поняли, что старые градиенты мы помним плохо, но насколько плохо?
Очень плохо!
Давайте посмотрим на то, какой вес имеет каждый градиент на итерации 10000 в базовом AdamW.
Небольшое значение β (например, 0,9) придаст большой вес ближайшему прошлому и незначительный - более ранним временным шагам. Напротив, высокое значение β (например, 0,9999) придает относительно равномерный вес всем градиентам, но перестает быть достаточно чувствительным. Никакое значение β не может одновременно придать высокий вес ближайшему прошлому и существенный вес очень старым временным интервалам.
Разработанный AdEMAMix намного лучше помнит прошлое. Это позволяет использовать гораздо большие значения β, например, 0,9999. Для сравнения, при β=0,9999, половина массы приходится уже на 6930 последних временных шагов.
Видим, что по сравнению с AdamW добавляется еще один идентичный импульс. Байес и соответствующая поправка, к сожалению, вынесены за пределы данной прикладной статьи.
В наших экспериментах, хотя значения β1 , β2 остаются аналогичными значениям уравнения AdamW, мы часто используем β3 = 0,9999. Мы обнаружили, что α ∈ [4, 10] хорошо работает на практике.
Устранение проблем с ранним обучением
Опять же, сама по себе фундаментальная идея не несет в себе никаких кардинальных прорывов, но тот факт, что авторы много на что посмотрели и много что учли даже в самой первой итерации своей научной работы делает ее намного более привлекательной. И важное место тут играет работа с warm-up.
Итак, у нас два гиперпараметра. Допустим, что мы хотим линейный рост для α на разогреве. Но как быть с β3? Напомню, что самоцель этого исследования не в изучении различных β3. Мы хотим увеличивать thalf, обозначающий, на какое число последних градиентов приходится половина массы. И если мы хотим уже линейного роста именно на эту компоненту, то warm-up выглядит как-то так:
Computational overheads
Компьют, по сравнению с forward-backward, увеличивается незначительно. Более того, если мы говорим о масштабном распределенном обучении, то AdEMAMix не требует больше коммуникаций по сравнению с Adam, а значит overhead уменьшается еще сильнее в случаях, когда bottleneck заключается в переносе данных. Да-да, уважаемые RAG-мастера, не FlashAttention-ом едины!
В этом плане намного более острая проблема – выделение памяти под m2. Но мы его всегда можем распределить между нодами, например, с помощью Fully-Sharded-Data-Paralellism (FSDP).
Эксперименты
Computer Vision
Тренируем классический ViT (24M/86M) на ImageNet-1k и ImageNet-21k. Отмечу, что для ViT ключевую роль играют аугментации и в данном случае авторы использовали mixup.
AdEMAMix для различных capacity/data ratio. Мы используем три различных сочетания объема модели и данных. При таких настройках становится сложно превзойти базовые показатели при очень высоких capacity/data ratios. Эти эксперименты показывают, что AdEMAMix, по-видимому, лучше всего работает в сценариях с большими объемами данных. В целом, мы обнаружили, что AdEMAMix стабильно снижает потери при обучении более эффективно, чем AdamW.
Авторы описывают это следующим образом: Когда снижение потерь при обучении коррелирует с уменьшением потерь при тестировании, AdEMAMix превосходит базовый уровень AdamW. По-моему это скорее следствие свойства, связанного с capacity/data ratio.
LLM
Обучаем базовый трансформер на 1024 токенах на датасете RedPajama v2.
Претрейн от AdamW против обучения AdEMAMix с нуля. Просто добавляем компоненту при переключении с AdamW на AdEMAMix компоненту m2 = 0 и далее работаем уже с ней как обычно. Примечательно, что авторы экспериментально выявили, что в таких случаях warm-up для m2 не нужен. Как по мне, это скорее в минус авторам, поскольку является контринтуитивным в и без того крайне не теоретической статье.
Чем раньше вы переключаетесь, тем больше выигрыш при уменьшающейся отдаче. Это указывает на то, что улучшение AdEMAMix не может быть связано исключительно с динамикой обучения в начале, скорее, динамика обучения, в том числе в конце, играет важную роль. Это также подтверждается обратным экспериментом, в ходе которого в середине обучения переключение с AdEMAMix на AdamW показало снижение производительности.
Модели AdEMAMix забывают обучающие данные медленнее. В попытке понять причину улучшений AdEMAMix по сравнению с AdamW, мы изучили, насколько быстро забывается тренировочный батч после его использования во время тренировки. Мы сосредоточились на анализе одного конкретного батча B из тестового датасета. Мы добавляем его на одну итерацию tB, а потом снова убираем из обучения.
Направления для будущих исследований
Авторы не обозначили направления для будущих исследований, видимо пожелав оставить эту статью как вещь в себе. Тем не менее, идея весьма прикладная, и было бы интересно посмотреть, как она будет себя вести в более диковинных доменах, таких как генеративные модели или тот же вышеупомянутый Reinforcement Learning.
В то же время, разбивка оптимайзера на долгую и короткую память приводит к желанию делать так не только на макроуровне, но и на микро, например, как-то соответствующе упорядочивая веса в лосс-функции для различных по номеру слоев аттеншена.
Перевод и комментарии Михаила Трегубова сделаны специально для Хабра и для телеграм-канала Контур.AI.