Почти такой же заголовок носит и моя предыдущая статья, с той лишь разницей, что тогда я создавал линзы для SnapChat алгоритмически, используя dlib и openCV, а сегодня хочу показать, как можно добиться результата, используя машинное обучение. Этот подход позволит не заниматься ручным проектированием алгоритма, а получать итоговое изображение прямо из нейронной сети.
Вот что мы получим:
Что такое pix2pix?
Это способ преобразования изображения в изображение с помощью состязательных сетей (широко известный как pix2pix).
Название «pix2pix» означает, что сеть обучена преобразовывать входное изображение в соответствующее ему выходное изображение. Вот примеры таких преобразований:
Самой крутой особенностью pix2pix является универсальность подхода. Вместо создания нового алгоритма или новой модели для каждой из задач выше, достаточно просто использовать разные датасеты для тренировки сети.
В отличие от подходов, применявшихся ранее, pix2pix учится решать задачи гораздо быстрее и на меньшей обучающей выборке. Так, например, результаты ниже были получены при обучении с использованием Pascal Titan X GPU на датасете из 400 пар изображений и менее чем за два часа.
Как работает pix2pix?
pix2pix использует две нейронные сети, обучающиеся параллельно:
- Генератор
- Дискриминатор
Генератор пытается сгенерировать выходное изображение из входных обучающих данных, а дискриминатор пытается определить, является ли результат настоящим или сгенерированным.
Когда у генератора получаются изображения неотличимые (дискриминатором) от настоящих, мы начинаем тренировать дискриминатор на них и на настоящих изображениях. Когда у дискриминатора получается успешно отличать реальные изображения от сгенерированных, мы вновь начинаем тренировать генератор, чтобы тот снова научился обманывать дискриминатор.
Такая «гонка вооружений» приводит к тому, что отличить реальные изображения от сгенерированных становится сложно и человеку.
Практика
Тренировать наш генератор фильтров для SnapChat мы будем на изображениях 256x256 (большие размеры потребуют большего количества видеопамяти). Для создания датасета воспользуемся кодом из предыдущего туториала.
Я скачал множество изображений лиц и применил к каждому фильтр «Thug Life Glasses». Получится что-то типа таких пар:
Для создания модели возьмём репозиторий pix2pix на основе TensorFlow. Склонируйте его и установите Tensorflow.
Команда для запуска обучения будет такой:
python pix2pix.py --mode train --output_dir dir_to_save_checkpoint --max_epochs 200 --input_dir dir_with_training_data --which_direction AtoB
Параметр which_direction задаёт направление обучения. AtoB означает что мы хотим превратить изображение A (слева, без очков) в изображение B (справа, с очками). Кстати, обратите внимание, что pix2pix может успешно научиться восстанавливать исходное изображение из изображения с фильтром, достаточно лишь поменять направление обучения.
Отслеживать прогресс обучения можно с помощью tensorboard, для чего надо выполнить команду:
tensorboard --logdir=dir_to_save_checkpoint
Как только вы видите, что результаты на тренировочных данных стали достаточно хорошими, можно остановить обучение и проверить работу модели на произвольных данных. Продолжить тренировку с последней контрольной точки можно так:
python pix2pix.py --mode train --output_dir dir_to_save_checkpoint --max_epochs 200 --input_dir dir_with_training_data --which_direction AtoB --checkpoint dir_of_saved_checkpoint
Заключение
Появление генеративных сетей типа pix2pix открывает большие перспективы для универсального и простого решения всевозможных задач обработки изображений.