Во времена повсеместного заполонения трансформерами, которые пожирали в себя все больше и больше кремниевых чипов; когда казалось, что лучше уже не будет и за каждый новый токен нужно платить в квадрате от предыдущих, в эту холодную зимнюю пору появилась она - Мамба.
Во времена повсеместного заполонения трансформерами, которые пожирали в себя все больше и больше кремниевых чипов; когда казалось, что лучше уже не будет и за каждый новый токен нужно платить в квадрате от предыдущих, в эту холодную зимнюю пору появилась она - Мамба.

Трансформеры произвели настоящий фурор в области Deep Learning и демонстрируют выдающуюся эффективность. Однако у них существует серьезное ограничение по длине входной последовательности (контекста) из-за квадратичной вычислительной сложности. Большинство моделей работают с контекстом длиной менее 10 000, что делает их малоприменимыми в задачах с большими объемами входных данных. И хотя ходили различные слухи, было бы странно увидеть сильный искусственный интеллект, который можно за пару минут заболтать до беспамятства.

Вычислительная сложность

Длина контекста L

Пропускная способность

Трансформер

L^2

\sim 10^3

x

Мамба

L

\sim 10^6

5x

Мамба основывается на принципиально другом подходе - SSM, который, хоть и сильно старше трансформера, в контексте глубокого обучения не показывал достаточной эффективности, особенно в качестве языковой модели. Мамба имеет линейную вычислительную зависимость и в 5 раз выше пропускную способность, чем у трансформеров. Авторы проверили свое детище на серии моделей только до 2.8 млрд. параметров, что еще мало похоже на Chatgpt, но уже утерли нос текущим топам языковых моделей в своей весовой категории. Длина контекста при этом была выбрана как у соответствующего трансформера, так что контекст размером в миллион был проверен только на простых синтетических тестах, что, однако, тоже немаловажно, так как ни трансформеры, ни свертки с этими тестами не справились. В этой статье мы детально рассмотрим всю математику новой архитектуры, заметая под ковер преимущества и недостатки.

Линейная модель пространства состояний (SSM)

Непрерывный случай

Модель пространства состояний, на которой построена вся идея, в непрерывном виде выглядит так:

\begin{array}{lcl} \boldsymbol{\dot h}(t) = \boldsymbol{Ah}(t)+\boldsymbol{Bx}(t)\\ \boldsymbol{y}(t) = \boldsymbol{Ch}(t)+\boldsymbol{Dx}(t) \end{array}

С помощью такой модели можно записать дифференциальное уравнение N-го порядка как N уравнений первого порядка в матричном виде, где \boldsymbol{h}(t) - вектор состояния, содержащий производные по возрастанию порядка от 0 до N-1, x(t) - входной сигнал, y(t) - выходной сигнал. Таким образом, сложность описываемой системы нелинейно растет от N.

По-другому на модель можно смотреть так - одномерный сигнал x(t) отображается в N- мерное латентное состояние \boldsymbol{h}(t), а затем проецируется в выходной сигнал y(t).

Чтобы взаимодействовать с моделью в векторном виде конечной размерности, нужно дискретизировать ее.

Дискретизация

Умножим первое уравнение на e^{-\boldsymbol{A}t}:

e^{-\boldsymbol{A}t} \boldsymbol{\dot h}(t) - e^{-\boldsymbol{A}t} \boldsymbol{Ah}(t) = e^{-\boldsymbol{A}t} \boldsymbol{Bx}(t)\\ \frac{d}{dt} (e^{-\boldsymbol{A}t} \boldsymbol{h}(t)) = e^{-\boldsymbol{A}t} \boldsymbol{Bx}(t)

Тогда общее решение для непрерывной модели будет:

\boldsymbol{h}(t) = e^{\boldsymbol{A}t} \boldsymbol{h}(0) + \int_0^t{e^{\boldsymbol{A}(t-\tau)}} \boldsymbol{Bx}(\tau) d\tau

Шаг дискретизации \Delta \Rightarrow \boldsymbol{x_k} = \boldsymbol{x}(k \Delta), \boldsymbol{h_k} = \boldsymbol{h}(k \Delta), \boldsymbol{y_k} = \boldsymbol{y}(k\Delta):

