Смотрели итоги прошедшего ICLR? Меня заинтересовала довольно провокационная, на первый взгляд, статья от Эплов — ParaRNN. Казалось бы, параллельность РНН — это их главный недостаток, благодаря которому их заменили трансформеры (в большинстве задач).

Так вот, давайте разберемся со всем, на максимально низком уровне, если знаете, что такое RNN и производная — то эта статья для вас.

1. Алгоритм DEER

DEER = Deep Equilibrium Evaluation of Recurrence (Lim et al., 2024). Базовый алгоритм, на котором строится ParaRNN.

1.1. Постановка как задача нахождения корня

Пусть у нас есть обыкновенная RNN с переходной функцией f: \mathbb{R}^D \to \mathbb{R}^D, начальным состоянием \mathbf{s}_0 и неизвестными состояниями \mathbf{s}_1, \ldots, \mathbf{s}_T. Введем остаток (residual):

\mathbf{r}(\mathbf{s}_{1:T}) := [\mathbf{s}_1 - f(\mathbf{s}_0),\ \mathbf{s}_2 - f(\mathbf{s}_1),\ \ldots,\ \mathbf{s}_T - f(\mathbf{s}_{T-1})] \in \mathbb{R}^{T \times D}

Истинная траектория \mathbf{s}^*_1, \ldots, \mathbf{s}^*_T - это единственное решение уравнения:

\mathbf{r}(\mathbf{s}^*_{1:T}) = \mathbf{0}

Когда говорят «применить RNN к последовательности», имеют в виду стандартную процедуру: взять начальное состояние \mathbf{s}_0, применить переходную функцию f, получить \mathbf{s}_1, потом еще раз применить f, получить \mathbf{s}_2, и так далее:

\mathbf{s}_1 = f(\mathbf{s}_0), \quad \mathbf{s}_2 = f(\mathbf{s}_1), \quad \ldots, \quad \mathbf{s}_T = f(\mathbf{s}_{T-1})

Соответственно, получается, что \mathbf{r} - вектор, у которого все элементы равны 0, опять же потому что при соблюдении рекуррентности \mathbf{s}_1 = f(\mathbf{s}_0) и \mathbf{s}_1 - f(\mathbf{s}_0) = 0.


1.2. Итерации Ньютона

Соответственно дальше, необходимо найти решение уравнения \mathbf{r}(\mathbf{s}) = 0, или в полном случае — вектор, решающий систему уравнений. Но для начала разберемся со скалярным случаем.

Скалярный случай: одно уравнение от одной переменной

Пусть у нас есть гладкая функция r: \mathbb{R} \to \mathbb{R} и мы хотим найти такое s^*, что r(s^*) = 0. Геометрически — найти точку, где график функции пересекает ось абсцисс.

Идея метода Ньютона строится на простой мысли: в малой окрестности точки гладкая функция почти неотличима от своей касательной. Если мы стоим в текущем приближении s^{(i)} (которое, в общем случае, не корень — там r(s^{(i)}) \neq 0), мы можем сделать вид, что r — это ее касательная в этой точке, и для такой линейной функции легко аналитически найти, где она пересекает ось.

Касательная к r в точке s^{(i)} — это первое слагаемое разложения Тейлора:

r(s) \approx r(s^{(i)}) + r'(s^{(i)})\,(s - s^{(i)})

Вспомним: Разложение Тейлора — это способ приблизить любую гладкую функцию вблизи точки s_0 многочленом:

r(s) = r(s_0) + r'(s_0)(s - s_0) + \frac{r''(s_0)}{2!}(s - s_0)^2 + \frac{r'''(s_0)}{3!}(s - s_0)^3 + \ldots

где каждое следующее слагаемое уточняет приближение, добавляя информацию о все более тонкой особенности формы функции (наклон, кривизна, и т.д.). Логический смысл такой: если функция гладкая, то ее поведение в окрестности точки полностью закодировано в значениях ее производных в этой одной точке — измерив несколько чисел в s_0, мы можем восстановить значения функции рядом. Делитель k! возникает естественно из требования, чтобы в точке s_0 совпадали все производные многочлена и самой функции (он сокращается с факториалом, выскакивающим при k-кратном дифференцировании (s - s_0)^k).


