Аннотация

Одной из наиболее примечательных особенностей Large Language Models (LLM) является их способность к in-context learning — обучению в контексте. В частности, на этапе инференса LLM может усваивать новые паттерны без какого-либо дополнительного обновления весов, если эти паттерны представлены в виде примеров в промпте, даже если эти паттерны не встречались во время обучения. Механизмы, за счёт которых это возможно, всё ещё во многом остаются неизвестными.

В данной работе мы показываем, что комбинация слоя self-attention с MLP позволяет трансформер-блоку неявно модифицировать веса MLP-слоя в зависимости от контекста. Мы утверждаем на основе теоретического анализа и экспериментов, что этот простой механизм может объяснять, почему LLM способны обучаться в контексте, а не только во время тренировки модели. В частности, мы демонстрируем, что при ряде упрощающих допущений трансформер-блок неявно преобразует контекст в low-rank обновление весов MLP-слоя.

1. Введение

Large Language Models (LLM) и архитектура трансформер произвели революцию в области машинного обучения и, вероятно, окажут такое же влияние на множество других сфер — промышленность, науку и искусство. Однако, несмотря на столь значительный эффект, механизмы, за счёт которых LLM приобретают эмерджентные свойства, делающие их столь полезными, остаются в значительной степени теоретической загадкой.

В этой работе мы фокусируемся на способности LLM к in-context learning (ICL) — обучению в контексте — после того, как процесс тренировки модели полностью завершён. Речь идёт о способности усваивать знания из примеров, которые не встречались в обучающей выборке, но предоставляются модели уже на этапе инференса через промпт.

Исторически в машинном обучении умение извлекать паттерны из серии примеров трактовалось как динамический процесс обновления весов модели по мере того, как она «потребляет» эти примеры в рамках некоторой оптимизационной процедуры. Однако в случае in-context learning отсутствует явное обновление весов, которое могло бы объяснить эмерджентную динамическую природу обученных LLM, способных реорганизовывать или переконфигурировать себя под влиянием инструкций из пользовательского промпта.

Эта загадочная и крайне полезная способность LLM привела исследователей к гипотезе о существовании неявного механизма обновления весов, происходящего во время инференса, когда модель «потребляет» промпт. Более того, недавние исследования показали, что трансформер-блоки могут неявно реализовывать разновидность стохастического градиентного спуска при обработке контекста.

В данной работе мы исследуем эту интуицию о неявных обновлениях весов, но идём противоположным путём. Вместо того чтобы использовать путь сильной абстракции и рассматривать только простые toy-модели, мы предлагаем сфокусироваться на ключевых свойствах контекстной обработки и показать, что attention-слои играют в этом центральную роль.

Мы рассматриваем обобщение трансформер-блока, которое называем контекстным блоком. Мы показываем, что слои с таким контекстуальным свойством, будучи объединёнными со стандартными нейронными сетями, неявно преобразуют контекст в обновление весов самого первого слоя стека нейросети.

Мы выводим явную формулу для этого неявного обновления, и оказывается, что это обновление можно выразить как low-rank матрицу ранга 1. Это приводит нас к выводу, что такие контекстные слои — например, self-attention — в комбинации с нейронной сетью фактически выполняют неявный fine-tuning весов MLP, где апдейт вычисляется напрямую из контекста.

Основные элементы работы:

  • Мы вводим понятие контекстного блока, образованного контекстным слоем, «надстроенным» над нейросетью, — тем самым обобщая блок трансформера.

  • Мы показываем, что для контекстных блоков выход токена в присутствии контекста совпадает с выходом той же нейросети без контекста, но с матрицей весов, обновлённой low-rank матрицей.

  • Вывод явной формулы для неявного обновления весов нейросети, индуцированного контекстом, и анализ влияния контекста на параметры модели.

  • Показ связи обработки токенов с динамикой обучения: мы демонстрируем, что процесс «потребления» токенов эквивалентен неявному градиентному спуску в пространстве весов нейросети.

1.1 Смежные работы

Начиная с определённого масштаба, LLM демонстрируют способность обучаться на примерах, предоставленных в промпте. Эта эмерджентная способность была впервые чётко показана ещё в GPT-3 [3] и получила название In-Context Learning (ICL) [4].

В [12] работе авторы формулируют фундаментальный вопрос: действительно ли происходит истинное обучение во время инференса когда модель обрабатывает промрт, или же примеры в контексте просто помогают модели активировать уже выученные на этапе pre-training способности, без какого-либо нового обучения в момент инференса.

Более того, в работе[13] утверждается, что примеры в промпте служат лишь формой байесовского обусловливания (Bayesian conditioning), а не собственно обучением. В том же направлении [14] показывают, что замена меток в примерах промпта на случайные не приводит к значительному падению качества ICL, что подтверждает гипотезу: модель не учится заново, а извлекает уже заученные во время pre-training знания.

Однако [15] переосмысляют эту идею и показывают, что хотя это верно для малых моделей, крупные LLM начинают действительно обучаться на случайно переставленных метках внутри промпта. Аналогично, [16] демонстрируют, что возникновение настоящего ICL также сильно зависит от разнообразия данных на этапе pre-training в контексте LLM.

С другой стороны, [8] показывают, что трансформер-модели, прошедшие pre-training на задачах регрессии, способны по контексту «на лету» осваивать такие разные функции, как линейные функции, деревья решений и двухслойные нейросети. Эти эксперименты предоставляют контролируемую среду для проверки гипотезы о существовании истинного ICL.

