Несколько дней к ряду я занимался реставрацией легаси модели ai-forever/rugpt3xl, это классическая языковая модель от SberDevices на 1.3B параметров, крошка по современным меркам, на которой сберовцы обкатывали свои научные наработки аж в далёком 2021м году. Подробнее о ней можно почитать в статье “A family of pretrained transformer language models for Russian” на Google Scholar.

Да, она foundation, то есть умеет только продолжать текст, не может выполнять инструкции или работать в режиме чата. Но обучена она на корпусе русского языка и этот самый русский генерит очень бодро. У неё есть две примечательные особенности: её обучали с нуля, архитектура представляет собой глубокую модификацию GPT-2.

Превьюшку сгенерировал через ChatGPT
Превьюшку сгенерировал через ChatGPT

Предыстория

Меня давно беспокоило, что эта модель просто пылится на полке истории, она лежит на HuggingFace, формально доступна, но запустить её почти нереально, так как нужен стек технологий пятилетней давности. По сути модель заживо замурована в своём чекпоинте, а ведь это одна из первых серьёзных русскоязычных моделей, обученных с нуля. Мне давно хотелось сдуть с неё пыль, отреставрировать и выставить в импровизированном музее, чтобы любой желающий мог её потрогать руками и даже при желании обучить на своём датасете.

До этого момента пытался трижды, но все предыдущие разы не хватало знаний и опыта работы с внутренностями моделей, да и кодовые агенты тогда были не чета нынешним.

Проблема была в том, что оригинальный чекпоинт суть сырой mp_rank_00_model_states.pt, чтобы его загрузить, нужен полный стек Megatron-LM, DeepSpeed, apex, и всё это до кучи завязано на древние версии PyTorch 1.7 и transformers 3.5, короче говоря, запустить её в 2026 году “как есть” задача нетривиальная.

Но решил попробовать ещё один раз, благо думаю где мне знаний будет мало кодовый агент поможет, поднял старые заметки, склонировал всё что может понадобиться для работы и написал подробные спецификации, предварительную работу делал Cursor, а я уже тестировал и пинал фантазии ИИ в нужную сторону.

Конвертация в HuggingFace формат

Задача агента была простая, для начала он должен был изучить исходники ai-forever/ru-gpts, исходники современной библиотеки transformers, веса и конфиги модели deepseek-ai/DeepSeek-Coder-V2-Lite-Instruct в качестве примера того, как на HuggingFace можно закидывать кастомные классы моделей и конфигов, чтобы всё работало корректно с trust_remote_code=True.

Моя основная цель была в том, чтобы получить полностью рабочую модель, идентичную оригинальной на столько на сколько это возможно, которую можно было бы запустить через transformers, и при этом чтобы её можно было обучать, хоть через LoRA, хоть через SFT-тренеры из поставки transformers.

Спустя несколько часов пинаний и уточнений удалось понять, как устроены веса модели ruGPT3XL и как их можно конвертировать в формат, близкий к тому что описано в классе GPT2Model.

Вот как выглядит оригинальная структура модели из чекпоинта Megatron-LM:

word_embeddings.weight                 [50264, 2048]
position_embeddings.weight             [2048, 2048]
transformer.layers.{0..23}
  +-- input_layernorm.weight           [2048]
  +-- input_layernorm.bias             [2048]
  +-- attention
  |     +-- query_key_value.weight     [6144, 2048]  <- сшитые QKV
  |     +-- query_key_value.bias       [6144]
  |     +-- dense.weight               [2048, 2048]
  |     +-- dense.bias                 [2048]
  +-- post_attention_layernorm.weight  [2048]
  +-- post_attention_layernorm.bias    [2048]
  +-- mlp
        +-- dense_h_to_4h.weight       [8192, 2048]
        +-- dense_h_to_4h.bias         [8192]
        +-- dense_4h_to_h.weight       [2048, 8192]
        +-- dense_4h_to_h.bias         [2048]
transformer.final_layernorm.weight     [2048]
transformer.final_layernorm.bias       [2048]

А вот так выглядит структура GPT2 моделей:

model.embed_tokens.weight              [50264, 2048]
model.embed_positions.weight           [2048, 2048]
model.layers.{0..23}
  +-- input_layernorm.weight           [2048]
  +-- input_layernorm.bias             [2048]
  +-- self_attn
  |     +-- q_proj.weight              [2048, 2048]  <- отдельно Q
  |     +-- q_proj.bias                [2048]
  |     +-- k_proj.weight              [2048, 2048]  <- отдельно K
  |     +-- k_proj.bias                [2048]
  |     +-- v_proj.weight              [2048, 2048]  <- отдельно V
  |     +-- v_proj.bias                [2048]
  |     +-- o_proj.weight              [2048, 2048]
  |     +-- o_proj.bias                [2048]
  +-- post_attention_layernorm.weight  [2048]
  +-- post_attention_layernorm.bias    [2048]
  +-- mlp
        +-- up_proj.weight             [8192, 2048]
        +-- up_proj.bias               [8192]
        +-- down_proj.weight           [2048, 8192]
        +-- down_proj.bias             [2048]
model.norm.weight                      [2048]
model.norm.bias                        [2048]
lm_head.weight                         [50264, 2048]

Главное отличие в том, что сшитая QKV проекция размерностью [6144, 2048] разбивается на три отдельных линейных слоя Q, K, V размерностью [2048, 2048] каждый, плюс явный lm_head, который нужно будет скопировать из весов эмбеддингов.

Вот полная таблица маппинга:

Megatron-LM (оригинал)

HuggingFace (конвертированная)

word_embeddings.weight

model.embed_tokens.weight

position_embeddings.weight

model.embed_positions.weight

