Наверняка вы мечтали поговорить с великим философом: задать ему вопрос о своей жизни, узнать его мнение или просто поболтать. В наше время это возможно за счет чат-ботов, которые поддерживают диалог, имитируя манеру общения живого человека. Подобные чат-боты создаются благодаря технологиям обработки естественного языка и генерации текста. Уже сейчас существуют обученные модели, которые неплохо справляются с данной задачей.

В этой статье я расскажу о своем опыте обучения алгоритма генерации текста, основанного на высказываниях великих личностей. В датасете для обучения модели используются цитаты десяти известных философов, писателей и ученых. 

Конечный текст будет генерироваться на основе высказываний всех десяти мыслителей.Но если вы захотите “пообщаться” с кем-то конкретным, например, с Сократом или Ницше, то Google Colab, в котором велась работа, прилагается в конце статьи. С его помощью можно будет поэкспериментировать только с генерацией выбранного вами философа.

План работы

С задачей генерации текста неплохо справляются модели по типу \textbf{LSTM},\textbf{GRU}, но все же зачастую лучшего всего работают Трансформеры. У них получается более осмысленный и понятный для человека текст. Проблема в том, что Трансформеры очень “тяжеловесные”, и их нужно долго обучать.

У нас есть датасет из цитат философов, на котором нужно обучить модель, чтобы та генерировала тексты, схожие с реальными высказываниями этих людей. Вопрос в том, какую модель брать, и как обрабатывать данные для нее? Как оценивать результат модели, как ее обучать? 

Как настоящие исследователи-практики, мы начнем с более простой модели - \textbf{LSTM}. Вдруг она сразу покажет достойный результат, и нам не придется использовать Трансформеры. 

Посмотрим на результаты уже предобученных моделей и дообучим их на нашем датасете. В качестве лосса будем использовать незамысловатую кросс-энтропию, а в качестве метрики качества возьмем простую точность, в данном случае она вполне хорошо заходит. 

После попробуем сгенерировать концовку для простого предложения, и оценим с человеческой точки зрения, насколько модель хорошо составляет текст. Ведь хорошая метрика - это, конечно, замечательно, но если сгенерированные предложения не будут иметь никакого смысла, то какой от них прок?

Шаг 1: Сбор данных

Для обучения модели я воспользовался open source датасетом компании LabelMe. Он представлен в виде таблицы, в котором собраны цитаты 

  • Конфуция, 

  • Ницше, 

  • Сократа, 

  • Омара Хайяма, 

  • Марка Аврелия, 

  • Лао Цзы, 

  • Аристотеля, 

  • Иммануила Канта, 

  • Цицерона,

  • Платона.

В общей сложности более 1500 цитат. Такой вот концентрат многовековой мудрости. Вы можете сказать, что можно было собрать и больше. Но ключевая особенность этого датасета – это не бездумный парсинг. В нем есть структура и логика отбора, благодаря чему конечный результат будет лучше, чем если нахватать больше рандомных цитат с citaty.info.

Узнать о датасете больше и скачать его можно здесь.

Шаг 2: Обработка данных

Для \textbf{LSTM} объединим данные из датасета и посмотрим на предложение, токенизируем его. В Google Colab можно увидеть, что получилось слишком большое количество токенов – 8760:

Количество токенов при изначальной токенизации
Количество токенов при изначальной токенизации

В чем же проблема большого числа токенов? Дело в том, что при предсказании и генерации следующего слова в тексте модель подсчитывает для каждого слова вероятность того, что оно окажется следующим в предложении.

Получается, если слов очень много, то и количество вероятностей будет соответствующим. Вероятность мы вычисляем с помощью софтмакса, но брать софтмакс от вектора такой большой размерности затруднительно. 

Есть несколько способов решения проблемы:

  • Отбросить редко встречающиеся слова (токены), и заменить их все на соответствующий токен UNK, присваиваемый всем токенам, которых нет в словаре. Понятно, что это очень простой и не особо эффективный вариант, так как мы сильно урезаем словарь.

  • Использовать в качестве токенов символы, а не слова. Этот простой подход имеет явный недостаток — модели, обучаемой на отдельных символах, гораздо сложнее выучить соотношения между буквами в каждом отдельном слове. Из-за этого может возникнуть ситуация, при которой значительная часть слов, сгенерированных моделью, не будут являться словами языка, на котором написаны входные тексты.

  • Использовать Byte Pair Encoding. Более современный и мощный метод, реализованный ребятами из VK. В прикрепленном Google Colab кратко описан принцип работы данного метода, а также прикреплены ссылки на смежные статьи и GitHub. Идея решения проблемы заключается в том, что мы находим часто встречающиеся n-gram-ы и заменяем их на какой-то символ, тем самым кодируя их. Очевидно, что в объемных текстах мы сможем найти огромное количество таких замен.

Вот небольшая поясняющая картинка:

Пример работы Byte Pair Encoding
Пример работы Byte Pair Encoding

Конечно же мы выберем последний метод работы с данными.

Шаг 3: Обучение модели

Само обучение модели будет проходить стандартным способом. Для этого нам понадобится к каждому скрытому состоянию RNN применить классификационную нейронную сеть - скорее всего, линейную. На каждом этапе времени должно получиться свое предсказание вероятности для следующего токена.

После этого необходимо усреднить лосс по полученным предсказаниям. Вкратце, это будет происходить следующим образом: на вход приходит последовательность x_1, x_2, ..., x_{n}. Требуется на каждом этапе времени t по символам x_1, ..., x_{t - 1} предсказывать символ x_t. Для этого на вход нужно подать последовательность BOS, x_1, x_2..., x_{n} и в качестве таргетов взять последовательность x_1, x_2, ..., x_{n}, EOS