В [17] показано, что трансформеры способны обучаться в режиме meta-optimization — когда сама модель начинает вести себя как мета-оптимизатор. Эта гипотеза проверяется в [6], где (в той же контролируемой постановке регрессии) трансформер с линейным вниманием, обученный в режиме gradient flow, сходится к мета-оптимизатору и ведёт себя как градиентный спуск.

Одновременно работы [7], [9] и [10] демонстрируют теоретические механизмы, в которых потребление примеров через промпт во время инференса может быть неявно связано с шагами градиентного спуска, но реализуется через implicit weight updates. Недавняя работа [11] также показывает, что при использовании chain-of-thought в промпте эффект аналогичен многократным шагам стохастического градиентного спуска.

Однако все эти теоретические модели предполагают узкий класс сценариев, например, линейные слои и prompt, построенные на основе регрессионных пар. Как показано в [18] и [19], эти предположения недостаточно реалистичны, и авторы выявили различия между истинным ICL и градиентным спуском, выполняемым через fine-tuning на примерах из prompt. Совсем недавно [20] подтвердили, что ICL демонстрирует генерализационное преимущество по сравнению с традиционным fine-tuning.

В данной работе мы развиваем концепцию ICL, объясняя её через механизм implicit weight updates, соответствующих своеобразной неявной динамике обучения. Однако мы отказываемся от ограничений классических моделей, которые делают предположения о том, что контекстное обучение происходит в слоях self-attention. Вместо этого мы предлагаем более общую модель, в которой механизм обновления весов переносится на MLP-слой внутри трансформер-блока, а не на attention-механизм [7, 9, 10].

2. Контекстные блоки

В этом разделе мы абстрагируем некоторые ключевые свойства трансформеров. В частности, вводим понятие контекстного слоя (contextual layer), которое обобщает слой self-attention в трансформерных блоках. В этой постановке контекстный блок (contextual block) — это композиция контекстного слоя со стандартной нейросетью, обобщающая понятие трансформерного блока. Далее мы доказываем основной теоремный результат: контекст в контекстных блоках действует как низкоранговый fine-tuning-апдейт весов нейросети. Для простоты мы формулируем результаты для нейросети без skip-connection; случай со skip-connection аналогичен, но технически сложнее и полностью разобран в Приложении A.

Мы называем контекстным слоем слой сети A(⋅), который может принимать на вход одиночный вектор x и выдавать выход A(x); либо, опционально, A может дополнительно принимать контекст C (например, последовательность токенов, изображение и т. п.) вместе с вектором x, выдавая выход A([C,x]). Заметим, что далее мы часто будем опускать явное обозначение конкатенации [C,x] во входе контекстного слоя и просто писать A(C,x), имея в виду A([C,x]).

Как прототипический и направляющий пример контекстного слоя рассмотрим слой self-attention в трансформерном блоке, где контекст C — это инструкционный промпт, состоящий из последовательности контекстных токенов C=[c_1,…,c_n], а x — query-токен, по которому LLM делает предсказание. Вместе C и x образуют контекстуализированный входной промпт [C,x]=[c_1,…,c_n,x], то есть конкатенацию контекстных токенов и query-токена. Мы полагаем A(C,x) выходом слоя self-attention для последнего токена x. Таким образом, и A(C,x), и A(x) лежат в одном и том же выходном векторном пространстве. Контекстные слои порождают контекстуальные векторы, что позволяет определить разность ΔA(C):=A(C,x)−A(x) между выходом слоя с контекстом и без него.

Мотивируясь этим обобщением слоя self-attention как контекстного слоя, мы теперь обобщаем и понятие целого трансформерного блока, вводя контекстный блок:

Определение 2.1. Контекстный блок — это композиция T_W=M_W∘A, где A — контекстный слой, а M_W​ — нейросеть; т. е. M_W(z)=f_θ(Wz+b), где W и b — веса первого полносвязного (dense) слоя, а f_θ(z) — «остальная часть» нейросети.

Следующая теорема утверждает, что контекстный блок преобразует подмножество Y⊂C контекста C в неявное обновление весов нейросети так, что W становится W+ΔW(Y), где информационное содержание Y переносится в веса через апдейт ΔW(Y). В некотором смысле, контекстные слои загружают в веса сети параметры, соответствующие контекстной части Y, неявно добавляя к весам W низкоранговое обновление ΔW(Y). А именно, выход контекстного блока на полном контексте C совпадает с выходом того же блока на контексте C\setminus Y1, если Y убран из C, но «вшит» в веса через апдейт ΔW(Y).

Теорема 2.2. Рассмотрим контекстный блок T_W=M_W∘A, как выше: он образован контекстным слоем A и полносвязным слоем M_W с матрицей весов W. Для заданных контекста C и входа x влияние некоторой части Y⊂C контекста C на выход контекстного блока неявно соответствует обновлению весов ранга-1 первого слоя M_W​: W+ΔW(Y). А именно,

T_{W}(C,x) = T_{\,W+\Delta W(Y)}\!\bigl(C\setminus Y,\, x\bigr)\text{where }\; \Delta W(Y) = \frac{\bigl(W\,\Delta A(Y)\bigr)\,A(C\setminus Y,x)^T}        {\|A(C\setminus Y,x)\|^{2}} , \tag{1}

