Всем привет! Меня зовут Айбек Аланов. Я — аспирант ВШЭ, а также научный сотрудник группы «Вероятностные методы машинного обучения» AIRI. Сегодня мне хотелось бы поделиться с вами успехами, которых добилась наша научная группа в вопросе адаптации генеративно-состязательных сетей на новые домены.

Генеративно-состязательные сети

Генеративно-состязательной сетью (Generative adversarial network, GAN) называют алгоритм машинного обучения без учителя, который основан на комбинации двух нейронных сетей — генератора и дискриминатора, — настроенных на работу друг против друга. Первая генерирует новые образцы на основе исходных, задача второй — распознать, что это подделка. После каждого цикла генерации и распознавания, происходит обновление весов каждой сети на основе общей функции потерь, которая минимизируется генератором и максимизируется дискриминатором. Такая антагонистическая игра позволяет генератору все лучше и лучше подделывать образцы до такой степени, что к концу обучения они становятся неотличимы от реальных образцов.

Источник: WelcomeAIOverlords / YouTube
Источник: WelcomeAIOverlords / YouTube

Впервые такая модель была предложена легендой машинного обучения Яном Гудфеллоу в 2014 году. Менее, чем за декаду она показала себя эффективным инструментом для генерации и улучшения изображений, благодаря чему GAN активно используются для решения самых разных задач в компьютерном зрении: улучшения и изменения изображений, междоменных преобразований «image-to-image» и многого другого.

Проблема малой выборки

Сегодня, чтобы обучить генеративную модель на каком-то домене (то есть наборе схожих по признакам изображений), нам нужен доступ к большой выборке высокого качества из него. Под «большой» я понимаю «очень большой». Например, если мы хотим научиться генерировать реалистичные лица высоком разрешении, то нам нужно иметь датасет типа FFHQ, содержащий около 70 000 изображений, которые были специально для этого отобраны и имеют разрешение как минимум 1024×1024 пикселей.

Источник: Tero Karras et al., Proceedings of the IEEE/CVF conference (2020)
Источник: Tero Karras et al., Proceedings of the IEEE/CVF conference (2020)

Но такие большие датасеты встречаются не всегда. Скажем, для генерации мордочек котиков может помочь датасет AFHQ, но там их всего пять тысяч. Еще хуже обстоят дела с другими наборами данных. Например, в датасете MetFaces содержится всего чуть более тысячи лиц людей с картин разных художников, а в датасете FaceSketches, собранном из карандашных скетчей, всего около 300 объектов. Если мы попытаемся обучить наш GAN на таких маленьких выборках, результат будет неудовлетворительный.

Если обучать GAN на датасете из 2 тысяч лиц, получается не очень. Источник: Tero Karras et al., arXiv (2020)
Если обучать GAN на датасете из 2 тысяч лиц, получается не очень. Источник: Tero Karras et al., arXiv (2020)

Трансферное обучение

Но все не так плохо. В какой-то момент специалисты по машинному обучению поняли, что эту проблему можно решить, адаптируя нейросеть, обученную на домене, где есть много картинок, к узкому домену. Такой подход получил название трансферного обучения (transfer learning, TL).

В случае с лицами первоначальное обучение проводят на все том же датасете FFHQ. Обученный таким образом GAN можно дообучить на небольшом числе картин людей, и он начнет рисовать лица людей в стиле известных художников. Существуют также методы доменной адаптации, которые позволяют делать это на основе текстовых описаний или одного стилевого изображения.

Процесс адаптации GAN, обученной на FFHQ, к домену с портретами от разных художников. Источник: Ojha et al., Proceedings of the IEEE/CVF conference (2021)
Процесс адаптации GAN, обученной на FFHQ, к домену с портретами от разных художников. Источник: Ojha et al., Proceedings of the IEEE/CVF conference (2021)

Но проблема всех этих методов в том, что, чтобы адаптировать генератор, обычно обучают все веса модели. В своей работе мы рассматривали в основном модель StyleGAN2 — так вот для нее необходимо настраивать порядка 30 миллионов весов! На самом деле, это довольно избыточный труд в случае, если новый домен довольно близок к основному. Например, если вы всего лишь хотите превратить ваших друзей в персонажей аниме.

Схема традиционной адаптации модели StyleGAN2. Аниме на аве за 30 миллионов весов
Схема традиционной адаптации модели StyleGAN2. Аниме на аве за 30 миллионов весов

Мы с моими коллегами, Дмитрием Ветровым и Вадимом Титовым, поставили себе цель облегчить эту процедуру и уменьшить число обучаемых параметров. Для этого нам потребовалось глубже разобраться, как устроена архитектура генератора StyleGAN2.

Модифицируем схему работы StyleGAN2

