Модель T5 – это нейросеть, которая уже обучена хорошо понимать и генерировать текст, и которую можно дообучить на собственную задачу, будь то перевод, суммаризация текстов, или генерация ответа чат-бота.

В этом посте я рассказываю про первую многозадачную модель T5 для русского языка и показываю, как её можно обучить на новой задаче.

Русскоязычная модель T5 худо-бедно решает десяток разных задач
Русскоязычная модель T5 худо-бедно решает десяток разных задач

Зачем нужна русскоязычная T5

T5 – нейросетевая модель для понимания и генерации текста. Изобрели её в работе от Google два года назад, и расшифровывается это название как text-to-text transfer transformer. Трансформер – это архитектура нейросетей, позволяющая извлекать из текста довольно объёмную информацию. Благодаря этой архитектуре модели типа BERT круто понимают тексты, а модели типа GPT весьма правдоподобно их генерируют. Text-to-text означает, что модель T5 принимает на вход тексты и "читает" их энкодером (как BERT), а потом "пишет" декодером новые тексты и отдаёт на выход. Слово transfer говорит о цели этой модели: она предобучалась восстанавливать пропущенные фрагменты текста, но при желании её можно дообучить на новые, более полезные задачи: перевод, перефразирование, суммаризация текстов, генерация диалоговых ответов, и т.п.

Гугл выпустил две версии T5: первая понимает только английский язык, зато дообучалась на 24 разных задачах, а вторая понимает 101 язык (включая русский), но умеет только заполнять пропуски в тексте. Поэтому я решил сначала ужать мультиязычную модель T5 (mT5) до двух языков: русского и английского, выкинув ненужные токены из её словаря и соответствующие строки из матриц входных и выходных эмбеддингов. Процесс подробно описан в этом посте, а в результате модель "похудела" с 2.2 до 0.9 ГБ, а значит, стала более удобной для применения. Эту уменьшенная модель я выложил под именем cointegrated/rut5-base. А дальше я пошёл по пути Google и дообучил свою русскую T5 решать одновременно несколько разных русских и англоязычных задач.

Многозадачная модель

Для каждой задачи, как и в гугловской статье, я использовал свой префикс, который надо писать, отделив символом |, перед входным текстом. Сами задачи такие:

Перевод (префикс translate ru-en или translate en-ru). Обучал на датасете opus_wikipedia.

print(generate(
    'translate ru-en | Каждый охотник желает знать, где сидит фазан.'))
# Each hunter wants to know, where he is.

print(generate(
    'translate en-ru | Every hunter wants to know where the duck is.'))
# Все охотники хотят знать, где находится птица.

Перефразирование (paraphrase). Обучал на датасете tapaco.

print(generate(
    'paraphrase | Каждый охотник желает знать, где сидит фазан.', 
    encoder_no_repeat_ngram_size=1, repetition_penalty=0.5, 
    no_repeat_ngram_size=1
))
# В любом случае каждый рыбак мечтает познакомиться со своей фермой

Заполнение пропусков в тексте (fill). Пропуски можно обозначать как ___ или  _3_, где 3  – примерное количество слов, которые надо вставить. Обучал на корпусе ru_web-public_2019_1M из Лейпцигской коллекции.

print(generate('fill | Каждый охотник _3_, где сидит фазан.'))
# смотрит на озеро

Восстановление текста из зашумлённого мешка слов (assemble). Обучал также на Лейпцигском корпусе.

print(generate('assemble | охотник каждый знать фазан сидит'))
# Каждый охотник знает, что фазан сидит.

Упрощение текстов (simplify). Обучал на данных дорожки RuSimpleSentEval.

print(generate(
    'simplify | Местным продуктом-специалитетом с защищённым ' \
    'географическим наименованием по происхождению считается ' \
    'люнебургский степной барашек.', max_length=32))
# Местным продуктом-специалитетом считается люнебургский степной барашек.

Диалоговый ответ (reply для ответов в стиле художественной литературы, обученных на корпусе Козиева, и  answer для ответов, обученных на otvet.mail.ru).

print(generate('reply | Помогите мне закадрить девушку'))
# Что я хочу?

print(generate('answer | Помогите мне закадрить девушку'))
# я хочу познакомиться с девушкой!!!!!!!!

Ответ на вопросы по тексту (comprehend). Обучал на датасете SberQUAD.

