Введение
Привет, Habr! Меня зовут Андрей Атаманюк, я Data Scientist в R&D команде рекомендательных систем Wildberries & Russ. В этой статье я разберу тонкости обучения двухбашенных моделей (без специфики к домену рекомендаций), которые могут существенно влиять на качество рекомендаций, но часто остаются за кадром. Речь пойдёт о систематическом росте норм эмбеддингов популярных товаров — эффекте, который противоречит интуитивным ожиданиям от косинусных лоссов.

Предыстория
Всё началось, когда я обучал DSSM-модель для рекомендаций товаров маркетплейса пользователям с использованием стандартного подхода:
Двухбашенная архитектура с раздельными энкодерами
Косинусный лосс (InfoNCE, ещё известен как Full Product Softmax loss) для обучения с in-batch негативами
Ожидание: косинусный лосс должен нивелировать влияние популярности, потому что косинус зависит только от угла между айтемами и, как я думал раньше, обучение на него не должно «мотивировать» эмбеддинги «впитывать» популярность в их норму.
Но практика преподнесла сюрприз: при анализе эмбеддингов популярных товаров их нормы систематически росли, в то время как нишевые товары демонстрировали стагнацию. Причем эффект усиливался с ростом частоты встречаемости товара в данных. При составлении рекомендаций я пробовал использовать и косинус, и скалярное произведение между эмбеддингами юзера и айтема — при использовании скалярного произведения возникали систематические искажения. О них можно прочитать ниже.
Почему это критично для индустрии?
Ранжирование: большинство систем используют скалярное произведение для оценки релевантности:
(
— эмбеддинг пользователя,
— эмбеддинг товара)
-
Систематическое искажение: рост
для популярных товаров искусственно завышает их релевантность:
# Псевдокод ранжирования scores = [] for item in candidate_items: score = dot(user_embedding, item_embedding) # Зависит от нормы! scores.append(score) top_items = sort(scores)[:10]
Эффект «богатые богатеют»: хиты получают ещё больше показов → нормы растут → цикл усиливается
Подавление long-tail: нишевые товары систематически недопредставляются
Что вы найдёте в статье
Статья объяснит:
Почему популярные работы (NormFace и др.) упускают ключевые факторы
Их доказательства не учитывают архитектуру энкодеровКак архитектура модели влияет на динамику эмбеддингов
На примере 9 экспериментов покажу, когда возникает ортогональность движения эмбеддингов и когда — корреляция норм эмбеддингов айтемов с популярностью айтемов (часто буду называть это явление popularity bias)Аналитический вывод зависимости нормы эмбеддинга объекта от его популярности
Разнообразные эксперименты усиливают уверенность в полученных выводах
Ценность для практиков:
Сэкономите месяцы на отладке «необъяснимых» артефактов обучения
Получите воспроизводимые примеры на PyTorch для собственных экспериментов
О чем статья?
В результате исследований процессов обучения двухбашенных моделей (в моем случае была задача коллаборативной фильтрации) я определил причины, из-за которых у популярных объектов наблюдается систематическое увеличение нормы эмбеддингов в процессе обучения с использованием функции потерь, использующей косинусную меру близости. У этого эффекта есть концептуальное объяснение, выходящее из анализа формул обратного распространения ошибки. Этот анализ объясняет, при выполнении каких факторов будет наблюдаться такое влияние популярности объектов на результат процесса обучения.
Это исследование уточняет границы применимости результатов некоторых популярных статей: NormFace: L2 Hypersphere Embedding for Face Verification (930 цитирований), The Hidden Pitfalls of the Cosine Similarity Loss (5 цитирований) и On the Importance of Embedding Norms in Self-Supervised Learning (2 цитирования). Я аналитически вывожу условия, при которых в процессе обучения моделей эмбеддинги объектов движутся ортогонально и при которых проявляется эффект popularity bias — корреляция норм эмбеддингов айтемов с популярностью айтемов.
В конце статьи я покажу результаты экспериментов, подтверждающие сказанное, а также там будут воспроизводимые примеры кода соответствующих экспериментов.
Одна формула, на которой строится весь анализ
Формулы обратного распространения ошибки
Определим общий вид модели, процесс обучения которой нас интересует: это двухбашенная нейронная сеть, у которой есть левая башня (левый энкодер) и правая башня (правый энкодер), которые разделены и не имеют общих параметров. Очень часто решают задачу выучивания эмбеддингов именно с помощью такой модели: выход левой башни приближается или отдаляется относительно выхода правой башни (например, используя функцию потерь, в которой мы минимизируем/максимизируем скалярное произведение эмбеддинга юзера и эмбеддинга купленного товара — выходов левого и правого энкодеров).
Самый простой пример такой модели — Alternating Least Squares, где левый и правый энкодер являются просто матрицами юзерных и айтемных эмбеддингов. Мы же начнём рассматривать общую теорию для энкодеров любой архитектуры. По ходу повествования будем вводить дополнительные ограничения на архитектуру — в тех местах, где дальнейшая математика невозможна без таких ограничений.
Такие модели обучаются методом обратного распространения ошибки. Согласно общей теории, как работает этот метод (раздел 6.5.2 в книге Deep Learning Textbook), для конкретно взятого энкодера (неважно, левого или правого) мы можем посмотреть, как там считаются формулы для обновления его параметров:
Отсюда понятно, как выглядит формула градиента лосса по параметрам энкодера:
То есть в результате SGD-шага на этом батче мы получим следующее изменение параметров энкодера:
Мы описали, как совокупность всех примеров в батче влияет на изменение параметров энкодера. Теперь посмотрим, как это изменение повлияет на выход энкодера для конкретного -го примера в батче. Выразим это в виде линейной аппроксимации первого порядка:
Именно эта формула является отправной точкой дальнейшего анализа.
Первый важный вывод про процесс изменения эмбеддингов
Исследуем popularity bias только для энкодеров, линейных по параметрам
Для начала заметим, что формула выше, описывающая изменение эмбеддинга в процессе обучения, является линейной аппроксимацией. То есть получить строгий аналитический вывод, при выполнении каких факторов будет наблюдаться popularity bias (в данном контексте под этим термином я подразумеваю корреляцию норм эмбеддингов с популярностью айтемов), мы можем только для линейных по параметрам архитектур энкодеров, потому что именно для них эта линейная аппроксимация совпадает с точным изменением эмбеддинга.
Что такое энкодер, линейный по параметрам
Это энкодер, у которого Якобиан выхода J не зависит от θ
Для ясности определений: назовём как
— выход энкодера c параметрами
для
-го объекта,
— как
.
Мы будем считать энкодер линейным по параметрам, если его выход для любого входного примера можно записать как:
где матрица зависит только от входа (и, возможно, от фиксированных гиперпараметров), но не зависит от параметров
.
При таком определении остается константой на всем пространстве параметров. То есть формула
выше является уже не просто линейной аппроксимацией, а точной формулой изменения
. На её основе можно строить дальнейший анализ.
Какие архитектуры можно считать линейными по параметрам?
Рассмотрим примеры архитектур, в которых J не зависит от параметров энкодера θ:
1. Энкодер состоит из единственного embedding-слоя
В этом случае в везде, кроме столбцов <
-ый эмбеддинг,
-ая координата>, стоят нули, а в блоке из этих столбцов — единичная матрица:
2. Энкодер состоит из единственного линейного слоя (без bias)
В этом случае в в каждом из
блоков (ширины
) на строке
стоит копия вектора
; все остальные элементы нулевые:
Если присутствует bias, то к просто приклеивается справа единичная матрица.
Пример архитектуры, которая не является линейной по параметрам
Энкодер состоит из двух последовательно идущих линейных слоёв (без bias)
Посмотрим на энкодер, состоящий из двух последовательно идущих линейных слоёв (без bias):
Выходит, что два линейных слоя — это нелинейная по параметрам модель.
Это линейная по входу модель (linear_layer_№1 * linear_layer_№2 * x = linear_layer_№3 * x). Но если держать в голове то, о чем нам говорит Якобиан (насколько сдвинется выход, если немного подвинуть параметр), то несложно догадаться, что если мы подвинем вес в первом линейном слое, то на выходе это скажется как-то с учетом весов второго линейного слоя: между первым слоем и выходом как раз находится домножение на эту матрицу второго линейного слоя. А раз матрица зависит от параметров, значит, что такая архитектура не является линейной по параметрам и мы не можем получить строгий аналитический вывод про popularity bias.
Для некоторых линейных архитектур изменение эмбеддинга коллинеарно градиенту лосса по выходу
Для этого входы должны быть ортогональны
Вспомним формулу изменения эмбеддинга для энкодеров, линейных по параметрам:
После изучения, как именно выглядят Якобианы для некоторых примеров архитектур, становится понятно их свойство, при котором эмбеддинг будет двигаться коллинеарно его градиенту ( ||
): важно, чтобы в модели у каждого уникального входа была «собственная» строка параметров, которая не пересекается c чужими. Это просто понять, если посмотреть на примеры.
Рассмотрим случай с единственным embedding-слоем:
Шаг SGD:
Теперь случай с линейным слоем без bias:
Шаг SGD:
Из-за того, что разные входы в энкодер проходят через одни и те же параметры, выход энкодера уже не меняется строго коллинеарно его градиенту. Но если все сделать ортогональными друг другу (например, если
— это one-hot эмбеддинг), то вклад остальных объектов обнулится и будет коллинеарность. В сущности, nn.Linear над one-hot фичами (без bias в слое) и nn.Embedding — это одно и то же, если не брать в расчет производительность (matmul vs lookup).
Градиент любого косинусного лосса по выходу энкодера ортогонален ему
Если эмбеддинг объекта фигурирует в лоссе только через косинусы, то градиент лосса ортогонален этому эмбеддингу
Пусть эмбеддинг , выходящий из энкодера, участвует в лоссе только через косинусные близости с другими эмбеддингами, порожденными двухбашенной моделью. Тогда лосс можно записать так:
Тут производная лосса по косинусу — это какой-то скаляр. Мы сможем сказать, что градиент лосса по ортогонален
, если:
Запишем:
То есть
для любого косинусного лосса. Это значит, что если
||
, то
— в процессе обучения эмбеддинг движется ортогонально:

Резюме: факторы, при которых эмбеддинг движется ортогонально
Если для энкодера в двухбашенной модели выполняется четыре фактора:
в качестве оптимайзера используется простой SGD без моментума и регуляризации
это линейный (по параметрам) энкодер
у каждого уникального входа в энкодер «собственная» строка параметров, которая не пересекается c чужими
градиент лосса по выходу энкодера (эмбеддингу) ортогонален выходу эмбеддинга
То получается аналитически доказать, что в процессе обучения эмбеддинг движется ортогонально, т.е. по касательной к гиперсфере, на которой он лежал на предыдущем шаге.
Связь с популярными статьями про ортогональность движения эмбеддингов
В статье NormFace: L2 Hypersphere Embedding for Face Verification (930 цитирований) предлагают новый способ учить глубокие энкодеры. В разделе 3.2 они доказывают, что для их выхода сети и градиента их рассматриваемого лосса
выполняется ортогональность, после чего сразу же делают логический переход "it can be inferred that after update,
always increases", подразумевая то, что раз градиент по выходу сети ортогонален, то и сдвинется выход тоже ортогонально. Но мне остался непонятным этот переход: от ортогональности градиента к ортогональности самого обновления.
Такая ортогональность шага гарантируется лишь при выполнении набора достаточно строгих условий (см. выше). Мы эти условия выписываем явно и показываем: при их соблюдении эмбеддинги действительно движутся ортогонально; вне этой области переход от ортогональности градиента к ортогональности движения эмбеддингов (и, соответственно, к непрерывно возрастающей норме эмбеддингов) может не выполняться. Следовательно, мотивацию предложенного изменения корректно трактовать как условную: она верна при выполнении описанных нами условий, но не является универсальным основанием для модификации способа обучения энкодеров — тем более глубоких.
Но при выполнении условий, аналитически доказывающих ортогональность движения эмбеддингов, действительно возникает рост нормы эмбеддингов (по теореме Пифагора). Возникает вопрос: как его оценить аналитически?
Анализ изменения нормы эмбеддинга объекта от популярности объекта
В этом разделе будем полагать, что выполняются вышеописанные 4 фактора и поэтому в процессе обучения эмбеддинг движется ортогонально. Тогда по теореме Пифагора:
Прежде чем анализировать изменение нормы эмбеддинга, я сделаю следующее утверждение и докажу его — оно понадобится далее:
Чем больше эмбеддинг, тем медленнее он растет при косинусном лоссе
Чем длиннее , тем меньше норма градиента косинусного лосса по
Почему это важно: получается сложный контр-эффект "чем больше эмбеддинг, тем медленнее он растет", и этот контр-эффект нужно будет учитывать в дальнейших формулах для анализа изменения популярных объектов.
Теперь докажем — начну с использования формулы из Приложения №1:
Посмотрим на эту : так как это матрица, то это линейное отображение. Но так как она идемпотентна (
) и симметрична (
), то она не только линейное отображение, но ещё и по определению является ортогональным проектором. Этот ортогональный проектор переводит любой вектор (на который она действует через умножение слева) в пространство, состоящее из всех векторов, перпендикулярных
.
Теперь посмотрим, какие выводы можно сделать для произвольного косинусного лосса:
Посмотрим, как можно выразить норму этого градиента:
Здесь видно, что с ростом нормы эмбеддинга норма градиента уменьшается:
Знаменатель — норма эмбеддинга
Числитель не зависит от нормы эмбеддинга
:
Когда мы считаем, то
фигурирует там только в расчете производной
по косинусу. Норма
меняется → косинус не меняется → вход функции
не изменился → производная
по косинусу не зависит от нормы
.
Наконец, сам анализ изменения нормы эмбеддинга объекта от популярности объекта:
Неотрицательная зависимость нормы эмбеддинга от популярности
Хотим показать, что величина в процессе обучения длиной в
батчей в среднем не убывает по
.
Воспользуемся методом связывания (coupling) из теории вероятностей (сборник "Modern Discrete Probability: An Essential Toolkit"):
Примечание о распределениях
Пусть исходное распределение по всем айтемам равно . На множестве «не-i» айтемов (
) зададим «остаточное» распределение
, которое получается перемасштабированием (
-ый айтем выкинули из распределения — остальные вероятности увеличились):
Рассмотрим один слот генерации. Используем общие и исход категориального распределения «не-i» айтемов
:
Тогда:
и для конкретных :
Примечание о различии батчей
В вышеописанных запусках собранные батчи отличаются на слотах, где выпало . Поэтому в ситуации, где
оказался в обоих собранных батчах, может случиться
— например, если запуск с
получил положительные и отрицательные взаимодействия с айтемом
, а запуск с
получил только положительные взаимодействия с айтемом
.
Однако просто события мало для того, чтобы порядок
сменился на
— для такого сильного перелома накопленной суммы приростов требуется
, что является маловероятным событием (см. Приложение №2, "Событие
является редким")
Вывод
Если для айтемной башни выполняются факторы:
в качестве оптимайзера используется простой SGD без моментума и регуляризации
эта башня — линейный (по параметрам) энкодер
у каждого уникального входа в энкодер «собственная» строка параметров, которая не пересекается c чужими
градиент лосса по выходу энкодера (эмбеддингу) ортогонален выходу эмбеддинга
достаточно высокий learning rate (ортогональность выполняется и без этого; это нужно, чтобы «перебить» начальную инициализацию эмбеддингов)
то эмбеддинги у популярных айтемов имеют систематически большую норму, чем у непопулярных айтемов.
Здесь видно, сколько требований должно выполняться для того, чтобы эмбеддинги двигались ортогонально, чтобы их норма монотонно росла и чтобы это приводило к popularity bias. Таким образом, наше исследование уточняет выводы вышеупомянутых статей — мы показываем, при выполнении каких факторов действительно выполняется ортогональность движения эмбеддингов. В литературе на тему обучения представлений с использованием косинусного лосса эти моменты ранее не освещались.
Подтверждение выводов в реальных экспериментах
Эксперименты, изучающие факторы, при которых проявляется ортогональность движения эмбеддингов
Определим:
Условие 1 — архитектура энкодера должна состоять из линейного слоя (но только над входными данными в формате onehot-векторов) или из эмбеддинг слоя
Условие 2 — градиент лосса по выходу энкодера должен быть ортогонален выходу энкодера.
№ |
Левый энкодер |
Лосс |
Условие 1? |
Условие 2? |
Траектория ортогональна? |
---|---|---|---|---|---|
1 |
Embedding |
InfoNCE (cos) |
✔️ |
✔️ |
✔️ |
2 |
Embedding |
InfoNCE (dot) |
✔️ |
❌ |
Юзер ❌ / Айтем ❌ |
3 |
Transformer |
InfoNCE (cos) |
❌ |
✔️ |
Юзер ❌ / Айтем ✔️ |
4 |
Embedding → Linear |
InfoNCE (cos) |
❌ |
✔️ |
Юзер ❌ / Айтем ✔️ |
5 |
Embedding (frozen) → Linear |
InfoNCE (cos) |
❌ |
✔️ |
Юзер ❌ / Айтем ✔️ |
6 |
one-hot → Linear |
InfoNCE (cos) |
✔️ |
✔️ |
✔️ |
7 |
Embedding |
✔️ |
✔️ |
✔️ |
Для начала уточню, что во всех экспериментах в айтемной башне (правом энкодере) использовался только nn.Embedding, а InfoNCE — это контрастивный софтмакс лосс с in-batch негативами, где cos или dot — это используемая операция в этом лоссе (мы оцениваем релевантность юзера к айтему или через косинусную близость, или через скалярное произведение). Везде используется SGD без моментума — при использовании других вариантов оптимайзеров не удалось обнаружить никакую зависимость в движении эмбеддингов. В эксперименте №4 архитектура совпадает с one-hot → Linear → Linear , т.е. это просто 2 линейных слоя. Архитектура из эксперимента №5 является показательным примером архитектуры, состоящей из единственного Linear слоя без доп. условия на ортогональность (one-hot) входов — тут просто какие-то числовые фичи на входе.
Что случится, если в модели будет нелинейность по параметрам:
Теперь вернёмся к высказыванию, что два линейных слоя — это нелинейная по параметрам модель: я утверждал, что из-за нелинейности нельзя получить аналитическое объяснение того, как там движутся эмбеддинги и проявляется ли popularity bias. Результат эксперимента №4 согласуется с этими рассуждениями — в отличие от экспериментов №1 и №6, здесь эмбеддинги из левого энкодера движутся хаотично.
Что случится, если в линейной по параметрам модели входы будут не ортогональны друг другу:
Вспомним, что линейности энкодера еще недостаточно для того, чтобы гарантировалось ортогональное движение эмбеддингов, и взглянем на результат эксперимента №5 — эмбеддинги из левого энкодера движутся хаотично. Это согласуется с моим высказыванием, что для ортогональности движения эмбеддингов требуется, чтобы у каждого уникального входа в энкодер «собственная» строка параметров, которая не пересекается c чужими. Я смог обнаружить только две (идентичные) архитектуры, при которых это выполняется — линейный слой над onehot-векторами (эксперимент №6) или эмбеддинг слой (эксперимент №1).
Что случится, если использовать какой-то другой лосс:
Результат эксперимента №7 не противоречит моему высказыванию «градиент любого косинусного лосса по выходу энкодера ортогонален ему».
Эти эксперименты позволили мне быть чуть увереннее в аналитических выводах про ортогональность движения эмбеддингов.
Эксперименты, проверяющие, действительно ли ортогональность движения эмбеддингов приводит к появлению popularity bias
№ |
Левый энкодер |
Лосс |
Условие 1? |
Условие 2? |
Траектория ортогональна? |
Достаточно большое ортогональное смещение? |
Popularity bias? |
---|---|---|---|---|---|---|---|
8 |
Embedding |
InfoNCE (cos) |
✔️ |
✔️ |
✔️ |
✔️ (высокая скорость обучения) |
✔️ |
9 |
Embedding |
InfoNCE (cos) |
✔️ |
✔️ |
✔️ |
❌ (низкая скорость обучения) |
❌ |
В экспериментах №8 и №9 я проверил, действительно ли выполнение этого свойства приводит к появлению popularity bias. Я взял настоящие данные о заказах, на которых учатся наши рекомендательные модели в проде, и запустил на них обучение двухбашенной нейронной сети, максимально приближенной к production версии рекомендательной модели DSSM.
Результат — да, действительно, если эмбеддинги двигаются ортогонально ввиду выполнения вышеописанных факторов, то они наращивают свою норму; таким образом, айтемы, чаще "получающие движение" (более популярные - чаще встречающиеся в датасете и, соответственно, чаще попадающие в обучающие батчи), имеют систематически большую норму - наблюдается значительная корреляция (коэффициент корреляции Пирсона 0.54) популярности айтемов и норм эмбеддингов. Но есть нюанс:
Получилось выявить важный фактор, выполнение которого необходимо для появления popularity bias при ортогональном движении эмбеддингов: эмбеддинги должны двигаться с достаточно большим ортогональным смещением, т.к. движение по касательной на малое расстояние почти не растит норму эмбеддинга (по теореме Пифагора) и этот незначительный рост не сможет "перебороть" изначальный рандом норм при инициализации эмбеддингов. В эксперименте №8 мне пришлось сильно увеличить Learning Rate у SGD оптимайзера и понизить температуру лосса, чтобы объекты двигались на бОльшие расстояния (при использовании более общепринятного LR = 0.1 и temperature = 0.1 объекты в процессе обучения двигались крайне медленно и вообще не переходили на гиперсферы большего радиуса - я наблюдал околонулевую корреляцию популярности айтемов и норм эмбеддингов в эксперименте №9).
Этот результат согласуется с анализом изменения нормы эмбеддинга айтема от популярности айтема.
Код экспериментов №1 и №4
Я выбрал эти эксперименты как самые показательные для описываемого явления; 1 эксперимент воспроизводит условия обучения, при которых эмбеддинги левого энкодера движутся ортогонально. 4 эксперимент нарушает эти условия (реализуя "общий" дизайн модели, используемый, например, в статье NormFace) и показывает, что если они нарушаются, то эмбеддинги движутся хаотично, а не ортогонально. Код позволяет мониторить динамику движения эмбеддингов через численные показатели.
Код эксперимента №1 (архитектура левого энкодера состоит только из Embedding)
import math
import random
import numpy as np
import pandas as pd
from tqdm import tqdm
import plotly.express as px
import matplotlib.pyplot as plt
from sklearn.preprocessing import LabelEncoder
import torch
import torch.nn as nn
from torch.utils.data import DataLoader, Dataset
from info_nce import InfoNCE # pip install info-nce-pytorch
torch.manual_seed(0)
np.random.seed(0)
random.seed(0)
device = 'cuda' if torch.cuda.is_available() else 'cpu'
# игрушечные данные - коллаборативная модель выучит, что 5 юзер,
# 6 юзер и 7 юзер близки друг к другу из-за позитивных
# взаимодействий с одними и теми же айтемами (5i* и 6i*)
train = pd.DataFrame([], columns=["user_id", "item_id"])
train.loc[len(train)] = ["1u", "1i1"]
train.loc[len(train)] = ["1u", "1i2"]
train.loc[len(train)] = ["1u", "1i3"]
train.loc[len(train)] = ["2u", "2i1"]
train.loc[len(train)] = ["2u", "2i2"]
train.loc[len(train)] = ["2u", "2i3"]
train.loc[len(train)] = ["3u", "3i1"]
train.loc[len(train)] = ["3u", "3i2"]
train.loc[len(train)] = ["3u", "3i3"]
train.loc[len(train)] = ["4u", "4i1"]
train.loc[len(train)] = ["4u", "4i2"]
train.loc[len(train)] = ["4u", "4i3"]
train.loc[len(train)] = ["5u", "5i1"]
train.loc[len(train)] = ["5u", "5i2"]
train.loc[len(train)] = ["5u", "5i3"]
train.loc[len(train)] = ["6u", "6i1"]
train.loc[len(train)] = ["6u", "6i2"]
train.loc[len(train)] = ["6u", "6i3"]
train.loc[len(train)] = ["7u", "5i1"]
train.loc[len(train)] = ["7u", "5i2"]
train.loc[len(train)] = ["7u", "5i3"]
train.loc[len(train)] = ["7u", "5i4"]
train.loc[len(train)] = ["7u", "6i1"]
train.loc[len(train)] = ["7u", "6i2"]
train.loc[len(train)] = ["7u", "6i3"]
user_encoder = LabelEncoder().fit(train["user_id"])
item_encoder = LabelEncoder().fit(train["item_id"])
class TwoTowerDataset(Dataset):
def __init__(self,
positive_interactions_dataframe,
user_encoder,
item_encoder
):
self.positive_interactions_dataframe = \
positive_interactions_dataframe.copy()
self.positive_interactions_dataframe["encoded_user_id"] = \
user_encoder.transform(self.positive_interactions_dataframe.user_id)
self.positive_interactions_dataframe["encoded_item_id"] = \
item_encoder.transform(self.positive_interactions_dataframe.item_id)
self.user2features = {}
self.item2features = {}
self.interactions_features_list = []
for row in tqdm(self.positive_interactions_dataframe.itertuples()):
self.interactions_features_list.append(
(
(row.encoded_user_id,),
(row.encoded_item_id,)
)
)
if row.user_id not in self.user2features:
self.user2features[row.user_id] = (row.encoded_user_id,)
if row.item_id not in self.item2features:
self.item2features[row.item_id] = (row.encoded_item_id,)
self.start = 0
self.end = len(self.interactions_features_list)
def __len__(self):
return len(self.interactions_features_list)
def __getitem__(self, i):
return self.interactions_features_list[i]
train_dataset = TwoTowerDataset(train, user_encoder, item_encoder)
def collate_fn(batch):
users_features, items_features = zip(*batch)
return (
torch.IntTensor(np.array(users_features)),
torch.IntTensor(np.array(items_features)),
)
train_loader = DataLoader(
train_dataset,
batch_size=len(train_dataset),
collate_fn=collate_fn
)
class TwoTowerModel(nn.Module):
def __init__(self, user_embedding_sizes, item_embedding_sizes, device):
super(TwoTowerModel, self).__init__()
self.device = device
self.user_embeds = nn.Embedding(
user_embedding_sizes[0],
user_embedding_sizes[1]
)
self.item_embeds = nn.Embedding(
item_embedding_sizes[0],
item_embedding_sizes[1]
)
def get_user_embeddings(self, user_features):
"""Пользовательская часть TwoTower"""
user_embeddings = self.user_embeds(user_features[:, 0])
return user_embeddings
def get_item_embeddings(self, item_features):
"""Айтемная часть TwoTower"""
item_embeddings = self.item_embeds(item_features[:, 0])
return item_embeddings
def forward(self, user_features, item_features):
user_embs = self.get_user_embeddings(user_features.to(self.device))
item_embs = self.get_item_embeddings(item_features.to(self.device))
return user_embs, item_embs
infonceloss = InfoNCE(temperature=1)
def plot_embedding_shift(
old_users_embeddings,
old_items_embeddings,
new_users_embeddings,
new_items_embeddings,
users_ids,
items_ids,
title="Embedding Shift"
):
"""
Рисует сдвиг эмбеддингов с окружностью разных радиусов.
old_users_embeddings: Тензор старых эмбеддингов юзеров
old_items_embeddings: Тензор старых эмбеддингов айтемов
new_users_embeddings: Тензор новых эмбеддингов юзеров
new_items_embeddings: Тензор новых эмбеддингов айтемов
users_ids: Список id юзеров в порядке, соответствующем порядку
эмбеддингов юзеров
items_ids: Список id айтемов в порядке, соответствующем порядку
эмбеддингов айтемов
"""
old_users_embeddings = old_users_embeddings.cpu().detach().numpy()
old_items_embeddings = old_items_embeddings.cpu().detach().numpy()
new_users_embeddings = new_users_embeddings.cpu().detach().numpy()
new_items_embeddings = new_items_embeddings.cpu().detach().numpy()
# Вводим новую метрику: какая доля градиентов расположена
# перпендикулярно радиус-вектору
# для юзеров:
count_orthogonal_users = 0
for i in range(len(old_users_embeddings)):
grad = new_users_embeddings[i] - old_users_embeddings[i]
is_orthogonal = math.isclose(
grad @ old_users_embeddings[i],
0,
abs_tol=1e-6
)
if is_orthogonal:
count_orthogonal_users += 1
print(
"Доля юзеров, для которых градиент направлен строго по "
"касательной к гиперсфере, на которой изначально лежал "
"эмбеддинг этого объекта: ",
count_orthogonal_users / len(old_users_embeddings)
)
# для айтемов:
count_orthogonal_items = 0
for i in range(len(old_items_embeddings)):
grad = new_items_embeddings[i] - old_items_embeddings[i]
is_orthogonal = math.isclose(
grad @ old_items_embeddings[i],
0,
abs_tol=1e-6
)
if is_orthogonal:
count_orthogonal_items += 1
print(
"Доля айтемов, для которых градиент направлен строго по "
"касательной к гиперсфере, на которой изначально лежал "
"эмбеддинг этого объекта: ",
count_orthogonal_items / len(old_items_embeddings)
)
if old_users_embeddings.shape[1] == 2:
# Дальнейшая визуализация - только для двумерного случая:
old_embs = torch.cat(
(
torch.from_numpy(old_users_embeddings),
torch.from_numpy(old_items_embeddings)
),
0
).numpy()
new_embs = torch.cat(
(
torch.from_numpy(new_users_embeddings),
torch.from_numpy(new_items_embeddings)
),
0
).numpy()
ids=users_ids+items_ids
# Создаем DataFrame для интерактивного графика
df = pd.DataFrame({
"old_x": old_embs[:, 0],
"old_y": old_embs[:, 1],
"new_x": new_embs[:, 0],
"new_y": new_embs[:, 1],
"id": ids
})
# График со стрелками
fig, ax = plt.subplots(figsize=(20, 14))
# Рисуем стрелки
for i in range(len(df)):
ax.arrow(df["old_x"][i],
df["old_y"][i],
df["new_x"][i] - df["old_x"][i],
df["new_y"][i] - df["old_y"][i],
head_width=max(
np.sqrt((df["new_x"][i] - df["old_x"][i])**2 +
(df["new_y"][i] - df["old_y"][i])**2) * 0.2,
0.1),
head_length=max(
np.sqrt((df["new_x"][i] - df["old_x"][i])**2 +
(df["new_y"][i] - df["old_y"][i])**2) * 0.2,
0.1),
fc='blue',
ec='blue',
alpha=0.7)
# Точки эмбеддингов
ax.scatter(df["new_x"],
df["new_y"],
color='red',
label="New Positions",
s=30)
ax.scatter(df["old_x"],
df["old_y"],
color='gray',
alpha=0.5,
label="Old Positions",
s=30)
# Рисуем оси и разные окружности
line1 = plt.axline((1, 0), (0, 0),
color='black',
linestyle='dashed',
linewidth=2)
line2 = plt.axline((0, 1), (0, 0),
color='black',
linestyle='dashed',
linewidth=2)
circle1 = plt.Circle((0, 0),
1,
color='green',
fill=False,
linestyle='dashed',
linewidth=2)
ax.add_patch(circle1)
circle2 = plt.Circle((0, 0),
1.5,
color='green',
fill=False,
linestyle='dashed',
linewidth=2)
ax.add_patch(circle2)
circle3 = plt.Circle((0, 0),
2,
color='green',
fill=False,
linestyle='dashed',
linewidth=2)
ax.add_patch(circle3)
# Настройки осей и легенды
ax.set_xlim(-3, 3)
ax.set_ylim(-3, 3)
ax.set_title(title)
ax.legend()
ax.set_aspect('equal')
plt.show()
# Интерактивная версия с plotly (отображение id при наведении)
fig_interactive = px.scatter(df,
x="new_x",
y="new_y",
text="id",
title="Embedding Positions with IDs",
hover_data={
"id": True,
"new_x": False,
"new_y": False
})
fig_interactive.update_layout(
autosize=False,
width=800,
height=800,
)
fig_interactive.show()
user_embedding_sizes = [len(user_encoder.classes_), 2]
item_embedding_sizes = [len(item_encoder.classes_), 2]
model = TwoTowerModel(user_embedding_sizes, item_embedding_sizes,
device).to(device)
optimizer = torch.optim.SGD(model.parameters(), lr=5, momentum=0)
users_ids, users_features = zip(*train_dataset.user2features.items())
items_ids, items_features = zip(*train_dataset.item2features.items())
def train_loop(model, optimizer, train_loader, n_epochs=100):
loss_history = list()
for epoch in range(n_epochs):
for batch in tqdm(train_loader, desc=f'Epoch {epoch}'):
user_embs, item_embs = model(*batch)
loss = infonceloss(user_embs, item_embs)
optimizer.zero_grad()
loss.backward()
old_users_embeddings = model.get_user_embeddings(
torch.IntTensor(np.array(users_features)).to(device)
)
old_items_embeddings = model.get_item_embeddings(
torch.IntTensor(np.array(items_features)).to(device)
)
optimizer.step()
new_users_embeddings = model.get_user_embeddings(
torch.IntTensor(np.array(users_features)).to(device)
)
new_items_embeddings = model.get_item_embeddings(
torch.IntTensor(np.array(items_features)).to(device)
)
loss_history.append(loss.item())
# отображаем только первые эпохи, эпоху посередине и последнюю
if (epoch < 3) or (epoch == n_epochs//2) or (epoch == n_epochs-1):
print(f'Train loss: {loss_history[-1]}')
plot_embedding_shift(
old_users_embeddings,
old_items_embeddings,
new_users_embeddings,
new_items_embeddings,
users_ids,
items_ids
)
model.train()
train_loop(model, optimizer, train_loader)
Код эксперимента №4 (архитектура левого энкодера состоит из Embedding→Linear)
import math
import random
import numpy as np
import pandas as pd
from tqdm import tqdm
import plotly.express as px
import matplotlib.pyplot as plt
from sklearn.preprocessing import LabelEncoder
import torch
import torch.nn as nn
from torch.utils.data import DataLoader, Dataset
from info_nce import InfoNCE # pip install info-nce-pytorch
torch.manual_seed(0)
np.random.seed(0)
random.seed(0)
device = 'cuda' if torch.cuda.is_available() else 'cpu'
# игрушечные данные - коллаборативная модель выучит, что 5 юзер,
# 6 юзер и 7 юзер близки друг к другу из-за позитивных
# взаимодействий с одними и теми же айтемами (5i* и 6i*)
train = pd.DataFrame([], columns=["user_id", "item_id"])
train.loc[len(train)] = ["1u", "1i1"]
train.loc[len(train)] = ["1u", "1i2"]
train.loc[len(train)] = ["1u", "1i3"]
train.loc[len(train)] = ["2u", "2i1"]
train.loc[len(train)] = ["2u", "2i2"]
train.loc[len(train)] = ["2u", "2i3"]
train.loc[len(train)] = ["3u", "3i1"]
train.loc[len(train)] = ["3u", "3i2"]
train.loc[len(train)] = ["3u", "3i3"]
train.loc[len(train)] = ["4u", "4i1"]
train.loc[len(train)] = ["4u", "4i2"]
train.loc[len(train)] = ["4u", "4i3"]
train.loc[len(train)] = ["5u", "5i1"]
train.loc[len(train)] = ["5u", "5i2"]
train.loc[len(train)] = ["5u", "5i3"]
train.loc[len(train)] = ["6u", "6i1"]
train.loc[len(train)] = ["6u", "6i2"]
train.loc[len(train)] = ["6u", "6i3"]
train.loc[len(train)] = ["7u", "5i1"]
train.loc[len(train)] = ["7u", "5i2"]
train.loc[len(train)] = ["7u", "5i3"]
train.loc[len(train)] = ["7u", "5i4"]
train.loc[len(train)] = ["7u", "6i1"]
train.loc[len(train)] = ["7u", "6i2"]
train.loc[len(train)] = ["7u", "6i3"]
user_encoder = LabelEncoder().fit(train["user_id"])
item_encoder = LabelEncoder().fit(train["item_id"])
class TwoTowerDataset(Dataset):
def __init__(self,
positive_interactions_dataframe,
user_encoder,
item_encoder
):
self.positive_interactions_dataframe = \
positive_interactions_dataframe.copy()
self.positive_interactions_dataframe["encoded_user_id"] = \
user_encoder.transform(self.positive_interactions_dataframe.user_id)
self.positive_interactions_dataframe["encoded_item_id"] = \
item_encoder.transform(self.positive_interactions_dataframe.item_id)
self.user2features = {}
self.item2features = {}
self.interactions_features_list = []
for row in tqdm(self.positive_interactions_dataframe.itertuples()):
self.interactions_features_list.append(
(
(row.encoded_user_id,),
(row.encoded_item_id,)
)
)
if row.user_id not in self.user2features:
self.user2features[row.user_id] = (row.encoded_user_id,)
if row.item_id not in self.item2features:
self.item2features[row.item_id] = (row.encoded_item_id,)
self.start = 0
self.end = len(self.interactions_features_list)
def __len__(self):
return len(self.interactions_features_list)
def __getitem__(self, i):
return self.interactions_features_list[i]
train_dataset = TwoTowerDataset(train, user_encoder, item_encoder)
def collate_fn(batch):
users_features, items_features = zip(*batch)
return (
torch.IntTensor(np.array(users_features)),
torch.IntTensor(np.array(items_features)),
)
train_loader = DataLoader(
train_dataset,
batch_size=len(train_dataset),
collate_fn=collate_fn
)
class TwoTowerModel(nn.Module):
def __init__(self, user_embedding_sizes, item_embedding_sizes, device):
super(TwoTowerModel, self).__init__()
self.device = device
self.user_embeds = nn.Embedding(
user_embedding_sizes[0],
user_embedding_sizes[1]
)
self.user_linear = nn.Linear(
user_embedding_sizes[1],
user_embedding_sizes[1]
)
self.item_embeds = nn.Embedding(
item_embedding_sizes[0],
item_embedding_sizes[1]
)
def get_user_embeddings(self, user_features):
"""Пользовательская часть TwoTower"""
user_embeddings = self.user_embeds(user_features[:, 0])
user_embeddings = self.user_linear(user_embeddings)
return user_embeddings
def get_item_embeddings(self, item_features):
"""Айтемная часть TwoTower"""
item_embeddings = self.item_embeds(item_features[:, 0])
return item_embeddings
def forward(self, user_features, item_features):
user_embs = self.get_user_embeddings(user_features.to(self.device))
item_embs = self.get_item_embeddings(item_features.to(self.device))
return user_embs, item_embs
infonceloss = InfoNCE(temperature=1)
def plot_embedding_shift(
old_users_embeddings,
old_items_embeddings,
new_users_embeddings,
new_items_embeddings,
users_ids,
items_ids,
title="Embedding Shift"
):
"""
Рисует сдвиг эмбеддингов с окружностью разных радиусов.
old_users_embeddings: Тензор старых эмбеддингов юзеров
old_items_embeddings: Тензор старых эмбеддингов айтемов
new_users_embeddings: Тензор новых эмбеддингов юзеров
new_items_embeddings: Тензор новых эмбеддингов айтемов
users_ids: Список id юзеров в порядке, соответствующем порядку
эмбеддингов юзеров
items_ids: Список id айтемов в порядке, соответствующем порядку
эмбеддингов айтемов
"""
old_users_embeddings = old_users_embeddings.cpu().detach().numpy()
old_items_embeddings = old_items_embeddings.cpu().detach().numpy()
new_users_embeddings = new_users_embeddings.cpu().detach().numpy()
new_items_embeddings = new_items_embeddings.cpu().detach().numpy()
# Вводим новую метрику: какая доля градиентов расположена
# перпендикулярно радиус-вектору
# для юзеров:
count_orthogonal_users = 0
for i in range(len(old_users_embeddings)):
grad = new_users_embeddings[i] - old_users_embeddings[i]
is_orthogonal = math.isclose(
grad @ old_users_embeddings[i],
0,
abs_tol=1e-6
)
if is_orthogonal:
count_orthogonal_users += 1
print(
"Доля юзеров, для которых градиент направлен строго по "
"касательной к гиперсфере, на которой изначально лежал "
"эмбеддинг этого объекта: ",
count_orthogonal_users / len(old_users_embeddings)
)
# для айтемов:
count_orthogonal_items = 0
for i in range(len(old_items_embeddings)):
grad = new_items_embeddings[i] - old_items_embeddings[i]
is_orthogonal = math.isclose(
grad @ old_items_embeddings[i],
0,
abs_tol=1e-6
)
if is_orthogonal:
count_orthogonal_items += 1
print(
"Доля айтемов, для которых градиент направлен строго по "
"касательной к гиперсфере, на которой изначально лежал "
"эмбеддинг этого объекта: ",
count_orthogonal_items / len(old_items_embeddings)
)
if old_users_embeddings.shape[1] == 2:
# Дальнейшая визуализация - только для двумерного случая:
old_embs = torch.cat(
(
torch.from_numpy(old_users_embeddings),
torch.from_numpy(old_items_embeddings)
),
0
).numpy()
new_embs = torch.cat(
(
torch.from_numpy(new_users_embeddings),
torch.from_numpy(new_items_embeddings)
),
0
).numpy()
ids=users_ids+items_ids
# Создаем DataFrame для интерактивного графика
df = pd.DataFrame({
"old_x": old_embs[:, 0],
"old_y": old_embs[:, 1],
"new_x": new_embs[:, 0],
"new_y": new_embs[:, 1],
"id": ids
})
# График со стрелками
fig, ax = plt.subplots(figsize=(20, 14))
# Рисуем стрелки
for i in range(len(df)):
ax.arrow(df["old_x"][i],
df["old_y"][i],
df["new_x"][i] - df["old_x"][i],
df["new_y"][i] - df["old_y"][i],
head_width=max(
np.sqrt((df["new_x"][i] - df["old_x"][i])**2 +
(df["new_y"][i] - df["old_y"][i])**2) * 0.2,
0.1),
head_length=max(
np.sqrt((df["new_x"][i] - df["old_x"][i])**2 +
(df["new_y"][i] - df["old_y"][i])**2) * 0.2,
0.1),
fc='blue',
ec='blue',
alpha=0.7)
# Точки эмбеддингов
ax.scatter(df["new_x"],
df["new_y"],
color='red',
label="New Positions",
s=30)
ax.scatter(df["old_x"],
df["old_y"],
color='gray',
alpha=0.5,
label="Old Positions",
s=30)
# Рисуем оси и разные окружности
line1 = plt.axline((1, 0), (0, 0),
color='black',
linestyle='dashed',
linewidth=2)
line2 = plt.axline((0, 1), (0, 0),
color='black',
linestyle='dashed',
linewidth=2)
circle1 = plt.Circle((0, 0),
1,
color='green',
fill=False,
linestyle='dashed',
linewidth=2)
ax.add_patch(circle1)
circle2 = plt.Circle((0, 0),
1.5,
color='green',
fill=False,
linestyle='dashed',
linewidth=2)
ax.add_patch(circle2)
circle3 = plt.Circle((0, 0),
2,
color='green',
fill=False,
linestyle='dashed',
linewidth=2)
ax.add_patch(circle3)
# Настройки осей и легенды
ax.set_xlim(-3, 3)
ax.set_ylim(-3, 3)
ax.set_title(title)
ax.legend()
ax.set_aspect('equal')
plt.show()
# Интерактивная версия с plotly (отображение id при наведении)
fig_interactive = px.scatter(df,
x="new_x",
y="new_y",
text="id",
title="Embedding Positions with IDs",
hover_data={
"id": True,
"new_x": False,
"new_y": False
})
fig_interactive.update_layout(
autosize=False,
width=800,
height=800,
)
fig_interactive.show()
user_embedding_sizes = [len(user_encoder.classes_), 2]
item_embedding_sizes = [len(item_encoder.classes_), 2]
model = TwoTowerModel(user_embedding_sizes, item_embedding_sizes,
device).to(device)
optimizer = torch.optim.SGD(model.parameters(), lr=5, momentum=0)
users_ids, users_features = zip(*train_dataset.user2features.items())
items_ids, items_features = zip(*train_dataset.item2features.items())
def train_loop(model, optimizer, train_loader, n_epochs=100):
loss_history = list()
for epoch in range(n_epochs):
for batch in tqdm(train_loader, desc=f'Epoch {epoch}'):
user_embs, item_embs = model(*batch)
loss = infonceloss(user_embs, item_embs)
optimizer.zero_grad()
loss.backward()
old_users_embeddings = model.get_user_embeddings(
torch.IntTensor(np.array(users_features)).to(device)
)
old_items_embeddings = model.get_item_embeddings(
torch.IntTensor(np.array(items_features)).to(device)
)
optimizer.step()
new_users_embeddings = model.get_user_embeddings(
torch.IntTensor(np.array(users_features)).to(device)
)
new_items_embeddings = model.get_item_embeddings(
torch.IntTensor(np.array(items_features)).to(device)
)
loss_history.append(loss.item())
# отображаем только первые эпохи, эпоху посередине и последнюю
if (epoch < 3) or (epoch == n_epochs//2) or (epoch == n_epochs-1):
print(f'Train loss: {loss_history[-1]}')
plot_embedding_shift(
old_users_embeddings,
old_items_embeddings,
new_users_embeddings,
new_items_embeddings,
users_ids,
items_ids
)
model.train()
train_loop(model, optimizer, train_loader)
Приложения
Приложение №1
Рассмотрим, почему скалярное произведение q на градиент косинуса по нему действительно равно нулю:
Введем обозначения:
Тогда, по правилу дифференцирования частного:
Посчитаем:
Подставляем и получаем:
Умножим полученное на вектор q:
ЧТД!
Приложение №2
Комментарии (4)
gofat
29.08.2025 13:36Статья интересная, но еще бы добавить графиков и бенчмарков для сравнений (так воспринимать проще материал), чтобы было совсем круто.
И еще клево было бы увидеть, что делать с такой проблемой, коль она уж проявилась. И как ее мониторить на основе полученных выводов на практике.
andrey_atamanyuk Автор
29.08.2025 13:36Я вот думал по поводу графиков и решил, что визуала в виде табличек в разделе "Подтверждение выводов в реальных экспериментах" будет достаточно, чтобы увидеть, что к чему. Но раз возник такой вопрос, наверное, это было не самым удобным решением :)
По поводу "что делать с такой проблемой" - тут, я считаю, самое главное - это знать, что проблема возникает и когда именно она возникает (об этом и статья). Дальше уже целое поле для фантазии, вплоть до примитивной нормализации эмбедов после каждого шага или масштабирования градиентов. Текущая статья и так вышла достаточно массивной и я решил не загромождать ее еще больше))
По поводу мониторинга. Как и с любыми хорошими метриками - это сложно, разработкой оффлайн и онлайн метрик popularity bias у нас занимаются целые команды
vy44ch
29.08.2025 13:36Странно, что авторы NormFace не учли так много факторов, которые должны выполняться, чтобы говорить про ортогональность градиента -> ортогональность обновления. Там же вроде реально основная логика улучшения рушится. Непонятно тогда, почему оно давало буст на метриках
UtrobinMV
А вывод?