где ΔA(Y)=A(C,x)−A(C\setminus Y,x) — контекстный вектор, ассоциированный с Y. Заметим, что ΔW(Y) имеет ранг 1, поскольку WΔA(Y) — столбцовый вектор, а A(C\setminus Y,x)^T — строковый вектор.

Доказательство. Утверждение следует из прямого вычисления, где мы используем обозначение M_W(z)=f_θ(Wz+b), где W и b — веса первого dense-слоя сети M, а f_θ​ — остальная часть сети. В этих обозначениях по определению

T_{\,W+\Delta W(Y)}\!\bigl(C\setminus Y, x\bigr) = M_{\,W+\Delta W(Y)}\!\left(A(C\setminus Y, x)\right) \tag{2}= f_{\theta}\!\left((W+\Delta W(Y))\,A(C\setminus Y, x) + b\right) \tag{3}= f_{\theta}\!\left(   W\,A(C\setminus Y, x)   + \Delta W(Y)\,A(C\setminus Y, x)   + b \right). \tag{4}

Подставляя теперь ΔW(Y) из определения в формуле (1) и используя тождество  \frac{z^T}{\lVert z\rVert^{2}}\, z = 1, получаем

T_{\,W+\Delta W(Y)}\!\bigl(C\setminus Y, x\bigr) = f_{\theta}\!\left(   W\,A(C\setminus Y, x)   + \frac{(W\,\Delta A(Y))\,A(C\setminus Y, x)^T}          {\|A(C\setminus Y, x)\|^{2}}\,A(C\setminus Y, x)   + b \right) \tag{5}= f_{\theta}\!\left(   W\big(A(C\setminus Y, x) + \Delta A(Y)\big) + b \right) \tag{6}

Поскольку по определению контекстного вектора A(C\setminus Y,x)+ΔA(Y)=A(C,x), в итоге имеем

T_{\,W+\Delta W(Y)}\!\bigl(C\setminus Y, x\bigr) = f_{\theta}\!\left(W\,A(C,x)+b\right) = M_{W}\!\left(A(C,x)\right) = T_{W}(C,x) \tag{7}

Замечание 2.3. Наша теорема утверждает, что любой контекстный слой порождает неявную «передачу» веса из промпта в первый слой нейросети, тем самым неявно модифицируя поведение предобученной нейросети. Среди возможных реализаций контекстных слоёв (например, self-attention, RNN или рекуррентные слои с локальным вниманием, как в [21]) одни могут лучше вносить полезные модификации весов, чем другие. Представляется интересным оценить «порождающую» силу контекстного слоя с точки зрения специфической формы неявных обновлений весов, заданной нашей теоремой, и структуры A, задаваемой контекстным слоем.

Отметим, что при Y=C (весь контекст) теорема даёт формулу, позволяющую «переложить» всю контекстную информацию в матрицу весов W; а именно:

Следствие 2.3.1. В обозначениях выше полный контекст C может быть перенесён в веса нейросети следующим обновлением:

T_{W}(C,x) = T_{\,W+\Delta W(C)}(x), \quad \text{with }\; \Delta W(C) = \frac{(W\,\Delta A)\,A(x)^{T}}{\|A(x)\|^{2}} , \tag{8}

где ΔA=A(C,x)−A(x) — контекстный вектор, а ΔW имеет ранг 1, поскольку WΔA — столбцовый вектор, а A(x)^T — строковый.

Замечание 2.4. Формулу «передачи весов» (1) можно переписать через объединение/конкатенацию контекстов, положив D=C\setminus Y; тогда

T_{W}(D \cup Y, x) = T_{\,W+\Delta W(Y)}(D, x). \tag{9}

В Приложении A мы обобщаем Теорему 2.2 на сети M со skip-соединениями, что обычно имеет место для стандартных трансформерных блоков. В разделе 4 мы экспериментально проверяем теоретические результаты на классическом конкретном примере.

Динамика неявного обучения в ICL

Когда контекст C=[c_1,…,c_n]— это последовательность токенов, итеративное применение Следствия 2.3.1 выявляет неявную динамику обучения, порождённую влиянием каждого контекстного токена на выход контекстного блока. Начав с начальной матрицы весов W_0 для первого полносвязного слоя нейросети M_W​, можно рассматривать обновления весов, соответствующие поэтапному добавлению токена в контекст:

T_{W_{0}}\!\bigl(c_{1}, x\bigr) = T_{\,W_{0}+\Delta W_{0}(c_{1})}\!(x)T_{W_{0}}\!\bigl(c_{1}, c_{2}, x\bigr) = T_{\,W_{0}+\Delta W_{0}(c_{1}, c_{2})}\!(x)\begin{aligned} &\vdots\\ T_{W_{0}}(c_{1},\ldots,c_{n},x) &= T_{\,W_{0}+\Delta W_{0}(c_{1},\ldots,c_{n})}(x) \end{aligned}

Отсюда получаем следующую последовательность «контекстных» весов:

W_{1} = W_{0} + \Delta W_{0}(c_{1}) \tag{10}W_{2} = W_{0} + \Delta W_{0}(c_{1}, c_{2}) \tag{11}\begin{gather} \vdots \tag{12}\\ W_{n} = W_{0} + \Delta W_{0}(c_{1},\ldots,c_{n}) \tag{13} \end{gather}