print(generate(
    'comprehend | На фоне земельного конфликта между владельцами овец и ' \
    'ранчеро разворачивается история любви овцевода Моргана Лейна, ' \
    'прибывшего в США из Австралии, и Марии Синглетон, владелицы ' \
    'богатого скотоводческого ранчо. Вопрос: откуда приехал Морган?'))
# из Австралии

Задавание вопросов по тексту (ask), обучал также на SberQUAD.

print(generate(
    'ask | На фоне земельного конфликта между владельцами овец и ' \
    'ранчеро разворачивается история любви овцевода Моргана Лейна, ' \
    'прибывшего в США из Австралии, и Марии Синглетон, владелицы ' \
    'богатого скотоводческого ранчо.', max_length=32))
# Что разворачивается на фоне земельного конфликта 
# между владельцами овец и ранчеро?ро?

Генерация заголовка к новостной статье (headline). Обучал на датасете Ильи Гусева по мотивам соревнования Телеграма.

print(generate(
    'headline | На фоне земельного конфликта между владельцами овец ' \
    'и ранчеро разворачивается история любви овцевода Моргана Лейна, ' \
    'прибывшего в США из Австралии, и Марии Синглетон, владелицы ' \
    'богатого скотоводческого ранчо.', max_length=32))
# На фоне земельного конфликта разворачивается история любви 
# овцевода Моргана Лейна и Марии Синглетон

Как же работает эта магическая функция generate? Стандартный python и код из transformers, и дальше вы можете запускать любой из примеров выше. Попробовать это вы можете в демо-блокноте.

# !pip install transformers sentencepiece --quiet
import torch
from transformers import T5ForConditionalGeneration, T5Tokenizer
model_name = "cointegrated/rut5-base-multitask"
tokenizer = T5Tokenizer.from_pretrained(model_name)
model = T5ForConditionalGeneration.from_pretrained(model_name)

def generate(text, **kwargs):
    inputs = tokenizer(text, return_tensors='pt')
    with torch.no_grad():
        hypotheses = model.generate(**inputs, num_beams=5, **kwargs)
    return tokenizer.decode(hypotheses[0], skip_special_tokens=True)

Весь код, которым я дообучал модель, есть в этом блокноте; он не очень чистый, ибо я только экспериментировал, и содержит несколько задач, не вошедших в опубликованную версию модели. А сама модель выложена в каталог Huggingface: cointegrated/rut5-base-multitask.

Как обучать T5 на собственных данных

Текущая версия модели задумывалась как демонстрация возможностей предобученной seq2seq модели и как болванка для последующего дообучения. Поэтому я особо не возился ни с подбором гиперпараметров, ни с качеством датасетов. Следовательно, если более внимательно поработать с датасетом и гиперпараметрами и дообучить модель на какую-то одну задачу, работать она будет ещё лучше. Например, так я уже дообучал её на задаче перефразирования. А здесь я покажу, как обучить модель отвечать на вопросы из кроссвордов, используя фреймворки transformers и pytorch (при необходимости, модель можно потом сконвертировать в tensorflow или другой формат). Полный код примера есть всё в том же демо блокноте.

Инициализировать модель можно многозадачной версией. Тут, как и везде в библиотеке transformers, model – это сама нейросеть, а tokenizer – это часть модели, ответственная за сопоставление текстов со словарём: разбиение текстов на числовые токены, и сбор текстов из токенов обратно. optimizer – это объект, отвечающий за градиентный спуск; он нужен только на время обучения.

import torch 
from transformers import T5ForConditionalGeneration, T5Tokenizer
raw_model = 'cointegrated/rut5-base-multitask' 
model = T5ForConditionalGeneration.from_pretrained(raw_model).cuda();
tokenizer = T5Tokenizer.from_pretrained(raw_model)
optimizer = torch.optim.Adam(model.parameters(), lr=1e-5)

Ниже – полный код обучения модели. Предполагается, что переменная pairs – это список из пар, состоящих из вопроса и ответа. На GPU одна эпоха обучения на 75 тысячах примеров занимает около 15 минут.

from tqdm.auto import trange
import random
import numpy as np

batch_size = 16  # сколько примеров показываем модели за один шаг
report_steps = 200  # раз в сколько шагов печатаем результат
epochs = 3  # сколько раз мы покажем данные модели

