Модель T5 – это нейросеть, которая уже обучена хорошо понимать и генерировать текст, и которую можно дообучить на собственную задачу, будь то перевод, суммаризация текстов, или генерация ответа чат-бота.
В этом посте я рассказываю про первую многозадачную модель T5 для русского языка и показываю, как её можно обучить на новой задаче.
Зачем нужна русскоязычная T5
T5 – нейросетевая модель для понимания и генерации текста. Изобрели её в работе от Google два года назад, и расшифровывается это название как text-to-text transfer transformer. Трансформер – это архитектура нейросетей, позволяющая извлекать из текста довольно объёмную информацию. Благодаря этой архитектуре модели типа BERT круто понимают тексты, а модели типа GPT весьма правдоподобно их генерируют. Text-to-text означает, что модель T5 принимает на вход тексты и "читает" их энкодером (как BERT), а потом "пишет" декодером новые тексты и отдаёт на выход. Слово transfer говорит о цели этой модели: она предобучалась восстанавливать пропущенные фрагменты текста, но при желании её можно дообучить на новые, более полезные задачи: перевод, перефразирование, суммаризация текстов, генерация диалоговых ответов, и т.п.
Гугл выпустил две версии T5: первая понимает только английский язык, зато дообучалась на 24 разных задачах, а вторая понимает 101 язык (включая русский), но умеет только заполнять пропуски в тексте. Поэтому я решил сначала ужать мультиязычную модель T5 (mT5) до двух языков: русского и английского, выкинув ненужные токены из её словаря и соответствующие строки из матриц входных и выходных эмбеддингов. Процесс подробно описан в этом посте, а в результате модель "похудела" с 2.2 до 0.9 ГБ, а значит, стала более удобной для применения. Эту уменьшенная модель я выложил под именем cointegrated/rut5-base. А дальше я пошёл по пути Google и дообучил свою русскую T5 решать одновременно несколько разных русских и англоязычных задач.
Многозадачная модель
Для каждой задачи, как и в гугловской статье, я использовал свой префикс, который надо писать, отделив символом |
, перед входным текстом. Сами задачи такие:
Перевод (префикс translate ru-en
или translate en-ru
). Обучал на датасете opus_wikipedia.
print(generate(
'translate ru-en | Каждый охотник желает знать, где сидит фазан.'))
# Each hunter wants to know, where he is.
print(generate(
'translate en-ru | Every hunter wants to know where the duck is.'))
# Все охотники хотят знать, где находится птица.
Перефразирование (paraphrase
). Обучал на датасете tapaco.
print(generate(
'paraphrase | Каждый охотник желает знать, где сидит фазан.',
encoder_no_repeat_ngram_size=1, repetition_penalty=0.5,
no_repeat_ngram_size=1
))
# В любом случае каждый рыбак мечтает познакомиться со своей фермой
Заполнение пропусков в тексте (fill
). Пропуски можно обозначать как ___
или _3_
, где 3
– примерное количество слов, которые надо вставить. Обучал на корпусе ru_web-public_2019_1M из Лейпцигской коллекции.
print(generate('fill | Каждый охотник _3_, где сидит фазан.'))
# смотрит на озеро
Восстановление текста из зашумлённого мешка слов (assemble
). Обучал также на Лейпцигском корпусе.
print(generate('assemble | охотник каждый знать фазан сидит'))
# Каждый охотник знает, что фазан сидит.
Упрощение текстов (simplify
). Обучал на данных дорожки RuSimpleSentEval.
print(generate(
'simplify | Местным продуктом-специалитетом с защищённым ' \
'географическим наименованием по происхождению считается ' \
'люнебургский степной барашек.', max_length=32))
# Местным продуктом-специалитетом считается люнебургский степной барашек.
Диалоговый ответ (reply
для ответов в стиле художественной литературы, обученных на корпусе Козиева, и answer
для ответов, обученных на otvet.mail.ru).
print(generate('reply | Помогите мне закадрить девушку'))
# Что я хочу?
print(generate('answer | Помогите мне закадрить девушку'))
# я хочу познакомиться с девушкой!!!!!!!!
Ответ на вопросы по тексту (comprehend
). Обучал на датасете SberQUAD.
print(generate(
'comprehend | На фоне земельного конфликта между владельцами овец и ' \
'ранчеро разворачивается история любви овцевода Моргана Лейна, ' \
'прибывшего в США из Австралии, и Марии Синглетон, владелицы ' \
'богатого скотоводческого ранчо. Вопрос: откуда приехал Морган?'))
# из Австралии
Задавание вопросов по тексту (ask
), обучал также на SberQUAD.
print(generate(
'ask | На фоне земельного конфликта между владельцами овец и ' \
'ранчеро разворачивается история любви овцевода Моргана Лейна, ' \
'прибывшего в США из Австралии, и Марии Синглетон, владелицы ' \
'богатого скотоводческого ранчо.', max_length=32))
# Что разворачивается на фоне земельного конфликта
# между владельцами овец и ранчеро?ро?
Генерация заголовка к новостной статье (headline
). Обучал на датасете Ильи Гусева по мотивам соревнования Телеграма.
print(generate(
'headline | На фоне земельного конфликта между владельцами овец ' \
'и ранчеро разворачивается история любви овцевода Моргана Лейна, ' \
'прибывшего в США из Австралии, и Марии Синглетон, владелицы ' \
'богатого скотоводческого ранчо.', max_length=32))
# На фоне земельного конфликта разворачивается история любви
# овцевода Моргана Лейна и Марии Синглетон
Как же работает эта магическая функция generate
? Стандартный python
и код из transformers, и дальше вы можете запускать любой из примеров выше. Попробовать это вы можете в демо-блокноте.
# !pip install transformers sentencepiece --quiet
import torch
from transformers import T5ForConditionalGeneration, T5Tokenizer
model_name = "cointegrated/rut5-base-multitask"
tokenizer = T5Tokenizer.from_pretrained(model_name)
model = T5ForConditionalGeneration.from_pretrained(model_name)
def generate(text, **kwargs):
inputs = tokenizer(text, return_tensors='pt')
with torch.no_grad():
hypotheses = model.generate(**inputs, num_beams=5, **kwargs)
return tokenizer.decode(hypotheses[0], skip_special_tokens=True)
Весь код, которым я дообучал модель, есть в этом блокноте; он не очень чистый, ибо я только экспериментировал, и содержит несколько задач, не вошедших в опубликованную версию модели. А сама модель выложена в каталог Huggingface: cointegrated/rut5-base-multitask.
Как обучать T5 на собственных данных
Текущая версия модели задумывалась как демонстрация возможностей предобученной seq2seq модели и как болванка для последующего дообучения. Поэтому я особо не возился ни с подбором гиперпараметров, ни с качеством датасетов. Следовательно, если более внимательно поработать с датасетом и гиперпараметрами и дообучить модель на какую-то одну задачу, работать она будет ещё лучше. Например, так я уже дообучал её на задаче перефразирования. А здесь я покажу, как обучить модель отвечать на вопросы из кроссвордов, используя фреймворки transformers и pytorch (при необходимости, модель можно потом сконвертировать в tensorflow или другой формат). Полный код примера есть всё в том же демо блокноте.
Инициализировать модель можно многозадачной версией. Тут, как и везде в библиотеке transformers
, model
– это сама нейросеть, а tokenizer
– это часть модели, ответственная за сопоставление текстов со словарём: разбиение текстов на числовые токены, и сбор текстов из токенов обратно. optimizer
– это объект, отвечающий за градиентный спуск; он нужен только на время обучения.
import torch
from transformers import T5ForConditionalGeneration, T5Tokenizer
raw_model = 'cointegrated/rut5-base-multitask'
model = T5ForConditionalGeneration.from_pretrained(raw_model).cuda();
tokenizer = T5Tokenizer.from_pretrained(raw_model)
optimizer = torch.optim.Adam(model.parameters(), lr=1e-5)
Ниже – полный код обучения модели. Предполагается, что переменная pairs
– это список из пар, состоящих из вопроса и ответа. На GPU одна эпоха обучения на 75 тысячах примеров занимает около 15 минут.
from tqdm.auto import trange
import random
import numpy as np
batch_size = 16 # сколько примеров показываем модели за один шаг
report_steps = 200 # раз в сколько шагов печатаем результат
epochs = 3 # сколько раз мы покажем данные модели
model.train()
losses = []
for epoch in range(epochs):
print('EPOCH', epoch)
random.shuffle(pairs)
for i in trange(0, int(len(pairs) / batch_size)):
batch = pairs[i * batch_size: (i + 1) * batch_size]
# кодируем вопрос и ответ
x = tokenizer([p[0] for p in batch], return_tensors='pt', padding=True).to(model.device)
y = tokenizer([p[1] for p in batch], return_tensors='pt', padding=True).to(model.device)
# -100 - специальное значение, позволяющее не учитывать токены
y.input_ids[y.input_ids == 0] = -100
# вычисляем функцию потерь
loss = model(
input_ids=x.input_ids,
attention_mask=x.attention_mask,
labels=y.input_ids,
decoder_attention_mask=y.attention_mask,
return_dict=True
).loss
# делаем шаг градиентного спуска
loss.backward()
optimizer.step()
optimizer.zero_grad()
# печатаем скользящее среднее значение функции потерь
losses.append(loss.item())
if i % report_steps == 0:
print('step', i, 'loss', np.mean(losses[-report_steps:]))
Три эпохи я выбрал от балды. Размер батча я выбрал опытным путём: максимальный, при котором хватает памяти на GPU. learning_rate
, равный 1e-5
, я выставил, исходя из опыта: обычно при нём модель обучается не очень быстро, но качественно.
Код для генерации ответа обученной моделью весьма прост:
model.eval()
def answer(x, **kwargs):
inputs = tokenizer(x, return_tensors='pt').to(model.device)
with torch.no_grad():
hypotheses = model.generate(**inputs, **kwargs)
return tokenizer.decode(hypotheses[0], skip_special_tokens=True)
Посмотрим, насколько хорошо модель выучила свою тренировочную выборку.
Какое животное раньше называли камелопардом?
answer: Жираф
model: акула
---
Грамматическая категория глагола, выражающая отношение действия к действительности (в лингвистике)
answer: наклонение
model: действие
---
О чём поется в песне Greenday – «Wake Me Up When September Ends» (Разбуди меня, когда сентябрь кончится)?
answer: О смерти его отца
model: о лете
---
Соседка Земли по Солнечной системе
answer: Венера
model: Африка
---
Отношение размеров на чертеже, карте и т. п. к действительным размерам на местности, предмете
answer: масштаб
model: пропорциональность
Что ж, модель пытается, но часто "мажет". Возможно, стоит поучить её в течение ещё нескольких эпох. А теперь посмотрим, насколько хорошо модель справляется с ответами на вопросы, которые она не видела. Кажется, довольно пристойно.
Минерал, сульфид марганца
answer: алабандин
model: сульфид
---
Где находится родина табака?
answer: Южная Америка
model: Бразилия
---
Старинный русский головной убор с приподнятым вверх спереди и сзади околышем
answer: кораблик
model: шнур
---
Почетное звание у тюрков и монголов, дававшееся за воинские подвиги
answer: батыр
model: аким
---
Счетный прибор
answer: арифмометр
model: таблица
Обученную модель можно сохранить на диск, а потом, при желании, даже выложить на хаб Huggingface. Я выложил её под названием cointegrated/rut5-base-quiz.
new_model_name = 'rut5-base-quiz' # название папки для сохранения
model.save_pretrained(new_model_name)
tokenizer.save_pretrained(new_model_name)
Вместо заключения
Предобученные seq2seq модели – это здорово. Сейчас ими можно решать много разных задач NLP, а Google считает, что вообще чуть ли не все. Моя моделька показывает, что отчасти это уже верно и для русского языка.
Комментарии (4)
Gorodecki
07.10.2021 14:47Интересно попробовать её обучить для восстановления пунктуации и больших букв. Насколько я помню там 512 токенов на входе можно подать?
cointegrated Автор
07.10.2021 14:50Нет, T5 использует relative position embeddings, поэтому тексты могут быть неограниченной длины. Правда, сложность там всё равно квадратичная от длины входа, поэтому на длинных текстах модель будет работать медленно.
А так вообще вчера специально для этой задачи Silero выложили модель: https://habr.com/ru/post/581946/
averkij
Для кроссвордов надо, наверное, датасет побогаче, чтобы разные формулировки усвоить. Или парафраз на вопросы, кстати, делать.
cointegrated Автор
Это правда, можно и собрать датасет на порядок больше, и позаниматься аугментацией разного рода, и позаниматься генерацией пар (вопрос, ответ) из текстов. Улучшать много чего есть. Но конкретно в этом посте мне хотелось показать, что файнтюнить T5 - просто, и я пренебрёг возможными улучшениями)