что по построению сходится к эффекту полного контекста на веса MLP; а именно

T_{W_{n}}(x) = T_{W_{0}}(c_{1},\ldots,c_{n}). \tag{14}

Следующее утверждение показывает, что эта неявная динамика обучения подобна онлайн-градиентному спуску, где токены играют роль обучающих примеров, а функция потерь на каждом шаге меняется в зависимости от рассматриваемого в этот момент токена.

Предложение 3.1. В указанной выше нотации итеративный процесс обновления весов можно представить в виде стохастических шагов градиентного спуска

W_i = W_{i-1} - h\, \nabla_{W} L_i\!\left(W_{i-1}\right) \tag{15}

где скорость обучения h = 1/\lVert A(x)\rVert^{2}, а потери на шаге i заданы

L_i(W) = \operatorname{trace}\!\big(\Delta_i^{T} W\big), \tag{16}

причем \Delta_i = W_{0}\Big( A(c_{1},\ldots,c_{i},x) - A(c_{1},\ldots,c_{i+1},x) \Big)\, A(x)^{T}.

Доказательство. Во-первых, для последовательности W_i, определённой в (10)–(13), имеем

W_{i+1} - W_{i} = \Delta W_{0}(c_{1},\ldots,c_{i+1})   - \Delta W_{0}(c_{1},\ldots,c_{i}) \tag{17}= \frac{\,W_{0}\Big(A(c_{1},\ldots,c_{i+1},x)-A(c_{1},\ldots,c_{i},x)\Big)\,A(x)^{T}\,}        {\lVert A(x)\rVert^{2}} \tag{18}= -\,h\,\Delta_i, \tag{19}

где h = 1/\lVert A(x)\rVert^{2}и \Delta_i = W_{0}\Big( A(c_{1},\ldots,c_{i},x) - A(c_{1},\ldots,c_{i+1},x) \Big)\, A(x)^{T}.

Отсюда

W_{i+1} = W_i - h\,\Delta_i = W_i - h\,\nabla_{W}\,\operatorname{trace}\!\big(\Delta_i^{T} W\big), \tag{20}

поскольку в общем случае \nabla_{W}\,\operatorname{trace}\!\big(A^{T}W\big) = A.

Заметим, что Δi​ измеряет вклад добавления токена c_{i+1} к частичному контексту c_1,…,c_i. Если c_i​ не влияет на выход, т.е. A(c_1,…,c_i,x)−A(c_1,…,c_{i+1},x)=0, то и соответствующее обновление ∇_WL_i(W)=Δ_i зануляется. На Рис. 2 показано на простом эксперименте, что эти градиенты затухают по мере того, как динамика обучения сходится к использованию полного контекста.

Замечание 3.2. Интересно, что можно вывести другую, но схожую, неявную динамику обучения W_0,W_1,…,W_n​, рассматривая частичные обновления, которые на каждом шаге сохраняют неизменным выход контекстного блока при совместном использовании с оставшимися токенами: T_{W_{i}}\!\bigl(c_{i+1}, \cdots, c_{n}, x\bigr) = T_{W_{0}}\!\bigl(c_{1}, \ldots, c_{n}, x\bigr). Эта динамика описана в Приложении B. Отличие в том, что в общем случае её уже нельзя представить как градиентные шаги, но она приводит к факторизационной формуле для итоговой матрицы весов W_n​, такой что T_{W_{n}}(x) = T_{W_{0}}(c_{1},\ldots,c_{n},x).

4. Эксперименты

Чтобы проверить Теорему 2.2 на практике, мы рассматриваем корректно поставленную задачу обучения класса функций по примерам из контекста (in-context). Эта задача независимо изучалась в [6, 22]. В этих работах показано, что трансформер можно обучить с нуля выполнять in-context-обучение линейных функций. Иными словами, если модель трансформера была обучена на классе линейных функций, то после обучения она способна по одним лишь примерам в промпте выучивать новые, ранее невиданные линейные функции (выбранные из распределения, близкого к использованному при обучении) с качеством, сопоставимым с оптимальным оценивателем наименьших квадратов.

В [6, 22] авторы сосредотачивались на том, насколько трансформеры устойчивы (или, напротив, неустойчивы) к сдвигам распределения между обучающими данными модели и промптами на этапе инференса. Это не наша цель. Поскольку те работы уже подтвердили, что трансформеры действительно умеют учиться в контексте для линейных моделей, мы используем здесь схожий экспериментальный протокол, чтобы проверить, что контекстные промпты можно эффективно «перенести» в обновление весов по формуле (8). Мы проверяем, что предсказание обученной модели при наличии in-context-промпта идентично предсказанию модели с весами MLP, модифицированными согласно формуле (8), но без доступа к самому in-context-промпту.

4.1 Постановка эксперимента (Setup)

На высоком уровне, по аналогии с [6], мы обучаем простой трансформер на примерах промптов, состоящих из пар вход–выход вида (x_{1},\, h(x_{1}),\, \ldots,\, x_{N},\, h(x_{N}),\, x_{\text{query}}) , где x_i,x_{query}​ сэмплируются независимо и одинаково распределёнными i.i.d. из распределения D_x​, а функция h независимо сэмплируется из распределения над функциями в классе \mathcal{H}.В частности, мы берём \mathcal{H} как класс линейных функций, так что h(x) = \langle w, x \rangle, причём x_i,\; x_{\text{query}},\; w \sim \mathcal{N}(0, I_d).Цель обучающегося в контексте (in-context learner) — по промпту из таких пар предсказать \hat{y}\!\left(x_{\text{query}}\right) так, чтобы \hat{y}\!\left(x_{\text{query}}\right) \approx h\!\left(x_{\text{query}}\right).

