Привет, меня зовут Андрей Казначеев, я NLP engineer в компании MTS AI. В этой статье я расскажу, как создал лонгформер для русского языка. Все началось с того, что мне подкинули задачу по классификации длинных диалогов. Тексты длинные, а большинство популярных моделей имеют строгое ограничение по длине входной последовательности. Хотелось сделать решение умнее, чем просто побить текст на куски, однако ничего готового для русского языка не нашел. Тогда я задумался, а так ли сложно сделать свою собственную версию лонгформера под русский язык? Оказалось, совсем не сложно.

Архитектура лонгформера

Лонгформер - модель, основанная на архитектуре Transformer, адаптированной для обработки длинных текстов. Ни для кого не секрет, что в силу O(n2) сложности вычисления матрицы аттеншна, популярные transformer-based модели (BERT, RoBERTa, и т.д.) имеют короткий контекст (чаще всего 512 токенов). Авторы оригинального лонгформера использовали несколько трюков, которые снизили сложность вычисления до O(n):

  • Sliding window
    В то время как обычный аттеншн перемножает эмбеддинги по принципу "все на все" (Рис. 1a), в лонгформере используется скользящее окно фиксированного размера w (обычно это 512), поэтому сложность вычисления становится O(n x w) (Рис. 1b).

  • Dilated sliding window
    Можно в аттеншн паттерн добавить промежутки (gap) по аналогии с dilated-CNN. Если добавить gap размера d, для для слоя l receptive field будет размера l x d x w (Рис. 1с). В multi-head attention для разных голов задают разное значение d (некоторым 0). Это позволяет им учить разную информацию: какие-то концентрируются на локальном контексте, какие-то на дальнем.

  • Global attention
    Оказалось, что dilated и sliding аттеншн недостаточно гибок для различных задач: MLM концентрируется на локальном контексте для предсказания маскированного токена, для текстовой классификации вся информация о тексте агрегируется в один [CLS] токен, по которому происходит предсказание, для QA ответ на вопрос ищется в конкатенированном тексте. Поэтому было решено для некоторых заранее определенных токенов добавить глобальный аттеншн - они видят все токены последовательности. Для текстовой классификации это [CLS] токен, для MLM - [MASK] - токен, для QA - токены вопроса. В силу того, что этих токенов константное количество, общая сложность вычислений остается линейной.

    Рис. 1. Аттеншн паттерны
    Рис. 1. Аттеншн паттерны

Как я делал лонгформер для русского

Не имея желания возиться с обучением большой модели с нуля, я попробовал проявить творческий подход и переиспользовать уже обученные модели для русского языка, адаптировав их под длинные тексты. В официальном репозитории лонгформера есть ноутбук с примером конвертации роберты в лонгформер (правда он сделан под transformers==3.0.2, поэтому пришлось повозиться, чтобы адаптировать под 4.2.0+ версию).

Общая идея:

  1. Инициализировать новый класс лонгформера с такими же параметрами, как у роберты (там, где они общие).

    roberta_model = RobertaForMaskedLM.from_pretrained('ai-forever/ruRoberta-large')
    roberta_config = roberta_model.config
    longformer_config = LongformerConfig.from_pretrained(
            "allenai/longformer-large-4096",
            attention_window=512,
            hidden_size=roberta_config.hidden_size,
            num_hidden_layers=roberta_config.num_hidden_layers,
            num_attention_heads=roberta_config.num_attention_heads,
            intermediate_size=roberta_config.intermediate_size,
            hidden_act=roberta_config.hidden_act,
            hidden_dropout_prob=roberta_config.hidden_dropout_prob,
            attention_probs_dropout_prob=roberta_config.attention_probs_dropout_prob,
            max_position_embeddings=4098,
            type_vocab_size=roberta_config.type_vocab_size,
            initializer_range=roberta_config.initializer_range,
            layer_norm_eps=roberta_config.layer_norm_eps,
            gradient_checkpointing=roberta_config.gradient_checkpointing,
            pad_token_id=roberta_config.pad_token_id
        )
    
    longformer_model = LongformerForMaskedLM(longformer_config)
  2. Расширить позишн энкодинг роберты просто перекопировав обученный position_embedding на позиции выше 512 друг за другом - авторы оригинальной статьи утверждают, что это наиболее эффективный способ.

    config = longformer_model.config
    tokenizer = RobertaTokenizerFast.from_pretrained("ai-forever/ruRoberta-large", model_max_length=4096)
    tokenizer.model_max_length = config.max_position_embeddings-2
    tokenizer.init_kwargs['model_max_length'] = config.max_position_embeddings-2
    
    current_max_pos, embed_size = roberta_model.roberta.embeddings.position_embeddings.weight[:512,:].shape
    
    new_pos_embed = longformer_model.longformer.embeddings.position_embeddings.weight
    new_pos_embed.requires_grad=False
    
    k = 2
    step = current_max_pos
    
    while k < config.max_position_embeddings-1:
        new_pos_embed[k:(k + step)] = roberta_model.roberta.embeddings.position_embeddings.weight[:512]
        k += step
        
    roberta_model.roberta.embeddings.position_embeddings.weight.data = new_pos_embed
    roberta_model.roberta.embeddings.position_ids.data = torch.tensor([i for i in range(config.max_position_embeddings)]).reshape(1, config.max_position_embeddings)
  3. Поместить обученные веса роберты в соответствующие слои в лонгформере.

    longformer_model.longformer.load_state_dict(roberta_model.roberta.state_dict(), strict=False)
    longformer_model.lm_head.load_state_dict(roberta_model.lm_head.state_dict(), strict=False)
  4. Переиспользовать веса attention (query, key, value) для инициализации глобального аттеншна.

    import copy
    for i, (roberta_layer, longformer_layer) in enumerate(zip(roberta_model.roberta.encoder.layer, longformer_model.longformer.encoder.layer)):
        longformer_layer.attention.self.query_global = copy.deepcopy(roberta_layer.attention.self.query)
        longformer_layer.attention.self.key_global = copy.deepcopy(roberta_layer.attention.self.key)
        longformer_layer.attention.self.value_global = copy.deepcopy(roberta_layer.attention.self.value)
  5. Дообучить получившуюся модель на Masked Language Modeling задаче.