transformer.layers.{i}.input_layernorm.*

model.layers.{i}.input_layernorm.*

transformer.layers.{i}.attention.query_key_value.weight

model.layers.{i}.self_attn.{q,k,v}_proj.weight

transformer.layers.{i}.attention.query_key_value.bias

model.layers.{i}.self_attn.{q,k,v}_proj.bias

transformer.layers.{i}.attention.dense.*

model.layers.{i}.self_attn.o_proj.*

transformer.layers.{i}.post_attention_layernorm.*

model.layers.{i}.post_attention_layernorm.*

transformer.layers.{i}.mlp.dense_h_to_4h.*

model.layers.{i}.mlp.up_proj.*

transformer.layers.{i}.mlp.dense_4h_to_h.*

model.layers.{i}.mlp.down_proj.*

transformer.final_layernorm.*

model.norm.*

-

lm_head.weight (копия model.embed_tokens.weight)

Зная маппинг выполнить конвертацию не составлаяет уже большого труда, получился скрипт convert.py, который берёт оригинальный чекпоинт и превращает его в HuggingFace-модель с safetensors весами.

Но конвертация весов это только полдела, далее потребовалось написать кастомные классы модели и конфигурации, которые transformers смог бы загрузить.

В оригинальном репозитории ru-gpts есть класс RuGPT3XL это толстая обёртка вокруг мегатроновского GPT3Model, завёрнутого ещё и в FP16_Module. Чтобы просто загрузить модель, from_pretrained первым делом поднимает torch.distributed процесс-группу, инициализирует mpu (model parallelism utils), скачивает веса и deepspeed-конфиг, и только потом собирает модель через setup_model. KV-кеша нет, без CUDA и torch.distributed ничего не работает, красивый артефакт своей эпохи, но для современного пользователя всё это просто мёртвый груз.

Новые классы писались с нуля по примерам transformers, иерархия такая:

  • RuGPT3XLConfig - наследник PretrainedConfig из трансформерс со всеми параметрами модели (vocab_size, hidden_size, num_layers и так далее)

  • RuGPT3XLAttention - multi-head attention с отдельными Q/K/V проекциями и поддержкой DynamicCache из трансформерс

  • RuGPT3XMLP - блок персептрона (up_proj -> GELU -> down_proj)

  • RuGPT3XLDecoderLayer - слой декодера (LayerNorm -> Attention -> LayerNorm -> MLP)

  • RuGPT3XLModel - базовый трансформер модели (эмбеддинги + 24 слоя + финальный LayerNorm)

  • RuGPT3XLForCausalLM - обёртка конструктор с lm_head

Принципиальные отличия от оригинала:

  • Никаких зависимостей от Megatron-LM, DeepSpeed, apex, mpu, torch.distributed и прочей древности

  • Типовой forward() с сигнатурой, которую ожидает любой тренер, хоть LoRA/Pert, хоть SFTTrainer

  • Результаты возвращаются в CausalLMOutputWithPast из трансформерс вместо кастомного ModelOutput

  • Полноценный KV-кеш через трансформеровский DynamicCache для быстрой генерации

  • Поддержка gradient checkpointing на этапе обучения для оптимального управления памятью во время обучения

  • Ну и самое главное, моделька работает на CPU, GPU, через device_map="auto" без обязательного CUDA

По сути математика внутри осталась идентичной - тот же Pre-LayerNorm, та же GELU, те же размерности, тот же токенизатор, но вот обвязка вокруг модельки стала современной.

Веса залил на HuggingFace: evilfreelancer/ruGPT3XL.

Тестирование

Первым делом запустил пробную генерацию через generate.py скрипт, который грузит модель и скармливает ей несколько промптов:

python generate.py --model_path ./ruGPT3XL --dtype float16

На выходе три дефолтных промпта, и моделька вполне связный текст пишет - “Москва - столица”, “Искусственный интеллект - это”, “В далёком космосе”. Видно, что русский язык она знает хорошо, фразы строит грамотно, контекст держит, для foundation-модельки на 1.3B параметров прямиком из 2021 года это более чем достойно.

Дальше захотелось потыкать модельку посерьёзнее. Написал manual_test.py - скрипт на 24 промпта разного характера, от свободного продолжения текста (“Однажды в студёную зимнюю пору”, “Для приготовления борща нужно”) до простых вопрос-ответов через chat template (“Вопрос: Какая столица России? Ответ:”), гонял на RTX 4090 48Гб во float16, тут результаты.

Пара интересных наблюдений по результатам:

  • Средняя скорость генерации - 66.7 t/s на RTX 4090 (float16, batch_size=1)

  • На коротких промптах моделька часто останавливается рано (15-30 токенов), на длинных уверенно генерит до max_new_tokens

  • Вопрос-ответ через chat template работает, но ответы иногда… творческие. На “Кто написал Войну и мир?” модель уверенно ответила “Пушкин” (ну, foundation-модель, чтож поделать)

  • Хорошо справляется с продолжением фактического текста - рецепт борща, физика, история

  • На “Сколько планет в Солнечной системе?” ответила “Пять”, а потом ушла в рассуждения про цивилизации Млечного Пути

Ну и напоследок прогнал модельку через MERA, это такой открытый бенчмарк для оценки русскоязычных моделей, 23 задачи разного типа (логика, математика, знания о мире, код, этика, рассуждения). Загрузил результаты на mera.a-ai.ru, общий скор получился 0.198 (пока ещё на модерации).

Результаты MERA
Результаты MERA