Каждый обучающий промпт индексируется задачей τ∈N и имеет вид:

P_{\mathcal{T}} = \bigl(x_{\mathcal{T},1},\, h_{\mathcal{T}}(x_{\mathcal{T},1}),\, \ldots,\, x_{\mathcal{T},N},\, h_{\mathcal{T}}(x_{\mathcal{T},N}),\, x_{\mathcal{T},\text{query}}\bigr).

Мы можем записать такой промпт как матрицу-встраивание E_{\mathcal{T}}​, так что

E_{\mathcal{T}} := \begin{pmatrix} x_{\mathcal{T},1} & x_{\mathcal{T},2} & \cdots & x_{\mathcal{T},N} & x_{\mathcal{T},\mathrm{query}} \\ \langle w_{\mathcal{T}}, x_{\mathcal{T},1} \rangle & \langle w_{\mathcal{T}}, x_{\mathcal{T},2} \rangle & \cdots & \langle w_{\mathcal{T}}, x_{\mathcal{T},N} \rangle & 0 \end{pmatrix} \in \mathbb{R}^{(d+1)\times (N+1)}.

В нотации раздела 2 на E_{\mathcal{T}}​ удобно смотреть как на контекстуализированный входной промпт, где

C = [c_{1},\ldots,c_{N}] = \begin{pmatrix} x_{\mathcal{T},1} & x_{\mathcal{T},2} & \cdots & x_{\mathcal{T},N} \\ \langle w_{\mathcal{T}}, x_{\mathcal{T},1} \rangle & \langle w_{\mathcal{T}}, x_{\mathcal{T},2} \rangle & \cdots & \langle w_{\mathcal{T}}, x_{\mathcal{T},N} \rangle \end{pmatrix}\text{and}\quad x = \begin{pmatrix} x_{\mathcal{T},\mathrm{query}}\\[2pt] 0 \end{pmatrix}

так что E_{\mathcal{T}}=(C,x). Пусть θ — параметры модели. Предсказание модели \hat{y}\!\left(x_{\mathcal{T},\mathrm{query}}\right) для токена запроса x_{\mathcal{T},\mathrm{query}}, query​ — это последний компонент выхода по токену запроса у одного блока2 трансформера, то есть

\hat{y}\!\left(x_{\mathcal{T},\,\mathrm{query}}\right) = T_{W}(C,x)_{(d+1)} \tag{21}

Заметим, что при таком определении размерности T_W(C,x) и T_{W+\Delta W}(x) совпадают. Мы обучаем трансформер по лоссу на батче размера B,

\hat{\mathcal{L}}(\theta) = \frac{1}{2B}\sum_{\tau=1}^{B} \left(\hat{y}_{\mathcal{T},\mathrm{query}}       - \big\langle w_{\mathcal{T}},\, x_{\mathcal{T},\mathrm{query}}\big\rangle \right)^{2}.

4.2 Проверка теоремы 2.2

Пусть трансформер обучен на линейных функциях. Мы показываем, что in-context-промпт можно «перенести» в обновление весов, определённое формулой (8). А именно, хотим показать, что

T_{W}(C,x) = T_{\,W+\Delta W}(x);

или, что то же самое,

T_{W}\!\left( \begin{pmatrix} x_{\mathcal{T},1} & x_{\mathcal{T},2} & \cdots & x_{\mathcal{T},N} & x_{\mathcal{T},\mathrm{query}} \\ \langle w_{\mathcal{T}}, x_{\mathcal{T},1}\rangle & \langle w_{\mathcal{T}}, x_{\mathcal{T},2}\rangle & \cdots & \langle w_{\mathcal{T}}, x_{\mathcal{T},N}\rangle & 0 \end{pmatrix} \right) = T_{\,W+\Delta W}\!\left( \begin{pmatrix} x_{\mathcal{T},\mathrm{query}}\\[2pt] 0 \end{pmatrix} \right)

где ΔW вычисляется по формуле (8). На рис. 1 сравниваются значения валидационного лосса при предсказании с использованием in-context-промпта и при предсказании с эквивалентным обновлением весов. Лоссы для обеих настроек приведены по эпохам; также показан увеличенный фрагмент графика для ясности.

4.3 Сходимость ΔW

Мы ставим эксперименты, чтобы понять, как адаптируются веса по мере того, как модель обрабатывает in-context-промпт в рамках неявной (имплицитной) динамики обучения, описанной в Предложении 3.1. В частности, мы хотим проверить, что по мере достижения сходимости по контексту градиентные обновления стремятся к нулю.

Мы строим последовательность \bigl\{\, (\Delta W)_i \,\bigr\}_{i=1}^{N} где каждое (ΔW)_i задано формулами (10)–(13). То есть выполняется

T_{W}(C_i, x) = T_{\,W+(\Delta W)_i}(x)

где