\boldsymbol{h_k} = e^{\boldsymbol{A}k\Delta} \boldsymbol{h}(0) + \int_0^{k\Delta}{e^{\boldsymbol{A}(k\Delta-\tau)}} \boldsymbol{Bx}(\tau) d\tau \boldsymbol{h_{k+1}} = e^{\boldsymbol{A}\Delta}  \left[ e^{\boldsymbol{A}k\Delta} \boldsymbol{h}(0)+ \int_0^{k\Delta}{e^{\boldsymbol{A}(k\Delta-\tau)}} \boldsymbol{Bx}(\tau) d\tau\right] + \int_{k\Delta}^{(k+1)\Delta} e^{\boldsymbol{A}\left[(k+1)\Delta-\tau\right]} \boldsymbol{Bx}(\tau) d\tau =

Подставляем выражение для k и учитываем, что x=const внутри интервала Δ:

= e^{\boldsymbol{A}\Delta} \boldsymbol{h_k} + \left[ \int_0^{\Delta} e^{\boldsymbol{A}\nu} d\nu \right] \boldsymbol{Bx_k} =  e^{\boldsymbol{A}\Delta} \boldsymbol{h_k} + \frac{1}{\boldsymbol{A}} (e^{\boldsymbol{A}\Delta}-\boldsymbol{I})\boldsymbol{Bx_k}

Таким образом получаем дискретную SSM модель:

\begin{array}{lcl} \boldsymbol{h_k} = \overline{\boldsymbol{A}} \boldsymbol{h_{k-1}} + \overline{\boldsymbol{B}} \boldsymbol{x_k}\\  \boldsymbol{y_k} = \boldsymbol{\overline{C} h_k} + \overline{\boldsymbol{D}} \boldsymbol{x_k}\\  \\ \boldsymbol{\overline{A}} = e^{\boldsymbol{A}\Delta}\\ \boldsymbol{\overline{B}} = \frac{1}{\boldsymbol{A}} (e^{\boldsymbol{A}\Delta}-\boldsymbol{I})\boldsymbol{B} \approx \Delta \boldsymbol{B}\\ \boldsymbol{\overline{C}} = \boldsymbol{C}\\ \boldsymbol{\overline{D}} = \boldsymbol{D} \end{array}

Если в параметре \boldsymbol{\overline{B}} разложить экспоненту до первого порядка, происходит очень удачное упрощение, поэтому авторы пренебрегают точностью этого, не самого важного, параметра в пользу уменьшения вычислений:

\boldsymbol{x_k}- вход модели, \boldsymbol{y_k}- выход модели, \boldsymbol{h_k}- скрытое состояние или память модели,  \boldsymbol{\overline{A}} - главный из параметров, отвечает за то, как мы преобразуем память с течением времени - или параметр запоминания, \boldsymbol{\overline{B}} - параметр преобразования входа, \boldsymbol{\overline{C}} - параметр преобразования выхода, \boldsymbol{\overline{D}} - своего рода skip connection или skip параметр, \Delta - шаг дискретизации.

В простейшем случае имеем такие размерности:

\boldsymbol{\overline{A}} (N, N), \; \boldsymbol{\overline{B}} (N, 1), \; \boldsymbol{\overline{C}} (1, N), \; \boldsymbol{\overline{D}} (1, 1), \; \boldsymbol{x_k} (1, 1), \; \boldsymbol{y_k} (1, 1), \; \boldsymbol{h_k}(N, 1), \; \Delta = const

Таким образом, мы получили простую рекуррентную систему, сохранив при этом всю математическую силу пространства состояний. Интуицию по стандартной SSM можно получить здесь.

Селективная SSM

Отличительная особенность Мамбы от предыдущих глубоких SSM в этой ветке эволюции состоит в добавлении селективности. Иначе говоря, мы хотим, чтобы в скрытое состояние \boldsymbol{h_k} попадали только значимые из всех \boldsymbol{h_i} [i \lt k], а остальные отсеивались.

Обозначения

N - размерность скрытого состояния
L - длина входной последовательности
b - размер батча
d - глубина модели
E=2 - коэффициент расширения
d_{in} = Ed - глубина модели в Mamba блоке
A,B,C,D - параметры SSM
\Delta - размер шага дискретизации
\Delta_R = \frac{d}{16} - размерность проекции

Параметризация

Итак, чтобы модель могла акцентировать внимание на определенных элементах входной последовательности, сделаем три параметра зависимыми от входа:

\boldsymbol{B} = \boldsymbol{xW_B}, \; \boldsymbol{C} = \boldsymbol{xW_C}, \; \Delta = Softplus[\boldsymbol{xW_{\Delta1} W_{\Delta2}}+\Delta_{bias}]

Параметр \Delta управляет балансом между тем, насколько сильно фокусироваться или игнорировать текущий входной сигнал. Большой \Delta сбрасывает состояние \boldsymbol{h_k} и фокусируется на текущий вход \boldsymbol{x_k}, в то время как маленький \Delta сохраняет состояние и игнорирует текущий вход. Параметры \boldsymbol{B} и \boldsymbol{C} позволяют более тонко контролировать, вводить ли вход \boldsymbol{x_k} в состояние \boldsymbol{h_k} или состояние в выход \boldsymbol{y_k}.

\boldsymbol{A} и \boldsymbol{D} остаются независимыми от входа, но сами становятся параметрами.\boldsymbol{A} будем хранить в логарифмической форме \boldsymbol{A_{log}}(см. S4D инициализацию):

\boldsymbol{A} = -\exp^{\boldsymbol{A_{log}}}

Здесь и далее все экспоненты и логарифмы поэлементные. Таким образом, обучаемые параметры для селективного блока:

\boldsymbol{A_{log}}(d_{in}, N), \boldsymbol{W_{B}}(d_{in}, N), \boldsymbol{W_{C}}(d_{in}, N), \boldsymbol{D}(d_{in}), \boldsymbol{W_{\Delta1}}(d_{in}, \Delta_R), \boldsymbol{W_{\Delta2}}(\Delta_R, d_{in}), \Delta_{bias}(d_{in})

Введем сразу остальные параметры, которые будут использоваться в архитектуре:

\boldsymbol{W_{in}}(d, 2d_{in}), \boldsymbol{W_{out}}(d_{in}, d), \boldsymbol{W_{emb}}(vocab\:size, d)=\boldsymbol{W_{vocab}}^T, \boldsymbol{W_{conv1d}}(d_{in}, 1, K)

Инициализация параметров

Каждый из вышеописанных параметров инициализируется по своему:

\boldsymbol{A_{log}} = \ln\begin{pmatrix}1 & 2 & 3 & ... & N\\1 & 2 & 3 & ... & N\\  &   &...\end{pmatrix}, \; \boldsymbol{D} = \overline{\boldsymbol{1}}\Delta_{bias} = Softplus^{-1}\left[Uniform(10^{-3}, 10^{-1}) \right]

Параметр \boldsymbol{W_{conv1d}} задается стандартной инициализацией conv1d слоя с bias=True, тогда как все оставшиеся веса задаются Linear слоем с bias=False.

Инференс селективной SSM с аппаратным ускорением

Рисунок 1: Устройство Selective SSM блока (Mamba).
Рисунок 1: Устройство Selective SSM блока (Mamba).
\boldsymbol{x}(b, L, d_{in}), \boldsymbol{h_t}(b, d_{in}, N) \rightarrow \boldsymbol{y}(b, L, d_{in})

Для ускорения вычислений авторы разделили инференс selective SSM блока на два этапа - сначала подготовка (трехмерных массивов) на обычной (медленной) памяти видеокарты, затем дискретизация и вычисление рекурсии (четырехмерных массивов) в быстрой памяти видеокарты:

1) Подготовка (GPU HBM):

Возвращение \boldsymbol{A} в человеческий вид и проекция входа:

\begin{array}{ccc}\boldsymbol{A}(d_{in}, N) = -\exp^{\boldsymbol{A_{log}}}\\\boldsymbol{B}(b, L, N) = \boldsymbol{xW_B}\\\boldsymbol{C}(b, L, N) = \boldsymbol{xW_C}\\\Delta(b, L, d_{in}) = Softplus[\boldsymbol{xW_{\Delta1} W_{\Delta2}}+\Delta_{bias}]\end{array}

2) Selective scan (GPU SRAM):

Инициализация скрытого состояния:

\boldsymbol{h_{-1}} = \overline{\boldsymbol{0}}

Дискретизация:

\begin{array}{ccc}\boldsymbol{\overline{A}}(b, L, d_{in}, N) = e^{\Delta \boldsymbol{A}}\\\boldsymbol{\overline{B}x}(b, L, d_{in}, N) = \Delta \boldsymbol{Bx}\end{array}