Для понимания контекста - это base-модель 2021 года на 1.3B параметров, без instruction tuning, без RLHF, без всего того, что сейчас считается обязательным. Она умеет только продолжать текст. И тем не менее на некоторых задачах показывает вменяемые результаты:

  • PARus (здравый смысл) - 0.500

  • ruHateSpeech - 0.558

  • BPS (код/математика) - 0.528

  • RWSD (рассуждения) - 0.488

  • ruTiE (диалоговый контекст) - 0.502

  • ruMMLU - 0.252 (при random baseline ~0.25 для 4 вариантов ответа, модель на грани случайного угадывания, что для base-модели ожидаемо)

А вот математика (SimpleAr - 0.012, ruModAr - 0.001) и генерация кода (ruHumanEval, ruCodeEval - 0/0/0) ожидаемо в ноль, ведь модель и не учили решать подобные задачи.

Конвертация в GGUF

Но я не остановился на достигнутом. Захотелось развить успех и конвертировать модельку в GGUF формат для llama.cpp. Благодаря тому, что я потратил время на конвертацию в формат transformers, реализовать поддержку в llama.cpp удалось очень быстро. Там всего-навсего нужно было прокачать Python-конвертер convert_hf_to_gguf.py - модель по структуре похожа на GPT-2, поэтому патч объединяет отдельные q_proj, k_proj, v_proj обратно в единый QKV тензор, как того ожидает LLM_ARCH_GPT2.

Отправил PR #21011 в llama.cpp с наработками (кстати второй мой PR принятый в этот проект), прошёл ревью, были мелкие правки по хешу токенизатора, в итоге PR смерджили, а я смог конвертировать и квантовать модельки.

GGUF залил на HuggingFace: evilfreelancer/ruGPT3XL-GGUF.

Запуск инференса прямо в терминале:

llama-cli -hf evilfreelancer/ruGPT3XL-GGUF:Q4_K_M

Или поднять локальный OpenAI-совместимый сервер с веб-интерфейсом:

llama-server -hf evilfreelancer/ruGPT3XL-GGUF:Q4_K_M

Ну и под конец залил модельку на Ollama, там теперь собрано всё семейство ruGPT-3 в одном месте, все четыре размера - small (125M), medium (356M), large (760M) и xl (1.3B):

ollama run evilfreelancer/rugpt3:small
ollama run evilfreelancer/rugpt3:medium
ollama run evilfreelancer/rugpt3:large
ollama run evilfreelancer/rugpt3:xl

Такой вот музей классических русских языковых моделей в современной упаковке.

Ссылки

Послесловие

Печально, что наш бигтех вместо развития своих решений типа того же ruGPT или YaLM решили пойти по пути копирования и “инициализации из весов” других решений. По моему скромному мнению линейка ruGPT-3 была той самой точкой бифуркации, после которой мы свернули куда-то не туда, но это уже совсем другая история.

Ну я в свою очередь благодарю вас за прочтение, надеюсь мои наработки пригодятся, буду рад фидбеку в комментариях или в телеграме.