model.train()
losses = []
for epoch in range(epochs):
    print('EPOCH', epoch)
    random.shuffle(pairs)
    for i in trange(0, int(len(pairs) / batch_size)):
        batch = pairs[i * batch_size: (i + 1) * batch_size]
        # кодируем вопрос и ответ 
        x = tokenizer([p[0] for p in batch], return_tensors='pt', padding=True).to(model.device)
        y = tokenizer([p[1] for p in batch], return_tensors='pt', padding=True).to(model.device)
        # -100 - специальное значение, позволяющее не учитывать токены
        y.input_ids[y.input_ids == 0] = -100
        # вычисляем функцию потерь
        loss = model(
            input_ids=x.input_ids,
            attention_mask=x.attention_mask,
            labels=y.input_ids,
            decoder_attention_mask=y.attention_mask,
            return_dict=True
        ).loss
        # делаем шаг градиентного спуска
        loss.backward()
        optimizer.step()
        optimizer.zero_grad()
        # печатаем скользящее среднее значение функции потерь
        losses.append(loss.item())
        if i % report_steps == 0:
            print('step', i, 'loss', np.mean(losses[-report_steps:]))

Три эпохи я выбрал от балды. Размер батча я выбрал опытным путём: максимальный, при котором хватает памяти на GPU. learning_rate, равный 1e-5, я выставил, исходя из опыта: обычно при нём модель обучается не очень быстро, но качественно.

Код для генерации ответа обученной моделью весьма прост:

model.eval()

def answer(x, **kwargs):
    inputs = tokenizer(x, return_tensors='pt').to(model.device)
    with torch.no_grad():
        hypotheses = model.generate(**inputs, **kwargs)
    return tokenizer.decode(hypotheses[0], skip_special_tokens=True)

Посмотрим, насколько хорошо модель выучила свою тренировочную выборку. 

Какое животное раньше называли камелопардом?
answer: Жираф
model:  акула
---
Грамматическая категория глагола, выражающая отношение действия к действительности (в лингвистике)
answer: наклонение
model:  действие
---
О чём поется в песне Greenday – «Wake Me Up When September Ends» (Разбуди меня, когда сентябрь кончится)?
answer: О смерти его отца
model:  о лете
---
Соседка Земли по Солнечной системе
answer: Венера
model:  Африка
---
Отношение размеров на чертеже, карте и т. п. к действительным размерам на местности, предмете
answer: масштаб
model:  пропорциональность

Что ж, модель пытается, но часто "мажет". Возможно, стоит поучить её в течение ещё нескольких эпох. А теперь посмотрим, насколько хорошо модель справляется с ответами на вопросы, которые она не видела. Кажется, довольно пристойно.

Минерал, сульфид марганца
answer: алабандин
model:  сульфид
---
Где находится родина табака?
answer: Южная Америка
model:  Бразилия
---
Старинный русский головной убор с приподнятым вверх спереди и сзади околышем
answer: кораблик
model:  шнур
---
Почетное звание у тюрков и монголов, дававшееся за воинские подвиги
answer: батыр
model:  аким
---
Счетный прибор
answer: арифмометр
model:  таблица

Обученную модель можно сохранить на диск, а потом, при желании, даже выложить на хаб Huggingface. Я выложил её под названием cointegrated/rut5-base-quiz.

new_model_name = 'rut5-base-quiz'  # название папки для сохранения
model.save_pretrained(new_model_name)
tokenizer.save_pretrained(new_model_name)

Вместо заключения

Предобученные seq2seq модели – это здорово. Сейчас ими можно решать много разных задач NLP, а Google считает, что вообще чуть ли не все. Моя моделька показывает, что отчасти это уже верно и для русского языка.

Комментарии (4)


  1. averkij
    06.10.2021 18:43
    +1

    Для кроссвордов надо, наверное, датасет побогаче, чтобы разные формулировки усвоить. Или парафраз на вопросы, кстати, делать.

    image


    1. cointegrated Автор
      06.10.2021 19:41
      +1

      Это правда, можно и собрать датасет на порядок больше, и позаниматься аугментацией разного рода, и позаниматься генерацией пар (вопрос, ответ) из текстов. Улучшать много чего есть. Но конкретно в этом посте мне хотелось показать, что файнтюнить T5 - просто, и я пренебрёг возможными улучшениями)


  1. Gorodecki
    07.10.2021 14:47

    Интересно попробовать её обучить для восстановления пунктуации и больших букв. Насколько я помню там 512 токенов на входе можно подать?


    1. cointegrated Автор
      07.10.2021 14:50

      Нет, T5 использует relative position embeddings, поэтому тексты могут быть неограниченной длины. Правда, сложность там всё равно квадратичная от длины входа, поэтому на длинных текстах модель будет работать медленно.

      А так вообще вчера специально для этой задачи Silero выложили модель: https://habr.com/ru/post/581946/