Все просто и никакой особой магии. Однако я решил пойти дальше.

Дело в том, что лонгформер создается из роберты и можно сделать версии base и large. Однако эти модели будут не очень применимы на практике в силу своего внушительного размера (large-версия занимала порядка 38Gb видеопамяти обучаясь на векторах в 4096 токенов длиной, обучение base со скрипом влезало на 16Gb карточку в fp16). И я подумал: "А почему бы не сделать тайни версию, ведь существует такой замечательный rubert-tiny2". Оказалось, что ее, с парой дополнительных строчек кода, также несложно адаптировать под длинный контекст. Но и на этом я не остановился. Планомерно увеличивая длину контекста, я дошел до значения в 16384 токенов, и эту модель я считаю самой полезной из трех выложенных русскоязычных лонгформеров.

Ноутбук с созданием ru-longformer-tiny-16384 тут.

MLM файнтюнинг

Для файнтюнинга лонгформера нам потребуется датасет с длинными текстами, чтобы модель научилась извлекать информацию из расширенного контекста. Я решил велосипед не изобретать и собрать свой датасет из готовых дампов Википедии, новостей, выгрузки постов с Habr и корпуса русских книг, отсекая тексты короче 10000 токенов. Получился файлик размером 2.5Gb и длиной примерно в 200M токенов. Ну и еще один файлик поменьше для валидации.

Как подсказал здравый смысл и подтвердили эксперименты, во время обучения логично заморозить все веса модели, кроме глобальных аттеншнов и позишн энкодинга, так как веса предобученных моделей и так оптимальны (кроме roberta-base, её как раз дотюнивал со всеми размороженными весами и значительно улучшил метрики на бенчмарках). Финальный сетап параметров обучения выглядел так:

training_args = TrainingArguments(
    output_dir="./longformer_mlm",
    overwrite_output_dir=True,
    do_eval=True,
    do_predict=True,
    fp16=False,
    gradient_checkpointing = True,
    num_train_epochs=2,
    per_device_train_batch_size=1,
    per_device_eval_batch_size=1,
    gradient_accumulation_steps=8,
    learning_rate=6*1e-4,
    optim="adafactor",
    warmup_ratio=0.2,
    save_steps=100,
    save_total_limit=2,
    evaluation_strategy="steps",
    prediction_loss_only=True,
    logging_steps=10
)

Вообще это обычное дообучение модели на MLM задаче, но есть пара нюансов:

  1. Кастомный DataCollator, в который добавляется глобал аттеншн, который объявляется для всех токенов [MASK].

    class DataCollatorWithGlobalAttention(DataCollatorForLanguageModeling):
        def __call__(self, examples):
            batch = super().__call__(examples)
            global_attention_mask = [
                [1 if token_id == tokenizer.mask_token_id else 0 for token_id in input_ids]
                for input_ids in batch["input_ids"]
            ]
    
            global_attention_mask = [
                mask if any(mask) else [1] + [0]*(len(mask)-1)
                for mask in global_attention_mask
            ]
    
            batch["global_attention_mask"] = torch.tensor(global_attention_mask)
            return batch
  2. Во время экспериментов на определенном этапе обучения лосс начинал расходиться, поэтому я пришел к выводу, что необходимо тюнить модель на смеси длинных и коротких текстов. Поэтому кастомный класс Dataset выглядит так:

    class LargeTextDataset(Dataset):
        # for 16k tokens set CONTEXT_LEN = 80k symbols 
        def __init__(self, text_file_path, tokenizer, max_sequence_length):
            self.max_sequence_length = max_sequence_length
            with open(text_file_path, "r") as file:
                self.large_text = file.read()
            self.tokenizer = tokenizer
    
        def __len__(self):
            return len(self.large_text) // self.max_sequence_length
    
        def __getitem__(self, idx):
            if random.randint(0,3) > 0: #33.3% short texts
                max_length = self.max_sequence_length
            else:
                max_length = random.randint(312, self.max_sequence_length) #312 min length
    
            start = random.randint(0, len(self.large_text) - self.max_sequence_length)
            end = start + max_length
            start = random.randint(0, len(self.large_text) - self.max_sequence_length)
            end = start + self.max_sequence_length
            chunk = self.large_text[start:end]
            tokens = self.tokenizer(chunk, truncation=True, padding='max_length', max_length=CONTEXT_LEN)
            return tokens