Комментарии (13)


  1. ComputerPers
    28.03.2026 02:23

    Отличная работа, статья очень зашла


    1. efreelancer Автор
      28.03.2026 02:23

      Благодарю за комментарий, мне просто очень захотелось вновь порабоать с этой моделькой и пару дней спустя поулчился этот пост)


  1. UtrobinMV
    28.03.2026 02:23

    Мое личное предположение, что архитектура DeepSeek просто более удачная, а это тоже немаловажно. Поэтому и свернули в ту сторону. Тем более что не нужно было бы писать и адаптировать инференс.


    1. efreelancer Автор
      28.03.2026 02:23

      В целом согласен, DeepSeek и правда очень удачная архитектура, она очень похожа на архитектуру Llama, но с нюансами MoE и поддержкой мультихед аттеншен, просто проявилась она на четыре года позже отечественных экспериментов над архитектурой GPT2 в моделях ruGPT, сделаю смелое предположение, но мне кажется, что если бы наш бигтех не забросил попытки делать что-то своё, то вероятно архитектура ruGPT конкурировала бы сейчас с другими решениями.

      В случае с ruGPT инференс тоже не нужно было сильно адаптировать, эта архитектура по сути глубокая модификация GPT2 и работает соответствующим образом, на её основе в 21м году стоило бы делать какие-то специализированные модели, инстракт, чат, кодовые модели и так далее, но к сожалению в паблик выпускались только foundation.


  1. mrbp_old
    28.03.2026 02:23

    Тот случай, когда задачу начинали решать математики, инженеры, программисты, одаренные и гениальные творцы, а потом пришел "эффективный менеджер" и всех увёл в другую сторону. И в результате то, что должно было стать с прекрасной каретой - превратилось в тыкву. А сама идея и подход осталось пылиться на полках как история.


    1. efreelancer Автор
      28.03.2026 02:23

      Как там было пять лет назад сказать уже сложно, но имеем что имеем, добротная архитектура и занятные модельки на её основе заброшены, а ведь люди которые их делали (в посте будет ссылочка на научную работу) потратили на них много времени и сил, мне просто не хотелось чтобы их труд пропадал даром, авось кто заинтересуется в ruGPT3 XL когда-нибудь.


  1. ash_lm
    28.03.2026 02:23

    Скрытый текст

    Мне кажется это какой-то рандомайзер фраз.


    1. Anton_Timofeev
      28.03.2026 02:23

      В статье же написано, что это классический GPT - т.е. продолжатель фраз, и на диалог её не дообучали.


      1. ash_lm
        28.03.2026 02:23

        Я читал в начале, что "умеет только продолжать текст", не совсем было понятно как она работает поэтому и поставил. Но на мой взгляд это реально рандомайзер и что там от нейросети не совсем понятно.


        1. Shannon
          28.03.2026 02:23

          Вы пытаетесь работать с base моделью как с instruct моделью.

          Base (или pretrain) - это 1 шаг обучения LLM из 3. Смысл pretrain в том, чтобы модель умела составлять буквы в слова, слова в предложения, предложения были орфографически верные, логически правильные, набирала базу знаний и так далее. Такая нейросеть умеет продолжать текст, вы пишите ей “cat = кошка, dog = собака, duck =” и она продолжает “утка”.

          Но если вы пытаетесь с ней общаться в чате, то ей будет послан шаблон чата, который может выглядеть очень не типично, например, вот так:

          "<|im_start|>user\n{question}<|im_end|>\n<|im_start|>assistant\n{answer}<|im_end|>"

          Поэтому базовая модель в таком сценарии начинает генерировать хаотичные ответы. Чтобы base научилась нормально вести диалог, нужно обучить её шаблону.

          Для примера, обучим эту ruGPT3 XL base до уровня instruct (до очень примитивного, так как датасет должен быть куда разнообразнее). Генерируем 100 пар вопрос-ответ такого вида:

          Теперь из этих пар нужно создать датасет согласно шаблону-чата модели, в данном случае это:

          Вопрос: ...
          Ответ: ...
          

          Обучаем нейросеть как finetune целиком или как LoRA адаптер. Для примера хватит 3 эпох на 100 примерах:

          Через пару минут “ruGPT3XL-instruct” готова. Теперь можно задать какой-нибудь вопрос, которого не было в дополнительном датасете:

          Кто такой Шекспир?
          Кто такой Шекспир?
          Ты знаешь что-то про Fallout?
          Ты знаешь что-то про Fallout?

          Мы не учили модель отвечать, кто такой Шекспир, и не учили её ничему про Fallout, это уже было в модели. Мы только научили её обрабатывать шаблон чата, и, побочно, структуре ответа как из энциклопедии. Аналогично её можно научить агентным задачам и так далее.


          1. efreelancer Автор
            28.03.2026 02:23

            Только что запушил на HF исправление кода модели связанного со Sparse Attention (разреженное внимание), веса остались прежние, попробуйте переобучить модельку.


            1. Shannon
              28.03.2026 02:23

              Тут, конечно, это уже надо на нормальном датасете обучать, чтобы заметить разницу. Но архитектура старая, поддержки triton нет, обучение будет в разы дольше, чем дообучить даже, например, Qwen3.5 4B.

              Но в целом можно попробовать и так. Сходу при запуске инференса со Sparse Attn ошибка:

              cuda\TensorCompare.cu:109: block: [0,0,0], thread: [0,0,0] Assertion `input[0] != 0` failed.
                  next_tokens = torch.multinomial(probs, num_samples=1).squeeze(1)
                                ~~~~~~~~~~~~~~~~~^^^^^^^^^^^^^^^^^^^^^^
              torch.AcceleratorError: CUDA error: device-side assert triggered
              

              Если посмотреть вывод модели, то вероятности обнулены, логиты в бесконечности, а маска внимания все маркирует как False. Можно попробовать починить:

              modeling_rugpt3xl.py
              """PyTorch RuGPT-3 XL model.
              
              GPT-3-style decoder-only transformer (1.3B) trained on Russian text.
              Architecture: absolute position embeddings, pre-norm layers, GELU activation,
              tied LM head.
              """
              
              import math
              from typing import List, Optional, Tuple, Union
              
              import torch
              import torch.nn as nn
              import torch.nn.functional as F
              import torch.utils.checkpoint
              
              from transformers.activations import ACT2FN
              from transformers.cache_utils import Cache, DynamicCache
              from transformers.generation import GenerationMixin
              from transformers.modeling_outputs import (
                  BaseModelOutputWithPast,
                  CausalLMOutputWithPast,
              )
              from transformers.modeling_utils import PreTrainedModel
              from transformers.utils import logging
              
              from .configuration_rugpt3xl import RuGPT3XLConfig
              
              logger = logging.get_logger(__name__)
              
              
              def _make_sparse_layout(
                  num_heads: int,
                  num_blocks: int,
                  num_local_blocks: int,
                  num_global_blocks: int,
                  num_different_global_patterns: int,
                  device: torch.device,
              ) -> torch.Tensor:
                  """Build FixedSparsity boolean layout on *device*.
              
                  Returns [num_heads, num_blocks, num_blocks] bool tensor.
                  """
                  layout = torch.zeros(
                      num_heads, num_blocks, num_blocks, dtype=torch.bool, device=device,
                  )
              
                  for win in range(0, num_blocks, num_local_blocks):
                      end = min(win + num_local_blocks, num_blocks)
                      sz = end - win
                      layout[:, win:end, win:end] = torch.tril(
                          torch.ones(sz, sz, dtype=torch.bool, device=device)
                      )
              
                  for h in range(num_heads):
                      first = num_local_blocks - (
                          1 + h % num_different_global_patterns
                      ) * num_global_blocks
                      reg_end = num_blocks - (num_blocks % num_local_blocks)
                      for gi in range(first, reg_end, num_local_blocks):
                          layout[h, gi:, gi : gi + num_global_blocks] = True
                      if reg_end < num_blocks:
                          s = min(reg_end + first, num_blocks - num_global_blocks)
                          layout[h, s:, s : s + num_global_blocks] = True
              
                  return layout
              
              
              class RuGPT3XLAttention(nn.Module):
                  def __init__(self, config: RuGPT3XLConfig, layer_idx: int):
                      super().__init__()
                      self.config = config
                      self.layer_idx = layer_idx
                      self.hidden_size = config.hidden_size
                      self.num_heads = config.num_attention_heads
                      self.head_dim = self.hidden_size // self.num_heads
                      self.scale = self.head_dim ** -0.5
              
                      self.q_proj = nn.Linear(self.hidden_size, self.hidden_size)
                      self.k_proj = nn.Linear(self.hidden_size, self.hidden_size)
                      self.v_proj = nn.Linear(self.hidden_size, self.hidden_size)
                      self.o_proj = nn.Linear(self.hidden_size, self.hidden_size)
              
                      self.attn_dropout = nn.Dropout(config.attention_dropout)
                      self.resid_dropout = nn.Dropout(config.output_dropout)
              
                  def forward(
                      self,
                      hidden_states: torch.Tensor,
                      attention_mask: Optional[torch.Tensor] = None,
                      position_ids: Optional[torch.LongTensor] = None,
                      past_key_value: Optional[Cache] = None,
                      output_attentions: bool = False,
                      use_cache: bool = False,
                      **kwargs,
                  ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Cache]]:
                      bsz, q_len, _ = hidden_states.size()
              
                      query = self.q_proj(hidden_states)
                      key = self.k_proj(hidden_states)
                      value = self.v_proj(hidden_states)
              
                      query = query.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
                      key = key.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
                      value = value.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
              
                      if past_key_value is not None:
                          key, value = past_key_value.update(key, value, self.layer_idx)
              
                      attn_weights = torch.matmul(query, key.transpose(2, 3)) * self.scale
              
                      if attention_mask is not None:
                          attn_weights = attn_weights + attention_mask
              
                      attn_weights = F.softmax(attn_weights, dim=-1, dtype=torch.float32).to(
                          query.dtype
                      )
                      attn_weights = self.attn_dropout(attn_weights)
              
                      attn_output = torch.matmul(attn_weights, value)
                      attn_output = attn_output.transpose(1, 2).contiguous()
                      attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
              
                      attn_output = self.o_proj(attn_output)
                      attn_output = self.resid_dropout(attn_output)
              
                      return (
                          attn_output,
                          attn_weights if output_attentions else None,
                          past_key_value,
                      )
              
              
              class RuGPT3XMLP(nn.Module):
                  def __init__(self, config: RuGPT3XLConfig):
                      super().__init__()
                      self.up_proj = nn.Linear(config.hidden_size, config.intermediate_size)
                      self.down_proj = nn.Linear(config.intermediate_size, config.hidden_size)
                      self.act_fn = ACT2FN[config.hidden_act]
                      self.dropout = nn.Dropout(config.output_dropout)
              
                  def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
                      return self.dropout(self.down_proj(self.act_fn(self.up_proj(hidden_states))))
              
              
              class RuGPT3XLDecoderLayer(nn.Module):
                  def __init__(self, config: RuGPT3XLConfig, layer_idx: int):
                      super().__init__()
                      self.input_layernorm = nn.LayerNorm(
                          config.hidden_size, eps=config.layer_norm_eps
                      )
                      self.self_attn = RuGPT3XLAttention(config, layer_idx)
                      self.post_attention_layernorm = nn.LayerNorm(
                          config.hidden_size, eps=config.layer_norm_eps
                      )
                      self.mlp = RuGPT3XMLP(config)
              
                  def forward(
                      self,
                      hidden_states: torch.Tensor,
                      attention_mask: Optional[torch.Tensor] = None,
                      position_ids: Optional[torch.LongTensor] = None,
                      past_key_value: Optional[Cache] = None,
                      output_attentions: bool = False,
                      use_cache: bool = False,
                      **kwargs,
                  ) -> Tuple[torch.Tensor, ...]:
                      residual = hidden_states
                      hidden_states = self.input_layernorm(hidden_states)
                      hidden_states, self_attn_weights, present_key_value = self.self_attn(
                          hidden_states=hidden_states,
                          attention_mask=attention_mask,
                          position_ids=position_ids,
                          past_key_value=past_key_value,
                          output_attentions=output_attentions,
                          use_cache=use_cache,
                          **kwargs,
                      )
                      hidden_states = residual + hidden_states
              
                      residual = hidden_states
                      hidden_states = self.post_attention_layernorm(hidden_states)
                      hidden_states = self.mlp(hidden_states)
                      hidden_states = residual + hidden_states
              
                      outputs = (hidden_states,)
                      if output_attentions:
                          outputs += (self_attn_weights,)
                      if use_cache:
                          outputs += (present_key_value,)
                      return outputs
              
              
              class RuGPT3XLPreTrainedModel(PreTrainedModel):
                  config_class = RuGPT3XLConfig
                  base_model_prefix = "model"
                  supports_gradient_checkpointing = True
                  _no_split_modules = ["RuGPT3XLDecoderLayer"]
                  _skip_keys_device_placement = ["past_key_values"]
                  _supports_cache_class = True
              
                  def _init_weights(self, module):
                      std = self.config.initializer_range
                      if isinstance(module, nn.Linear):
                          module.weight.data.normal_(mean=0.0, std=std)
                          if module.bias is not None:
                              module.bias.data.zero_()
                      elif isinstance(module, nn.Embedding):
                          module.weight.data.normal_(mean=0.0, std=std)
                          if module.padding_idx is not None:
                              module.weight.data[module.padding_idx].zero_()
                      elif isinstance(module, nn.LayerNorm):
                          module.bias.data.zero_()
                          module.weight.data.fill_(1.0)
              
              
              class RuGPT3XLModel(RuGPT3XLPreTrainedModel):
                  """Bare RuGPT-3 XL transformer outputting raw hidden states."""
              
                  def __init__(self, config: RuGPT3XLConfig):
                      super().__init__(config)
                      self.padding_idx = config.pad_token_id
                      self.vocab_size = config.vocab_size
              
                      self.embed_tokens = nn.Embedding(
                          config.vocab_size, config.hidden_size, self.padding_idx
                      )
                      self.embed_positions = nn.Embedding(
                          config.max_position_embeddings, config.hidden_size
                      )
                      self.embed_dropout = nn.Dropout(config.embedding_dropout)
              
                      self.layers = nn.ModuleList(
                          [
                              RuGPT3XLDecoderLayer(config, layer_idx)
                              for layer_idx in range(config.num_hidden_layers)
                          ]
                      )
                      self.norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
              
                      # Sparse attention config
                      self._sparse_layers: set = set()
                      if getattr(config, "sparse_mode", "none") == "alternating":
                          self._sparse_layers = {
                              i for i in range(config.num_hidden_layers) if i % 2 == 0
                          }
                      elif getattr(config, "sparse_mode", "none") == "all":
                          self._sparse_layers = set(range(config.num_hidden_layers))
              
                      # Sparse layout will be lazily built on first forward.
                      # NOT registered as a buffer to avoid meta-device corruption.
                      self._sparse_layout: Optional[torch.Tensor] = None
              
                      self.gradient_checkpointing = False
                      self.post_init()
              
                  def _get_sparse_layout(self, device: torch.device) -> torch.Tensor:
                      """Return sparse layout tensor on *device*, building it if necessary."""
                      if self._sparse_layout is not None and self._sparse_layout.device == device:
                          return self._sparse_layout
              
                      cfg = self.config
                      num_blocks = cfg.max_position_embeddings // cfg.sparse_block_size
                      self._sparse_layout = _make_sparse_layout(
                          num_heads=cfg.num_attention_heads,
                          num_blocks=num_blocks,
                          num_local_blocks=cfg.sparse_num_local_blocks,
                          num_global_blocks=cfg.sparse_num_global_blocks,
                          num_different_global_patterns=cfg.sparse_num_different_global_patterns,
                          device=device,
                      )
                      return self._sparse_layout
              
                  def get_input_embeddings(self):
                      return self.embed_tokens
              
                  def set_input_embeddings(self, value):
                      self.embed_tokens = value
              
                  def forward(
                      self,
                      input_ids: Optional[torch.LongTensor] = None,
                      attention_mask: Optional[torch.Tensor] = None,
                      position_ids: Optional[torch.LongTensor] = None,
                      past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None,
                      inputs_embeds: Optional[torch.FloatTensor] = None,
                      use_cache: Optional[bool] = None,
                      output_attentions: Optional[bool] = None,
                      output_hidden_states: Optional[bool] = None,
                      return_dict: Optional[bool] = None,
                      **kwargs,
                  ) -> Union[Tuple, BaseModelOutputWithPast]:
                      output_attentions = (
                          output_attentions
                          if output_attentions is not None
                          else self.config.output_attentions
                      )
                      output_hidden_states = (
                          output_hidden_states
                          if output_hidden_states is not None
                          else self.config.output_hidden_states
                      )
                      use_cache = use_cache if use_cache is not None else self.config.use_cache
                      return_dict = (
                          return_dict if return_dict is not None else self.config.use_return_dict
                      )
              
                      if input_ids is not None and inputs_embeds is not None:
                          raise ValueError(
                              "You cannot specify both input_ids and inputs_embeds"
                          )
                      if input_ids is not None:
                          batch_size, seq_length = input_ids.shape[:2]
                      elif inputs_embeds is not None:
                          batch_size, seq_length = inputs_embeds.shape[:2]
                      else:
                          raise ValueError(
                              "You have to specify either input_ids or inputs_embeds"
                          )
              
                      if self.gradient_checkpointing and self.training and use_cache:
                          logger.warning_once(
                              "`use_cache=True` is incompatible with gradient checkpointing. "
                              "Setting `use_cache=False`."
                          )
                          use_cache = False
              
                      past_key_values_length = 0
                      if use_cache:
                          if past_key_values is None:
                              past_key_values = DynamicCache()
                          past_key_values_length = past_key_values.get_seq_length()
              
                      if position_ids is None:
                          device = (
                              input_ids.device if input_ids is not None else inputs_embeds.device
                          )
                          position_ids = torch.arange(
                              past_key_values_length,
                              seq_length + past_key_values_length,
                              dtype=torch.long,
                              device=device,
                          ).unsqueeze(0)
              
                      if inputs_embeds is None:
                          inputs_embeds = self.embed_tokens(input_ids)
              
                      position_embeds = self.embed_positions(position_ids)
                      hidden_states = self.embed_dropout(inputs_embeds + position_embeds)
              
                      # Dense causal mask
                      causal_mask = self._build_causal_mask(
                          batch_size,
                          seq_length,
                          past_key_values_length,
                          hidden_states.dtype,
                          hidden_states.device,
                          attention_mask,
                      )
              
                      # Sparse causal mask (lazily build layout on correct device)
                      sparse_mask = None
                      if self._sparse_layers:
                          sparse_layout = self._get_sparse_layout(hidden_states.device)
                          sparse_mask = self._build_sparse_causal_mask(
                              seq_length,
                              past_key_values_length,
                              hidden_states.dtype,
                              hidden_states.device,
                              sparse_layout,
                              self.config.sparse_block_size,
                              attention_mask,
                          )
              
                      all_hidden_states = () if output_hidden_states else None
                      all_self_attns = () if output_attentions else None
                      next_decoder_cache = None
              
                      for layer_idx, decoder_layer in enumerate(self.layers):
                          if output_hidden_states:
                              all_hidden_states += (hidden_states,)
              
                          layer_mask = (
                              sparse_mask
                              if (layer_idx in self._sparse_layers and sparse_mask is not None)
                              else causal_mask
                          )
              
                          if self.gradient_checkpointing and self.training:
                              layer_outputs = self._gradient_checkpointing_func(
                                  decoder_layer.__call__,
                                  hidden_states,
                                  layer_mask,
                                  position_ids,
                                  past_key_values,
                                  output_attentions,
                                  use_cache,
                              )
                          else:
                              layer_outputs = decoder_layer(
                                  hidden_states,
                                  attention_mask=layer_mask,
                                  position_ids=position_ids,
                                  past_key_value=past_key_values,
                                  output_attentions=output_attentions,
                                  use_cache=use_cache,
                              )
              
                          hidden_states = layer_outputs[0]
                          if use_cache:
                              next_decoder_cache = layer_outputs[
                                  2 if output_attentions else 1
                              ]
                          if output_attentions:
                              all_self_attns += (layer_outputs[1],)
              
                      hidden_states = self.norm(hidden_states)
              
                      if output_hidden_states:
                          all_hidden_states += (hidden_states,)
              
                      next_cache = next_decoder_cache if use_cache else None
              
                      if not return_dict:
                          return tuple(
                              v
                              for v in [
                                  hidden_states,
                                  next_cache,
                                  all_hidden_states,
                                  all_self_attns,
                              ]
                              if v is not None
                          )
                      return BaseModelOutputWithPast(
                          last_hidden_state=hidden_states,
                          past_key_values=next_cache,
                          hidden_states=all_hidden_states,
                          attentions=all_self_attns,
                      )
              
                  @staticmethod
                  def _build_causal_mask(
                      batch_size: int,
                      seq_length: int,
                      past_length: int,
                      dtype: torch.dtype,
                      device: torch.device,
                      attention_mask: Optional[torch.Tensor] = None,
                  ) -> torch.Tensor:
                      total_length = past_length + seq_length
                      causal = torch.full(
                          (seq_length, total_length),
                          torch.finfo(dtype).min,
                          device=device,
                      )
                      causal = causal.masked_fill(
                          torch.arange(total_length, device=device).unsqueeze(0)
                          <= torch.arange(
                              past_length, past_length + seq_length, device=device
                          ).unsqueeze(1),
                          0.0,
                      )
                      causal = causal.unsqueeze(0).unsqueeze(0)
              
                      if attention_mask is not None:
                          pad_mask = (
                              (1 - attention_mask[:, None, None, :].to(dtype))
                              * torch.finfo(dtype).min
                          )
                          causal = causal + pad_mask
              
                      return causal
              
                  @staticmethod
                  def _build_sparse_causal_mask(
                      seq_length: int,
                      past_length: int,
                      dtype: torch.dtype,
                      device: torch.device,
                      sparse_layout: torch.Tensor,
                      block_size: int,
                      attention_mask: Optional[torch.Tensor] = None,
                  ) -> torch.Tensor:
                      total_length = past_length + seq_length
                      num_blocks = sparse_layout.shape[1]
              
                      q_block = (
                          torch.arange(past_length, past_length + seq_length, device=device)
                          // block_size
                      ).clamp(max=num_blocks - 1)
                      k_block = (
                          torch.arange(total_length, device=device) // block_size
                      ).clamp(max=num_blocks - 1)
              
                      block_ok = sparse_layout[:, q_block][:, :, k_block]
              
                      q_pos = torch.arange(
                          past_length, past_length + seq_length, device=device
                      ).unsqueeze(1)
                      k_pos = torch.arange(total_length, device=device).unsqueeze(0)
                      causal_ok = k_pos <= q_pos
              
                      allowed = block_ok & causal_ok.unsqueeze(0)
              
                      min_val = torch.finfo(dtype).min
                      mask = torch.where(allowed, 0.0, min_val).to(dtype).unsqueeze(0)
              
                      if attention_mask is not None:
                          pad_mask = (
                              (1 - attention_mask[:, None, None, :].to(dtype)) * min_val
                          )
                          mask = mask + pad_mask
              
                      return mask
              
              
              class RuGPT3XLForCausalLM(RuGPT3XLPreTrainedModel, GenerationMixin):
                  _tied_weights_keys = {"lm_head.weight": "model.embed_tokens.weight"}
                  _supports_cache_class = True
              
                  def __init__(self, config: RuGPT3XLConfig):
                      super().__init__(config)
                      self.model = RuGPT3XLModel(config)
                      self.vocab_size = config.vocab_size
                      self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
                      self.post_init()
              
                  def get_input_embeddings(self):
                      return self.model.embed_tokens
              
                  def set_input_embeddings(self, value):
                      self.model.embed_tokens = value
              
                  def get_output_embeddings(self):
                      return self.lm_head
              
                  def set_output_embeddings(self, new_embeddings):
                      self.lm_head = new_embeddings
              
                  def get_decoder(self):
                      return self.model
              
                  def set_decoder(self, decoder):
                      self.model = decoder
              
                  def forward(
                      self,
                      input_ids: Optional[torch.LongTensor] = None,
                      attention_mask: Optional[torch.Tensor] = None,
                      position_ids: Optional[torch.LongTensor] = None,
                      past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None,
                      inputs_embeds: Optional[torch.FloatTensor] = None,
                      labels: Optional[torch.LongTensor] = None,
                      use_cache: Optional[bool] = None,
                      output_attentions: Optional[bool] = None,
                      output_hidden_states: Optional[bool] = None,
                      return_dict: Optional[bool] = None,
                      **kwargs,
                  ) -> Union[Tuple, CausalLMOutputWithPast]:
                      output_attentions = (
                          output_attentions
                          if output_attentions is not None
                          else self.config.output_attentions
                      )
                      output_hidden_states = (
                          output_hidden_states
                          if output_hidden_states is not None
                          else self.config.output_hidden_states
                      )
                      return_dict = (
                          return_dict if return_dict is not None else self.config.use_return_dict
                      )
              
                      outputs = self.model(
                          input_ids=input_ids,
                          attention_mask=attention_mask,
                          position_ids=position_ids,
                          past_key_values=past_key_values,
                          inputs_embeds=inputs_embeds,
                          use_cache=use_cache,
                          output_attentions=output_attentions,
                          output_hidden_states=output_hidden_states,
                          return_dict=return_dict,
                      )
              
                      hidden_states = outputs[0]
                      logits = self.lm_head(hidden_states).float()
              
                      loss = None
                      if labels is not None:
                          shift_logits = logits[..., :-1, :].contiguous()
                          shift_labels = labels[..., 1:].contiguous()
                          loss_fct = nn.CrossEntropyLoss()
                          shift_logits = shift_logits.view(-1, self.config.vocab_size)
                          shift_labels = shift_labels.view(-1).to(shift_logits.device)
                          loss = loss_fct(shift_logits, shift_labels)
              
                      if not return_dict:
                          output = (logits,) + outputs[1:]
                          return (loss,) + output if loss is not None else output
              
                      return CausalLMOutputWithPast(
                          loss=loss,
                          logits=logits,
                          past_key_values=outputs.past_key_values,
                          hidden_states=outputs.hidden_states,
                          attentions=outputs.attentions,
                      )
              
                  def prepare_inputs_for_generation(
                      self,
                      input_ids,
                      past_key_values=None,
                      attention_mask=None,
                      inputs_embeds=None,
                      **kwargs,
                  ):
                      if past_key_values is not None:
                          past_length = past_key_values.get_seq_length()
                          if (
                              attention_mask is not None
                              and attention_mask.shape[1] > input_ids.shape[1]
                          ):
                              input_ids = input_ids[
                                  :, -(attention_mask.shape[1] - past_length) :
                              ]
                          elif past_length < input_ids.shape[1]:
                              input_ids = input_ids[:, past_length:]
              
                      position_ids = kwargs.get("position_ids", None)
                      if attention_mask is not None and position_ids is None:
                          position_ids = attention_mask.long().cumsum(-1) - 1
                          position_ids.masked_fill_(attention_mask == 0, 1)
                      if position_ids is not None and past_key_values is not None:
                          position_ids = position_ids[:, -input_ids.shape[1] :]
              
                      if inputs_embeds is not None and past_key_values is None:
                          model_inputs = {"inputs_embeds": inputs_embeds}
                      else:
                          model_inputs = {"input_ids": input_ids}
              
                      model_inputs.update(
                          {
                              "position_ids": position_ids,
                              "past_key_values": past_key_values,
                              "use_cache": kwargs.get("use_cache"),
                              "attention_mask": attention_mask,
                          }
                      )
                      return model_inputs
              

              Теперь запускается нормально и можно обучить, и чтобы проверить работает ли вообще отсечение, в forward можно добавить отладочную информацию:

              Добавить в forward
                      next_decoder_cache = None
              
                      # === start debug ===
                      if seq_length > 1 and sparse_mask is not None:
                          total_len = seq_length + past_key_values_length
                          
                          min_val_dense = torch.finfo(causal_mask.dtype).min
                          min_val_sparse = torch.finfo(sparse_mask.dtype).min
                          
                          q_dim = causal_mask.shape[-2]
                          k_dim = causal_mask.shape[-1]
                          total_elements = q_dim * k_dim
                          
                          num_matrices_dense = causal_mask.numel() // total_elements
                          num_matrices_sparse = sparse_mask.numel() // total_elements
                          
                          dense_blocked = int((causal_mask == min_val_dense).sum().item() / num_matrices_dense)
                          sparse_blocked = int((sparse_mask == min_val_sparse).sum().item() / num_matrices_sparse)
                          extra_blocked = sparse_blocked - dense_blocked
                          
                          allowed_in_dense = total_elements - dense_blocked
                          sparse_penalty_pct = (extra_blocked / allowed_in_dense) * 100 if allowed_in_dense > 0 else 0.0
              
                          print(f"\n[SPARSE DEBUG] seq_len={seq_length}, total_len={total_len}")
                          print(f"Каузальная маска (Dense): shape={tuple(causal_mask.shape)}, blocked={dense_blocked}/{total_elements} ({(dense_blocked/total_elements)*100:.1f}%)")
                          print(f"Разреженная маска (Sparse): shape={tuple(sparse_mask.shape)}, blocked={sparse_blocked}/{total_elements} ({(sparse_blocked/total_elements)*100:.1f}%)")
                          print(f"Уровень разреженности (Sparse): +{sparse_penalty_pct:.1f}% (отсечено {extra_blocked} связей)")
                      # === end debug === 
              
                      for layer_idx, decoder_layer in enumerate(self.layers):
              

              Так как Sparse Attn тут настроен на 128 токенов, то если запрос меньше, то Sparse должна быть равна Dense и фактически не работать:

              Запрос длиннее 128, тут Sparse успешно отсекает и с виду нормально работает:

              Попробовал обучить LoRA на первом попавшемся датасете диалогов Den4ikAI/russian_dialogues, датасет на 2.5 млн строк или 40m токенов. Обучил только на первых 10000 за 5-10 минут.

              Обучилась успешно, и приобрела особый стиль общения, который присутствовал в датасете:

              Датасет, наверное, не очень удачный, но это хороший пример, что если есть base, то из базы можно выровнять (alignment) модель до различных состояний.


    1. efreelancer Автор
      28.03.2026 02:23

      Всё верно, моделька не расчитана на читчат, так как она лишь foundation (заготовка) и по хорошему её следует обучить на своём датасете под конкретную задачу.