В цикле по t вдоль оси L (по каждому токену) пересчет всех скрытых состояний \boldsymbol{h} и соответствующих им выходов \boldsymbol{y}:

\begin{array}{lcl}\boldsymbol{h_t} = \overline{\boldsymbol{A_t}} \boldsymbol{h_{t-1}} + \boldsymbol{(\overline{B} x)_t}\\ \boldsymbol{y_t} = \boldsymbol{C_t h_t} + \boldsymbol{Dx_t}\\ \end{array}

Архитектура Mamba

Рисунок 2: Устройство архитектуры Mamba
Рисунок 2: Устройство архитектуры Mamba

Mamba

Устройство архитектуры не сильно отличается от трансформерной:

  1. На входе имеем последовательность длиной L, которая может представлять из себя хоть текстовые токены, хоть элементы изображения.

  2. Векторизуем элементы последовательности матрицой эмбеддингов \boldsymbol{W_{emb}}, получая тот самый \boldsymbol{x}(b, L, d).

  3. Прогоняем его через n_{layers} мамба-слоев, сохраняя при этом размерность.

  4. Возвращаем размерность (b, L, vocab\;size) матричным умножением на \boldsymbol{W_{vocab}=W_{emb}^T} - той же матрицей, что и на входе.

  5. И, наконец, получаем вероятности для каждого токена по словарю.

Mamba Layer

Слой Мамба представляет из себя:

  1. Нормализацию по слою

  2. Непосредственно сам Мамба блок

  3. Skip connection

Mamba Block

Принцип блока основан на gated MLP, который при помощи дополнительной ветки с линейным слоем, активацией и последующим Element-wise умножением может управлять потоком информации основной ветки, определяя какая часть должна быть сохранена, а какая подавлена.

По основной же ветке идет, так называемый, inverted bottleneck:

  • Расширение (\boldsymbol{W_{in}}) \rightarrowdepthwise convolution (в данном случае одномерная) \rightarrow проекция (\boldsymbol{W_{out}}),

с добавлением активации и основного блока - selective SSM из предыдущего раздела.

Заключение

Модель Мамба успешно унаследовала ключевые характеристики от трансформеров, такие как внимание к контексту и мультимодальность, открывая при этом новые перспективы для будущего развития. Способность Мамба эффективно работать в различных доменах, особенно в модальностях, где требуется учет большого объема контекста, таких как геномика, аудио и видео, выделяет ее среди передовых разработок.

Хотя данный обзор сосредоточен исключительно на математических аспектах нового подхода, результаты показывают, что Мамба может стать мощным кандидатом на роль нового общего мультимодального бэкбона. Подробности про синтетические тесты, результаты и сравнения в областях LLM, аудио и геномики доступны в оригинальной статье (ссылка).

Интересного нам 2024 года!

Материалы