C_i = [c_{1},\ldots,c_{i}] = \begin{pmatrix} x_{\mathcal{T},1} & \cdots & x_{\mathcal{T},i} \\ \langle w_{\mathcal{T}}, x_{\mathcal{T},1} \rangle & \cdots & \langle w_{\mathcal{T}}, x_{\mathcal{T},i} \rangle \end{pmatrix} \quad\text{and}\quad x = \begin{pmatrix} x_{\mathcal{T},\mathrm{query}}\\[2pt] 0 \end{pmatrix}.

Рисунок 1: кривые лосса на обучении и валидации. Здесь «Validation loss (computed via ΔW)» означает лосс, вычисленный с использованием T_{W+\Delta W}(x); т. е. предсказание обученной модели при подаче только x_{query}, но с весами MLP, модифицированными на ΔW согласно уравнению (8).
Слева: кривая тренировочного лосса и обе кривые валидационного лосса.
Справа: увеличенный фрагмент валидационного лосса, вычисленного обоими способами, т. е. через T_W(C,x) и через T_{W+\Delta W}(x).

Если W_0​ — это выученные веса первого полносвязного слоя, то из Следствия 2.3.1 следует, что для любого i=1,2,…, N

(\Delta W)_i = \frac{(W_{0}\,\Delta A_i)\,A(x)^{T}}{\lVert A(x)\rVert^{2}}, \qquad \text{where }\; \Delta A_i :=A(c_{1},\ldots,c_{i},x) - A(x).

Интуитивно ожидается, что по мере того как «in-context-learner» обрабатывает всё большую часть промпта, относительное изменение в (ΔW)_i​ должно уменьшаться. На Рисунке 2 мы подтверждаем, что это действительно так.

Для заданного контекста C_i=[c_1,…,c_i] длины i мы строим маржинальное изменение в (ΔW)_i при добавлении ещё одного токена контекста c_{i+1}​, что даёт (ΔW)_{i+1​} для контекста C_{i+1}=[c_1,…,c_i,c_{i+1}]. Это маржинальное изменение измеряется в L2-норме; т. е. для каждой длины контекста i по оси y откладывается величина, соответствующая обновлениям градиента из Предложения 3.1:

\left\| \nabla_{W} L_i(W) \right\|_{2} = \left\|\, (\Delta W)_{i+1} - (\Delta W)_i \,\right\|_{2}.

Мы наблюдаем на Рисунке 2, что градиентные обновления убывают и исчезают по мере того, как неявная динамика обучения продвигается к полному контексту, что соответствует сходящемуся процессу градиентного спуска.

4.4 Сравнение с fine-tuning

Мы предобучаем модель-трансформер (один стандартный блок трансформера без MLP skip-connection) на примерах вида

E_{\mathcal{T}} := \begin{pmatrix} x_{\mathcal{T},1} & x_{\mathcal{T},2} & \cdots & x_{\mathcal{T},N} & x_{\mathcal{T},\mathrm{query}} \\ \langle w_{\mathcal{T}}, x_{\mathcal{T},1} \rangle & \langle w_{\mathcal{T}}, x_{\mathcal{T},2} \rangle & \cdots & \langle w_{\mathcal{T}}, x_{\mathcal{T},N} \rangle & 0 \end{pmatrix} \in \mathbb{R}^{(d+1)\times (N+1)} .

Здесь берём d=2 и N=50.

Для fine-tuning мы создаём один новый тестовый пример с использованием ω_{test}, который модель не видела на предобучении, хотя ω_{test}​ сэмплируется из того же распределения, из которого при предобучении берутся все ω_\tau​. Обозначим этот пример \mathcal{D}_{FT}​:

\mathcal{D}_{FT} = \begin{pmatrix} x_{1} & \cdots & x_{M} & x_{\mathrm{test}} \\ \langle \omega_{\mathrm{test}}, x_{1} \rangle & \cdots & \langle \omega_{\mathrm{test}}, x_{M} \rangle & 0 \end{pmatrix}

Теперь для каждого i=1,2,…,M формируем датасет для fine-tuning, беря первые i элементов из \mathcal{D}_{FT}​ и игнорируя последний столбец, который является нашим тестовым query. То есть, для всех i=1,…,M:

\mathcal{D}^{\,i}_{FT} = \begin{pmatrix} x_{1} & x_{2} & \cdots & x_{i} \\ \langle \omega_{\mathrm{test}}, x_{1} \rangle & \langle \omega_{\mathrm{test}}, x_{2} \rangle & \cdots & \langle \omega_{\mathrm{test}}, x_{i} \rangle \end{pmatrix}

Мы инициализируем трансформер предобученными весами, затем выполняем fine-tuning с помощью стохастического градиентного спуска (learning rate 0.01), подавая по одному примеру за раз в том же порядке, в каком они обрабатываются in-context. Во время fine-tuning мы обновляем только матрицу весов слоя MLP. Поэтому для каждого i=1,…,M мы выполняем i шагов градиентного спуска с размером батча 1. После fine-tuning на всех M примерах вычисляем loss (функцию потерь) fine-tuned-модели на тестовом query (x_{test},0). Это мы называем «GD test loss после i шагов».

Рисунок 2: Сходимость (ΔW)_i. По мере обработки всё большей части контекста относительное изменение весов W стремится к нулю. Для длины контекста i>2 график выше показывает среднюю разность \left\|\, (\Delta W)_{i+1} - (\Delta W)_i \,\right\|_{2}​ и стандартную ошибку по 100 независимым прогонам.

Аналогично, для каждого i мы считаем перенос весов (weight transfer), как определено в уравнении (8), с контекстом