Если кратко, эта модель устроена следующим образом. Случайный шум, который определяет наш выходной объект, подается с помощью специального стилевого вектора. Он получается из исходного шума путем нескольких преобразований и подается в каждый сверточный слой генератора. В слоях происходит модуляция входных каналов свертки путем умножения на компоненты этого вектора. Таким образом осуществляется контроль дисперсии каждого входного канала. Другими словами, стилевой вектор контролирует все семантические признаки выходного изображения: пол, возраст и другие свойства. Наконец, в слое происходит демодуляция, которая нормирует наши выходные каналы.

Схема внедрения стилевого вектора в сверточные слои генератора StyleGAN2
Схема внедрения стилевого вектора в сверточные слои генератора StyleGAN2

Мы решили внести в эту схему дополнительную операцию, а именно доменную модуляцию. Она заключается в введении дополнительного доменного вектора той же размерности, что и стилевой вектор. Его компоненты также домножаются на входные каналы, прежде чем происходит демодуляция. Доменная адаптация в данном случае заключается в том, что для нового домена нам достаточно дообучать только этот новый вектор, а не все веса модели, а его размерность — это всего лишь 6 тысяч, то есть в пять тысяч раз меньше, чем в традиционном подходе.

Схема внедрения доменного вектора
Схема внедрения доменного вектора

В своей недавней статье, опубликованной в сборнике трудов конференции NeurIPS 2022, мы показали, что этого действительно достаточно, чтобы адаптировать предобученную StyleGAN2 на самые разные домены. Вы можете сами убедиться, что оба подхода визуально показывают одинаковый результат для адаптации по текстовому описанию:

Адаптация по текстовому описанию. Слева — обучение генератора целиком, справа — только доменного вектора
Адаптация по текстовому описанию. Слева — обучение генератора целиком, справа — только доменного вектора

Для one-shot-адаптации — то есть адаптации по одному изображению — результат так же хороший (на примере модели MindTheGap):

Адаптация по одной картинке. Слева — обучение генератора целиком, справа — только доменного вектора
Адаптация по одной картинке. Слева — обучение генератора целиком, справа — только доменного вектора

Мульти-доменная адаптация с помощью гиперсети

Вдохновившись таким радикальным снижением пространства параметров, мы задались вопросом, а нельзя ли обучить дополнительную сеть, которая автоматически выдавала бы нам доменный вектор под каждый домен? Другими словами, можно ли решить таким способом задачу мульти-доменной адаптации GAN?

Оказалось, что можно. Для этого мы разработали гиперсеть, которая получила название HyperDomainNet. На ее вход мы подаем эмбеддинг текста, который описывает целевой домен —его можно получить с помощью дополнительного предобученного энкодера. На выходе гиперсеть выдает соответствующие доменные векторы. В результате GAN рисует изображения в каждом домене, который был использован в обучении гиперсети.

Схема работы мульти-доменной адаптации с помощью HyperDomainNet
Схема работы мульти-доменной адаптации с помощью HyperDomainNet

Такой подход дает результаты не хуже, чем если бы мы дообучали модель на каждый домен по-отдельности.

Результат адаптации с помощью HyperDomainNet
Результат адаптации с помощью HyperDomainNet

В ходе работы с гиперсетью мы обнаружили еще один необычный эффект: когда число доменов, на которых она обучается, становится достаточно велико, алгоритм смог адаптировать GAN к новым доменам (unseen domains), то есть к доменам, которые не были представлены в обучении.

Если вам интересно позапускать нашу модель, можете зайти на наш репозиторий. Там есть ссылка на наш Colab, где вы можете использовать наши предобученные модели на своих фотографиях, либо самостоятельно дообучить StyleGAN2 на новых доменах с помощью малого числа параметров.

На этом все. Если есть вопросы, с удовольствием отвечу на них в комментариях!

Результат работы HyperDomainNet, примененный к нашим фотографиям
Результат работы HyperDomainNet, примененный к нашим фотографиям

Комментарии (3)


  1. berng
    10.07.2023 13:42

    Поподробнее можно про внедрение S и D в сворточные слои? Вы каким-то образом корректируете элементы оригинального сверточного ядра?


    1. aibrain Автор
      10.07.2023 13:42
      +2

      Это происходит стандартным образом через операцию modulation, которая была предложена в статье https://arxiv.org/abs/1912.04958 (формула 1). Например, если мы подаем в операцию modulation вектор S, это означает, что веса свертки умножатся на элементы вектора S. Это умножение будет происходить следующим образом. Допустим w - это веса свертки, тогда w будет 4-мерным тензором с размерностями c_in x c_out x k x k, где c_in - это число входных слоев, c_out - число выходных слоев, k - размер ядра свертки. В таком случае вектор S должен иметь размер c_in, и мы умножаем каждый элемент S_i на весь подтензор w[i, :, :, :].


      1. berng
        10.07.2023 13:42

        Идея понятна, спасибо!