Комментарии (19)


  1. rajce
    14.01.2024 12:04
    +4

    Спасибо за глубокий анализ Мамбы - это действительно открывает новые горизонты в области Deep Learning.


  1. Arastas
    14.01.2024 12:04
    +4

    Давно ждал разбора SSM на Хабре. Спасибо за статью, но, к сожалению, есть что улучшить. Первая часть, до SSM, это мат. аппарат, которому лет 50-60 уже, знакомый любому, у кого был курс Введение в ТАУ. А начиная с SSM изложение комкается, начинают плыть обозначения. Жаль. Может попробуете дополнить, с фокусом именно на второй части?


    1. kashokhin Автор
      14.01.2024 12:04

      Спасибо за обратную связь, подумаю как исправить. Про обозначения не совсем понял.


      1. Arastas
        14.01.2024 12:04

        Например, вы пишите, что A_{log} это матрица d_{in} на N, так? И при этом A = -\exp(A_{log}). Для прямоугольной матрицы экспонента, видимо, должна быть поэлементная, а не матричная? Но в первом разделе экспоненты матричные, лучше явно написать, что обозначение меняется.

        Но тогда получается, что A это тоже матрица d_{in} на N, правильно? А в первом разделе это была матрица N на N, квадратная, как и положено матрице состояний. Что я упускаю?


        1. kashokhin Автор
          14.01.2024 12:04

          С экспонентой действительно можно запутаться, подправил. Хотя в контексте DL обычно ясно, что логарифмическая форма значений параметра используется для лучшей сходимости при обучении.

          После параметризации A - больше не матрица, а тензор параметров, который задают авторы по своему усмотрению. Идея квадратной матрицы заключалась в том, чтобы отображать скрытое состояние \boldsymbol{h}в него же обратно (N,N)@(N,1)=(N,1). Здесь логика сохраняется, но уже с дискретным параметром \overline{A_t}, который при поэлементном умножении в главном цикле Selective scan также сохраняет размерность для \boldsymbol{h}.

          Сложнее, например, вопрос обстоит с вычислением \overline{A}, которое происходит так:(b, L, d_{in})(d_{in}, N) \rightarrow (b, L, d_{in}, N), что не является ни поэлементным, ни матричным умножением. Однако загромождать эти моменты пояснениями не стал, так как это вопрос уже технический.


          1. Arastas
            14.01.2024 12:04
            +1

            Ну так это и есть поплывшие обозначения. Вы меняете размерность объекта, обозначенного буквой A, явно это не проговаривая. :)

            Здесь и далее все экспоненты и логарифмы поэлементные.

            А в следующем разделе формула \overline{A}(b, L, d_{in}, N) = e^{\Delta A}. Это точно поэлементная экспонента?


            1. kashokhin Автор
              14.01.2024 12:04

              Здесь никакой ошибки нет. Я описал математическую модель в классическом виде, указав дефолтные размерности для понимания. Затем в новой главе обозначил переход к глубокому обучению, вводя уже параметры, действительно, под старыми обозначениями. В этом и смысл аналогии перехода. Это статья по DL, поэтому и контекст соответствующий. Подскажите, где вы в DL видели матричную экспоненту? :)


              1. Arastas
                14.01.2024 12:04

                Я и не говорил, что у вас ошибка. Я сказал, что у вас поплыли обозначения. :)


                1. kashokhin Автор
                  14.01.2024 12:04

                  Обозначения не поплыли, а были явно переинициализированы в новом контексте с указанием новых размерностей. Словами также проговорено.


          1. Arastas
            14.01.2024 12:04

            Здесь логика сохраняется, но уже с дискретным параметром \overline{A_t}

            Сложнее, например, вопрос обстоит с вычислением \overline{A}, которое происходит так: (b, L, d_{in})(d_{in}, N) \rightarrow (b, L, d_{in}, N), что не является ни поэлементным, ни матричным умножением.

            Как \overline{A_t} получается из \overline{A}? На сколько я понимаю, матрица \overline{A_t} должна быть квадратной N на N?


            1. kashokhin Автор
              14.01.2024 12:04


              Как уже сказано здесь, индекс t указывает на индекс элемента тензора вдоль оси L:

              В цикле по t вдоль оси L (по каждому токену) пересчет всех скрытых состояний \boldsymbol{h} и соответствующих им выходов \boldsymbol{y}:

              Соответственно, \overline{A_t} имеет размерность (b, d_{in}, N).


              1. Arastas
                14.01.2024 12:04

                Тогда я снова теряюсь в обозначениях. Как мне читать h_{t+1} = \overline{A_t} h_t? Это же не тензорное произведение? Недоумение тем сильнее, что у вас точно такое же выражение записано ранее для SSM, только с h_k вместо h_t...

                В итоге пришлось идти читать оригинальную статью на ArXiV, чтобы понять, что вы имели в виду. Спасибо за мотивацию! :)


  1. ExternalWayfarer
    14.01.2024 12:04
    +19

    Не знаю, при чем тут приложение для знакомств...


    1. iggr63
      14.01.2024 12:04
      +1

      Оба приложения характеризуются большой длиной контекста:)?


  1. Jeshua
    14.01.2024 12:04
    +1

    Присоединюсь к благодарностям и к просьбам раскрыть тему, если можно, с практическим примером.


    1. kashokhin Автор
      14.01.2024 12:04

      Уже в процессе, спасибо!


  1. SatCat
    14.01.2024 12:04

    ну и ссылочку на код можно вставить https://github.com/state-spaces/mamba


    1. kashokhin Автор
      14.01.2024 12:04

      Действительно, спасибо!


  1. Juranja
    14.01.2024 12:04

    Я тут один такой , который читает и ни фига не понимает. Статья класс , хотя для чего она мне в жизни пригодится не знаю , но все же читаю . Теперь я точно знаю , что мамба это не только змея и сайт знакомств!