Приравниваем эту линейную аппроксимацию к нулю и находим, где она пересекает ось:

r(s^{(i)}) + r'(s^{(i)})\,(s - s^{(i)}) = 0

Решаем относительно s :

s = s^{(i)} - \frac{r(s^{(i)})}{r'(s^{(i)})}

Это и объявляем следующим приближением:

s^{(i+1)} = s^{(i)} - \frac{r(s^{(i)})}{r'(s^{(i)})}

Тут показан графически шаг s^{(i)} \to s^{(i+1)}, и на графике видно, что корень уравнения (нас интересует пересечение функции с осью абсцисс) поменялся с 2 до 1, что показывает улучшение, так как эталонное значение — 0.

Можно переписать это через приращение \Delta s^{(i+1)} := s^{(i+1)} - s^{(i)}, что окажется удобнее при обобщении:

r'(s^{(i)})\,\Delta s^{(i+1)} = -r(s^{(i)})

То есть «найти такое приращение, чтобы линейная поправка r' \cdot \Delta s скомпенсировала текущий остаток r».

Многомерный случай: N уравнений от N переменных

Теперь обобщаем. Вместо одной функции от одной переменной \mathbf{r}: \mathbb{R}^N \to \mathbb{R}^N — векторнозначная функция векторного аргумента, и ищем такой вектор \mathbf{s}^* \in \mathbb{R}^N, что \mathbf{r}(\mathbf{s}^*) = \mathbf{0}.

Логика остается такой же, меняются только объекты или их размерности:

Скаляр

Вектор

функция r(s)

вектор-функция \mathbf{r}(\mathbf{s})

производная r'(s) — число

якобиан J(\mathbf{s}) - матрица N \times N

касательная (прямая)

касательная гиперплоскость

деление на r'

умножение на J^{-1} (то есть, решение линейной системы)

Где J(\mathbf{s}) = \dfrac{\partial \mathbf{r}}{\partial \mathbf{s}} — якобиан, многомерный аналог обычной производной (об этом подробнее ниже).

Якобиан J(\mathbf{s}) — это просто матрица всех частных производных: в позиции (i, j) стоит \partial r_i / \partial s_j. Он играет роль производной — показывает, как малое изменение \mathbf{s} влияет на \mathbf{r} в линейном приближении.

Линеаризация \mathbf{r} вокруг точки \mathbf{s}^{(i)}:

\mathbf{r}(\mathbf{s}) \approx \mathbf{r}(\mathbf{s}^{(i)}) + J(\mathbf{s}^{(i)})\,(\mathbf{s} - \mathbf{s}^{(i)})

Приравниваем линейную аппроксимацию к нулевому вектору:

\mathbf{r}(\mathbf{s}^{(i)}) + J(\mathbf{s}^{(i)})\,(\mathbf{s} - \mathbf{s}^{(i)}) = \mathbf{0}

И обозначив приращение \Delta \mathbf{s}^{(i+1)} := \mathbf{s} - \mathbf{s}^{(i)}, получаем линейную систему относительно \Delta \mathbf{s}^{(i+1)}:

J(\mathbf{s}^{(i)})\,\Delta\mathbf{s}^{(i+1)} = -\mathbf{r}(\mathbf{s}^{(i)})

Решив систему и получив \Delta \mathbf{s}^{(i+1)}, обновляем приближение:

\mathbf{s}^{(i+1)} = \mathbf{s}^{(i)} + \Delta\mathbf{s}^{(i+1)}

Также можно записать более компактно через обратную матрицу:

\mathbf{s}^{(i+1)} = \mathbf{s}^{(i)} - J(\mathbf{s}^{(i)})^{-1}\,\mathbf{r}(\mathbf{s}^{(i)})

— это та же самая формула, просто короче. Запись с J^{-1} — чисто нотационная: на практике обратную матрицу никто никогда не вычисляет, потому что это и дорого, и численно неустойчиво. Вместо этого решают систему J \cdot \Delta\mathbf{s} = -\mathbf{r} напрямую — например, через LU-разложение, или через прямую подстановку, если J имеет специальную структуру (что и происходит у нас).

1.3. Применение к нашей задаче с RNN

В случае RNN все ровно по этому шаблону, только размерности конкретные:

  • \mathbf{s} = (\mathbf{s}_1, \ldots, \mathbf{s}_T) \in \mathbb{R}^{TD} - все скрытые состояния, склеенные в один длинный вектор длины TD.

  • \mathbf{r}: \mathbb{R}^{TD} \to \mathbb{R}^{TD} — вектор всех одношаговых остатков, той же длины.

  • J(\mathbf{s}) \in \mathbb{R}^{TD \times TD} — якобиан остатка по состоянию.

Применяем тот же ньютоновский шаг:

J(\mathbf{s}^{(i)})\,\Delta\mathbf{s}^{(i+1)} = -\mathbf{r}(\mathbf{s}^{(i)})

И тут возникает вопрос: «А разве решить линейную систему размера TD \times TD — это не та же самая последовательная задача? Где здесь параллелизация?»

Если бы J была произвольной плотной матрицей, то да — наивно решение стоило бы O((TD)^3), и никакой выгоды бы не было. Но J не произвольная. Из-за марковости RNN (каждый шаг f видит только предыдущее состояние \mathbf{s}_{t-1}, а не всю историю) в якобиане подавляющее большинство блоков — нули. Конкретно: в блочной строке номер t ненулевые элементы есть только в столбцах t и t-1. Получается блочно-бидиагональная структура:

J(\mathbf{s}) = \begin{pmatrix}I_D & 0 & 0 & \cdots & 0 \\-\frac{\partial f}{\partial \mathbf{s}}(\mathbf{s}_1) & I_D & 0 & \cdots & 0 \\0 & -\frac{\partial f}{\partial \mathbf{s}}(\mathbf{s}_2) & I_D & \cdots & 0 \\\vdots & & \ddots & \ddots & \vdots \\0 & 0 & \cdots & -\frac{\partial f}{\partial \mathbf{s}}(\mathbf{s}_{T-1}) & I_D\end{pmatrix}

Что такое якобиан вообще

Когда есть обычная функция от одной переменной r: \mathbb{R} \to \mathbb{R}, ее производная r'(s) - это одно число, которое говорит «насколько быстро меняется выход при малом изменении входа». Оно играет роль локального коэффициента пропорциональности: если сместить s на маленькое \delta, то r изменится примерно на r'(s) \cdot \delta.

Теперь представим, что функция, у которой и вход, и выход — векторы. Скажем, \mathbf{r}: \mathbb{R}^N \to \mathbb{R}^M: на вход подаем вектор из N чисел, на выход получаем вектор из M чисел. Понятие «производной» здесь усложняется, потому что теперь надо отвечать на N \times M вопросов одновременно: «как меняется i-я компонента выхода при изменении j-й компоненты входа?». Ответы на все эти вопросы естественно собираются в матрицу размера M \times N - это и есть якобиан:

J(\mathbf{s}) = \frac{\partial \mathbf{r}}{\partial \mathbf{s}} = \begin{pmatrix}\frac{\partial r_1}{\partial s_1} & \frac{\partial r_1}{\partial s_2} & \cdots & \frac{\partial r_1}{\partial s_N} \\\frac{\partial r_2}{\partial s_1} & \frac{\partial r_2}{\partial s_2} & \cdots & \frac{\partial r_2}{\partial s_N} \\\vdots & \vdots & \ddots & \vdots \\\frac{\partial r_M}{\partial s_1} & \frac{\partial r_M}{\partial s_2} & \cdots & \frac{\partial r_M}{\partial s_N}\end{pmatrix}

В позиции (i, j) стоит число \partial r_i / \partial s_j — частная производная i-й компоненты выхода по j-й компоненте входа. То есть якобиан буквально — это полная карта чувствительностей: каждая ячейка отвечает на конкретный вопрос «насколько чутко эта выходная координата реагирует на эту входную координату».


Вспомним определение остатка:

\mathbf{r}_t(\mathbf{s}_{1:T}) = \mathbf{s}_t - f(\mathbf{s}_{t-1})

Это выражение зависит только от двух переменных: от \mathbf{s}_t (через первое слагаемое) и от \mathbf{s}_{t-1} (через второе). Все остальные \mathbf{s}_k в формуле просто не присутствуют. А производная по переменной, которой в формуле нет, равна нулю.

Разберем по случаям, какой блок \partial \mathbf{r}_t / \partial \mathbf{s}_k получается при разных k:

Случай 1: k = t. Берем производную \mathbf{s}_t - f(\mathbf{s}_{t-1}) по \mathbf{s}_t. Только первое слагаемое зависит от \mathbf{s}_t, и его производная по самому себе — единичная матрица. Получаем:

\frac{\partial \mathbf{r}_t}{\partial \mathbf{s}_t} = I_D

Случай 2: k = t - 1. Берем производную по \mathbf{s}_{t-1}. Первое слагаемое от нее не зависит, второе — это -f(\mathbf{s}_{t-1}), и его производная — это -\partial f / \partial \mathbf{s}, вычисленная в точке \mathbf{s}_{t-1}:

\frac{\partial \mathbf{r}_t}{\partial \mathbf{s}_{t-1}} = -\frac{\partial f}{\partial \mathbf{s}}(\mathbf{s}_{t-1})

Случай 3: все остальные k (то есть k \neq t и k \neq t - 1). Переменная \mathbf{s}_k в формуле для \mathbf{r}_t просто не встречается. Значит:

\frac{\partial \mathbf{r}_t}{\partial \mathbf{s}_k} = 0_{D \times D} \quad (\text{нулевая матрица})

Вот и все. Из T^2 блоков ненулевых ровно T + (T-1) = 2T - 1: T единичных матриц на главной диагонали и T-1 якобианов перехода на поддиагонали. Все остальное — нули. Если расписать всю матрицу:

J(\mathbf{s}) = \begin{pmatrix}I_D & 0 & 0 & \cdots & 0 \\-\frac{\partial f}{\partial \mathbf{s}}(\mathbf{s}_1) & I_D & 0 & \cdots & 0 \\0 & -\frac{\partial f}{\partial \mathbf{s}}(\mathbf{s}_2) & I_D & \cdots & 0 \\\vdots & & \ddots & \ddots & \vdots \\0 & 0 & \cdots & -\frac{\partial f}{\partial \mathbf{s}}(\mathbf{s}_{T-1}) & I_D\end{pmatrix}

Где здесь марковость

Ключевая причина, по которой эта структура появилась — марковское свойство RNN. Переходная функция f в каждом шаге смотрит только на предыдущее состояние \mathbf{s}_{t-1}, а не на всю историю \mathbf{s}_1, \ldots, \mathbf{s}_{t-1}. Из-за этого остаток \mathbf{r}_t оказывается «локальным» объектом: он зависит только от двух соседних состояний — текущего \mathbf{s}_t и предыдущего \mathbf{s}_{t-1}.

Сколько нам реально нужно памяти

Хотя формально якобиан имеет TD \times TD ячеек, нам нужно хранить только ненулевые блоки. Это:

  • T единичных матриц I_D — но их даже хранить не нужно, мы знаем, что это I_D и можем подставлять на лету;

  • T - 1 якобианов перехода \partial f / \partial \mathbf{s}(\mathbf{s}_t) размера D \times D — это (T-1) \cdot D^2 чисел, что для T = 1000, D = 256 дает около 65 миллионов чисел вместо 65 миллиардов. Уже выполнимо.

Как структура якобиана дает параллелизм

Теперь главный вопрос: почему такая структура позволяет решить систему J \cdot \Delta\mathbf{s} = -\mathbf{r} параллельно? Здесь важно различать два уровня:

Уровень 1: структура позволяет решить систему через прямую подстановку. Возьмем систему J \cdot \Delta\mathbf{s} = -\mathbf{r} и распишем ее построчно. Первая блочная строка матрицы J — это (I_D, 0, 0, \ldots, 0), поэтому первое уравнение системы это:

I_D \cdot \Delta\mathbf{s}_1 = -\mathbf{r}_1
  • то есть просто \Delta\mathbf{s}_1 = -\mathbf{r}_1. Получили первый кусок ответа практически даром.

Вторая блочная строка J это (-\partial f / \partial \mathbf{s}(\mathbf{s}_1), I_D, 0, \ldots, 0), поэтому второе уравнение:

-\frac{\partial f}{\partial \mathbf{s}}(\mathbf{s}_1) \cdot \Delta\mathbf{s}_1 + I_D \cdot \Delta\mathbf{s}_2 = -\mathbf{r}_2

Откуда:

\Delta\mathbf{s}_2 = \frac{\partial f}{\partial \mathbf{s}}(\mathbf{s}_1) \cdot \Delta\mathbf{s}_1 - \mathbf{r}_2

И вообще, для произвольного

t > 1:

\Delta\mathbf{s}_t = \frac{\partial f}{\partial \mathbf{s}}(\mathbf{s}_{t-1}) \, \Delta\mathbf{s}_{t-1} - \mathbf{r}_t

Это и есть линейная рекуррентность, в которую превратилась наша гигантская система TD \times TD. Заметим, что в общем случае решение линейной системы стоит O(N^3) — но мы здесь обошлись без обращения какой-либо матрицы, благодаря тому что J блочно-бидиагональна. Система решается простой пробежкой по уравнениям сверху вниз. Это и называется forward substitution.

Если бы дело закончилось здесь, мы бы получили лишь последовательный алгоритм за O(T) шагов — каждый \Delta\mathbf{s}_t зависит от \Delta\mathbf{s}_{t-1}, и пробежать рекурренцию надо строго по порядку. То же самое, что просто прогнать RNN последовательно. Параллелизм рождается на следующем уровне.

Уровень 2: рекуррентность линейна, и потому ассоциативна. В этом — главный фокус. Обратим внимание на принципиальную разницу между двумя ситуациями:

  • Исходная RNN: \mathbf{s}_t = f(\mathbf{s}_{t-1}, \mathbf{x}_t) — функция f нелинейна, поэтому распараллелить такую рекуррентность нельзя: приходится считать каждый шаг честно по очереди.

  • Рекуррентность для \Delta\mathbf{s}: \Delta\mathbf{s}_t = A_t \cdot \Delta\mathbf{s}_{t-1} + \mathbf{b}_t (где A_t = \partial f / \partial \mathbf{s}(\mathbf{s}_{t-1}), \mathbf{b}_t = -\mathbf{r}_t) - она линейна. Это значит, что из нее можно получить замкнутую формулу:

\Delta\mathbf{s}_t = A_t A_{t-1} \cdots A_2 \cdot \Delta\mathbf{s}_1 + (A_t \cdots A_3 \cdot \mathbf{b}_2) + (A_t \cdots A_4 \cdot \mathbf{b}_3) + \ldots + \mathbf{b}_t

Все эти произведения матриц A_t \cdot A_{t-1} \cdot \ldots можно вычислить в любом порядке (умножение матриц ассоциативно: (AB)C = A(BC)). А значит, можно построить дерево вычислений, в котором мы сначала параллельно считаем все попарные произведения A_2 \cdot A_1, A_4 \cdot A_3, A_6 \cdot A_5, …, затем все четверки A_4 \cdot A_3 \cdot A_2 \cdot A_1, A_8 \cdot A_7 \cdot A_6 \cdot A_5, …, и так далее. За \log_2 T уровней дерева мы получаем все нужные накопленные произведения, и из них собираем все \Delta\mathbf{s}_t одновременно.

Это и есть параллельный скан (он же parallel prefix sum в обобщенной форме). Аналогия с обычным сложением: если надо сложить миллиард чисел, последовательно — миллиард шагов, а попарным деревом — всего \log_2(10^9) \approx 30 уровней. Тот же прием работает для любой ассоциативной операции, а композиция линейных отображений (умножение их матриц) — ассоциативна.

Итог по сложности: один шаг Ньютона выполняется за O(\log T) параллельной глубины (вместо O(T) последовательных шагов), а все применение RNN — за O(\text{iters} \cdot \log T), где iters — число итераций Ньютона.

Комментарии (0)