В данном случае BOS, EOS — специальные символы начала и конца предложения (любые новые символы, которых нет в словаре). При обучении мы подаем на вход истинные токены и предсказываем следующий токен, но при тестировании истинных токенов нет. Поэтому при тестировании стоит в качестве следующего символа передавать предыдущий предсказанный символ. Соответственно поясняющая картинка:

Обучение RNN
Обучение RNN
 Тестирование RNN
Тестирование RNN

Также для дальнейшей работы необходимо уравнять все предложения по длине. Идея решения этой проблемы незамысловатая: выберем нужную длину предложения, и для недостающих по длине предложений добавим специальный токенPAD, означающий, что предложение расширилось, и не нужно будет после эмбеддинга пускать по таким токенам градиент. Слишком длинные предложения мы просто обрежем. Все тонкости реализации можно посмотреть в Google Colab.

В качестве модели была взята следующая архитектура:

Архитектура RNN
Архитектура RNN

В качестве первого слоя используется эмбеддинг, который будет переводить токены в удобный для машины вид. Мы специально передаем индекс токена PAD, чтобы, как уже отмечалось, по таким словам не протекал градиент. Слои свертки и нормализации используются в виде feature extactor – некоторые эксперименты с моделью показали, что это дает неплохой результат. Далее идет сама рекуррентная сеть, и в конце – линейный слой в виде классификатора. 

В итоге после 30 эпох обучения получились следующие результаты:

Результаты обучения
Результаты обучения

Вполне неплохое значение метрики. Интересно посмотреть на саму генерацию текста.

Шаг 4: Генерация текста

Есть несколько способов генерации текста. Вот некоторые из них:

  • Жадная генерация текста: на каждом шаге берем токен с наибольшей вероятностью. Очевидны проблемы с данным подходом: выбрав наиболее вероятный токен, мы можем потерять осмысленность предложения, ведь не всегда наиболее вероятное слово является самым подходящим.

  • Top k sampling: для предсказания следующего токена посмотрим на распределение вероятностей. Выбираем k токенов с максимальной вероятностью, и из них сэмплируем один токен с вероятностью, пропорциональной предсказанной для этого токена вероятности. Понятно, что это частично решает проблему жадной генерации, но не полностью.

  • Beam search: смотрим на несколько вариантов построения предложения, и в некоторые моменты времени отсекаем предложения с наименьшей в целом вероятностью, и так разрастаемся. В конце берем предложение с наибольшей вероятностью. Вот картинка поясняющая данный метод:

В данном случае мы будем использовать первые два метода из-за простоты их реализации. 

  • Жадная генерация

Выясним, что выдает наша модель в случае жадной генерации. В качестве проверки возьмем начало предложения в виде: "Красота есть во всем, но". Посмотрим, как модель продолжит высказывание. Результаты получились следующими:

  • Top k sampling

Попробуем с тем же предложением продолжить с методом генерации \textit{Top k sampling}. Результаты получились следующими:

Видно, что текст аналогично получился бессмысленным.

Шаг 5: Работа с Transformers

RNN справились плохо, поэтому попробуем более тяжеловесные решения, а именно Трансформеры. Они почти всегда выдают неплохой результат, особенно в данной задаче. Почему мы, в таком случае, сразу не воспользовались этой моделью?

Дело в том, что при всей мощности данной модели, ее необходимо долго обучать. Даже на одну эпоху некоторые модели учатся часами. Поэтому чаще используют уже обученные модели, и для специфических случаев дообучают их. Так мы и сделаем, возьмем модель с известного сайта huggingface. После недолгих поисков, можно найти следующую модель, которая хорошо подходит для нашей задачи:

Model Transformer
Model Transformer

Это уже обученная модель от Сбербанка, основанная на известном Трансформере GPT-2.

Шаг 6: Проверка предобученной модели

Трансформеры славятся тем, что хорошо работают прямо из коробки. Поэтому сразу посмотрим, как дополняет предложение из предыдущего пункта данная модель:

Результат трансформера
Результат трансформера

Текст выглядит намного более осмысленным, можно даже зачитать своей даме сердца. Но все же интересно: если дообучить модель на нашем датасете, то сможет ли она смоделировать корректные высказывания великих людей?

Шаг 7: Обработка данных

Обработка данных во многом повторяет методы, использованные в работе с RNN. Добавляем токены конца и начала предложения. Приводим все предложения к одной длине с помощью обрезки, либо же дополнения с помощью все того же символа PAD.Только в данном случае есть небольшое различие, а именно в том, что для обработки данных символов Трансформеру необходимо также передавать так называемую “маску”, в которой стоят нули там, где градиент не должен течь. То есть там, где стоят токены PAD. В остальных местах стоят единицы.

Шаг 8: Дообучение модели и проверка результатов

Дообучение модели происходит аналогично обучению для модели RNN, только намного дольше, даже несмотря на то, что эпох здесь меньше. Зато результаты получились впечатляющими, а именно продолжение все того же высказывания:

Результат генерации дообученного трансформера
Результат генерации дообученного трансформера

Получилось очень похоже на высказывание великого человека с именем "Не знаю". 

В целом результат неплох.

Выводы

К сожалению, хоть RNN и не особо тяжеловесные модели, они показывают худший результат по сравнению с Трансформерами. Мы поняли, что дообученные Трансформеры хорошо моделируют язык людей, и могут без проблем продолжить вашу мысль, как будто это была мысль великого философа. Ну или почти.

Комментарии (0)