C_i = \begin{pmatrix} x_{1} & x_{2} & \cdots & x_{i} \\ \langle \omega_{\mathrm{test}}, x_{1} \rangle & \langle \omega_{\mathrm{test}}, x_{2} \rangle & \cdots & \langle \omega_{\mathrm{test}}, x_{i} \rangle \end{pmatrix}

и тем же тестовым запросом x=(x_{test},0). Используя ΔW из формулы переноса весов, вычисляем loss на (x_{test},0). Это мы называем «ΔW test loss» для длины контекста i.

На рисунке 3 ниже мы строим график зависимости fine-tuning GD test loss от ΔW weight-transfer test loss. На графике показано среднее по 100 независимым прогонам. Несмотря на различия, видно, что оба процесса обучения (fine-tuning и неявная динамика обновления весов) уменьшают loss сходным образом.

5. Заключение и ограничения

Наш подход к механике трансформерного блока, лежащей в основе ICL, улучшает предыдущие работы тем, что не накладывает ограничений на архитектуру слоя self-attention для извлечения неявной динамики обучения в весовом пространстве. Ранние теоретические работы, сосредоточенные на внутреннем устройстве трансформеров, выводили подобную неявную динамику, но лишь при жёстких допущениях о слое self-attention (например, линейное внимание и/или одна «голова»; см. [9–11, 19]). Фактически наши результаты остаются верными, даже если слой self-attention заменить другими формами контекстных слоёв — например, слоем RNN или любым слоем, который может принимать вход и, опционально, контекст. Это неожиданно, потому что наш анализ подсказывает: ICL в меньшей степени связан с внутренностями self-attention и в большей — с тем, что обычные нейросети способны переносить модификации в пространстве входов в структуру своих весов. Это глубокое свойство отмечалось в ряде теоретических работ и помогает понять, почему глубокие нейросети так хорошо обобщают [23–25].

Однако, хотя наш подход ближе к реальности, так как мы убираем ограничения на слой self-attention, мы всё ещё анализируем упрощённую модель в следующем смысле, что и составляет основное ограничение нашего анализа:

Рисунок 3:

  • Наше выводимое утверждение действительно только для одного трансформер-блока, так как основная теорема количественно оценивает влияние контекста только на выход самого последнего входного токена, а не на полный выход всего блока трансформера

  • Наша основная теорема анализирует влияние контекста только относительно первого сгенерированного токена. Она не охватывает полную механику генерации за пределами этого шага.

Несмотря на эти ограничения, мы надеемся, что эта работа поможет лучше понять таинственные явления, возникающие во время инференса у LLM.

A. Контекстные блоки со skip-соединениями

Рассмотрим теперь случай контекстных блоков со skip-соединениями, охватывающий стандартный Pre-LN-блок трансформера, как, например, описано в [26].

Определение A.1. Контекстный блок со skip-соединением — это слой вида

T(C,x) = x + A(C,x) + W'\, g_{\theta}\!\big(W\,A(C,x) + b\big) + b' \tag{22}

где g_θ — произвольная дифференцируемая модель, а A(C,x) — контекстный слой.

Мы можем обобщить Теорему 2.2 на этот случай, разрешив обновлять не только матрицу весов первого слоя W, но и сдвиг b' последнего слоя.

Теорема A.2. Рассмотрим контекстный блок со skip-соединением, как выше, т. е.

T(C,x) = x + A(C,x) + W'\, g_{\theta}\!\big(WA(C,x) + b\big) + b' \tag{23}

Пусть A(C,x) — контекстный слой, а g_θ(z) — любая дифференцируемая модель. Тогда влияние части Y⊂C контекста C на выход контекстного блока неявно соответствует обновлению весов ранга 1 для матрицы первого слоя W(Y)=W+ΔW(Y), а также обновлению смещения последнего слоя b'(Y) = b' + \Delta b'(Y) так, что

T_{W,\,b'}(C,x) = T_{\,W(Y),\,b'(Y)}\!\bigl(C\setminus Y,\, x\bigr), \tag{24}

Обновления параметров задаются формулами

\Delta b'(Y) = \Delta A(Y), \tag{25}\Delta W(Y) = \frac{(W\,\Delta A(Y))\,A(C\setminus Y, x)^{T}}        {\lVert A(C\setminus Y, x)\rVert^{2}}, \tag{26}

где \Delta A(Y) = A(C,x) - A(C\setminus Y, x) — контекстный вектор, соответствующий Y. Заметим, что ΔW(Y) имеет ранг 1, поскольку WΔA(Y) — столбец, а A(C\setminus Y, x)^{T} — строка.

Доказательство. Результат следует из прямого вычисления. В принятой нотации по определению