Бенчмарки

Так как основной идеей было сделать модель, адаптированную под длинные тексты, которая также не испортится при работе с короткими, то логично, что оценивать перфоманс модели нужно как на длиннотекстовых бенчмарках, так и на обычных.

С обычными проблем не возникло, лонгформер, сделанный из тайниберта имеет смысл сравнивать с результатами самого тайниберта, чтобы понять, что модель не испортилась. Я прогнал модель на куче задач из этого поста и привел ниже сравнение итоговых метрик:

TASK

ru-longformer-tiny-16384

rubert-tiny2

STSB

0.760

0.757

Paraphraser

0.647

0.648

NLI

0.416

0.415

Sentiment Analisys

0.739

0.736

Toxicity identification

0.938

0.936

Inappropriateness

0.749

0.748

Intents

0.755

0.756

IntentsX

0.626

0.636

FactRu

0.396

0.402

Rudr

0.387

0.390

Можно заключить, что на коротких текстах модель мы не испортили, изменения минимальны.

Что касается длинных текстов, я не нашел хороших готовых бенчмарков. Поэтому решил взять lenta-ru-news датасет, который содержит новостные статьи, размеченные по темам, и убрать все тексты короче 512 токенов. На этих данных несколько раз (с разным random_state в train_test_split) обучил ru-longformer-tiny-16384 в ForSequenceClassification режиме с глобальным аттеншном для [CLS] токена и его исходную модель - rubert-tiny2, разбив текст на чанки, получая эмбеддинги чанков, усредняя их и скармливая средний вектор в MLP. Метрики по нескольким итерациям обучения усреднил и посчитал стандартное отклонение для них:

ru-longformer-tiny-16384

rubert-tiny2

F1 macro (avg over multiple runs)

0.771

0.748

std

0.009

0.017

Из вышеперечисленного можно сделать вывод, что лонгформер, благодаря длинному контексту, улучшил перфоманс на длинных текстах и все еще так же эффективен на различных задачах с короткими текстами, как и его исходная модель.

Заключение

Я выложил три версии русскоязычного лонгформера:

Безусловно, large-версия, имея сильную исходную модель, по перфомансу опередит своих "коллег", однако она слишком большая, и скорость инференса также оставляет желать лучшего. С другой стороны, крошечный тайни лонгформер с тремя скрытыми слоями не отстает от своей исходной модели и, на мой взгляд, является лидером по сочетанию скорость-размер-качество.

P.S. большое спасибо за ценные советы и рекомендации автору телеграм канала AbstractDL

Комментарии (7)


  1. Andriljo
    14.09.2023 15:27

    Отличная работа, коллеги


  1. DarthPadla
    14.09.2023 15:27

    Твердый ридинг


  1. s_f1
    14.09.2023 15:27

    для русского языка

    • матрица аттеншна
    • эмбеддинги
    • позишн энкодинг 0_0

    Просто напомню:
    https://habr.com/ru/news/717646/


    1. MarkWatney
      14.09.2023 15:27
      +2

      Слово аттеншн используется не в значении "внимание". Тут это действительно отдельные термины, у которых свое значение.

      Вообще вопрос как правильно переводить термины, тут и аттеншн выглядит криво, и матрицей внимания это не назовёшь. И русских терминов и не появится, учитывая, что вся информация на английском.


      1. s_f1
        14.09.2023 15:27

        1. MarkWatney
          14.09.2023 15:27

          Ну может быть. Однако, когда attention в данном контексте имеют ввиду именно тот самый механизм внутри Attention-based слоя. Т.е. ссылаются на термин, который широко используется в актуальных статьях. А термин "механизм внимания" не так распространён, и, например, у меня не сразу бы возникла эта ассоциация. Даже в той статье все эти термины дублируются на английском по этой причине.

          К тому же global attention - это достаточно свежий термин. В итоге придётся придумывать свой перевод для него? Глобальное внимание будет совсем непонятно звучать.


  1. MarkWatney
    14.09.2023 15:27

    del