T_{W(Y),\,b'(Y)}\!\bigl(C\setminus Y, x\bigr) = x + A(C\setminus Y, x)   + W'\, g_{\theta}\!\left(\,(W+\Delta W(Y))\,A(C\setminus Y, x) + b\right)   + b' + \Delta b'(Y)=\, x + A(C\setminus Y, x) + \Delta b'(Y)+\, W' \, g_{\theta}\!\big( W\,A(C\setminus Y, x) + \Delta W(Y)\,A(C\setminus Y, x) + b \big) + b'

Подставляя теперь ΔW(Y) из определения и используя, что \frac{z^{T}}{\lVert z\rVert^{2}}, z = 1, получаем

\Delta W(Y)\,A(C\setminus Y, x) = \frac{(W\,\Delta A(Y))\,A(C\setminus Y, x)^{T}}{\lVert A(C\setminus Y, x)\rVert^{2}}\,A(C\setminus Y, x) = W\,\Delta A(Y).

Следовательно, получаем

T_{W(Y),\,b'(Y)}\!\bigl(C\setminus Y, x\bigr) = x + A(C\setminus Y, x) + \Delta A(Y)   + W'\, g_{\theta}\!\left( W\big(A(C\setminus Y, x)+\Delta A(Y)\big) + b \right)   + b'

Поскольку по определению контекстного вектора A(C\setminus Y, x) + \Delta A(Y) = A(C, x), в итоге имеем

T_{W(Y),\,b'(Y)}\!\bigl(C\setminus Y, x\bigr) = x + A(C,x) + W'\, g_{\theta}\!\big(WA(C,x) + b\big) + b' = T_{W,\,b'}(C,x)

чем доказательство завершается.

Заметим, что обновление вектора смещения \Delta b'(Y) по духу напоминает векторы функций из [27], выходы транскодера из [28] или латентные представления концептов из [29], используемые для редактирования весов трансформера. Также отметим, что данная теорема применима не только к контекстным слоям вида Pre-LN-блоков трансформера, как в [26], но и к другим типам контекстных слоёв — например, к слоям в рекуррентных моделях Griffin с локальным вниманием [21].

B. Альтернативная неявная динамика обучения ICL

В этом разделе мы описываем альтернативную неявную динамику обучения, получающуюся при итеративном применении Теоремы 2.2. Она выявляет неявную динамику обновления весов, порождённую вкладом каждого контекстного токена в выход контекстного блока. Это означает, что пока блок трансформера генерирует первый ответный токен, явного обновления весов не выполняется, однако фактический выход эквивалентен выходу того же контекстного блока без контекста, для которого в весовом пространстве произошла неявная динамика обучения. Ниже мы описываем эту динамику. А именно, начиная с начальной матрицы весов W_0​ первого полносвязного слоя нейросети M_{W_0}​​:

T_{W_{0}}(c_{1},\ldots,c_{n},x) = T_{\,W_{0}+\Delta W_{0}(c_{1})}(c_{2},\ldots,c_{n},x) \tag{27}

что даёт первое обновление весов, соответствующее эффекту токена c_1​ на матрицу весов первого слоя:

W_{1} = W_{0} + \frac{\bigl(W_{0}\,\Delta A(c_{1})\bigr)\,A(c_{2},\ldots,c_{n},x)^{T}}        {\lVert A(c_{2},\ldots,c_{n},x)\rVert^{2}} \tag{28}

Если продолжать этот процесс итеративно, получаем следующее обновление весов, соответствующее «поглощению» второго токена:

T_{W_{1}}(c_{2},\ldots,c_{n},x) = T_{\,W_{1}+\Delta W_{1}(c_{2})}(c_{3},\ldots,c_{n},x) \tag{29}

откуда

W_{2} = W_{1} + \frac{\bigl(W_{1}\,\Delta A(c_{2})\bigr)\,A(c_{3},\ldots,c_{n},x)^{T}}        {\lVert A(c_{3},\ldots,c_{n},x)\rVert^{2}} \tag{30}

Итак, итеративный процесс неявных обновлений весов для каждого следующего токена можно суммировать так.

Следствие B.0.1. В использованной выше нотации, итеративный процесс обновлений весов имеет вид

W_{i} = W_{i-1} + \frac{\bigl(W_{i-1}\,\Delta A(c_{i})\bigr)\,A(c_{i+1},\ldots,c_{n},x)^{T}}        {\lVert A(c_{i+1},\ldots,c_{n},x)\rVert^{2}} \tag{31}

где начальные веса первого полносвязного слоя W_0​ моделируют перенос информации из токена промпта c_i​ в веса контекстного блока. Иначе говоря, выполняется следующее

T_{W_{i}}(c_{i+1},\ldots,c_{n},x) = T_{W_{0}}(c_{1},\ldots,c_{n},x), \tag{32}

для всех i=1,…,n где \Delta A(c_{i}) = A(c_{i},\ldots,c_{n},x)   - A(c_{i+1},\ldots,c_{n},x).

Заметим, что ΔA(c_i) измеряет вклад токена контекста c_i в выход контекстного блока. Если c_i не влияет на выход (то есть ΔA(c_i)=0), соответствующее обновление исчезает. Обратите внимание, что обновление весов на шаге i линейно по весам; а именно, его можно переписать как

W_i = W_{i-1} + h_i\, W_{i-1} A_i      = W_{i-1}\bigl(1 + h_i A_i\bigr) \quad\text{where}\quad A_i := \Delta A(c_i)\, A(c_{i+1},\ldots,c_n,x)^{T}. \tag{33}

с адаптивной скоростью обучения

h_i:=\frac{1}{\lVert A(c_{i+1},\ldots,c_n,x)\rVert^{2}} \tag{34}

В частности, это даёт формулу факторизации для полной неявной матрицы весов, соответствующей эффекту контекста [c_1,…,c_n] на входной токен x:

W_{n} = W_{0}\,(1 + h_{1}A_{1})(1 + h_{2}A_{2})\cdots(1 + h_{n}A_{n}). \tag{35}

Комментарии (0)