Некоторое время назад компания Google DeepMind представила Gemini Diffusion — экспериментальную языковую модель, генерирующую текст методом диффузии. В отличие от традиционных моделей, написанных в стиле GPT и генерирующих слово за словом,  Gemini создаёт текст целыми блоками, пошагово уточняя случайный шум.

Я прочитал статью «Large Language Diffusion Models» — и с удивлением узнал, что дискретная диффузия языка представляет собой просто обобщение метода генерации пропущенного токена (MLM), практикуемого уже с 2018 года. Я сразу подумал: «А можно ли тонко настроить BERT-подобную модель так, чтобы приспособить её к генерации текста?» Из чистого любопытства решил наскоро набросать проверку концепции.

Примечание: уже после того, как написал эту статью, я наткнулся на исследование DiffusionBERT, где сделано практически то же самое, что и у меня, но проект гораздо тщательнее протестирован. Посмотрите этот пост, если тема вас заинтересовала.

Краткая история трансформеров

Первая версия архитектуры трансформеров, представленная в 2017 году, являла собой модель на основе энкодера и декодера. В 2018 году исследователи осознали, что энкодерную и декодерную составляющие модели можно разделить (что привело к появлению BERT и GPT), после чего были два обособленных семейства моделей:

  • Чисто энкодерные модели (BERT-подобные, двунаправленные)

Целью обучения энкодерных моделей является генерация пропущенного токена (MLM). В каждой входной последовательности случайным образом маскируется некоторое подмножество токенов, после чего энкодер учится реконструировать недостающие токены (заполнять пробелы). Модель сразу видит целостный (частично скрытый) контекст и изучает двунаправленные представления. Такая архитектура преуспевала при решении задач, в которых требовалось представлять целые предложения (или абзацы) — например, при классификации и извлечении.  

  • Чисто декодерные модели (GPT-подобные, авторегрессионные)

Целью обучения декодерных моделей является прогнозирование следующего токена. На каждой позиции t нужно предсказать токен, который будет находиться на позиции t+1 с учётом всех токенов, предшествующих t и служащих контекстом. Для предсказания будущих значений используется только та информация, которая находится слева от рассматриваемого токена (процесс однонаправленный). Такая архитектура прекрасно показала себя в генеративных задачах, где порождается токен за токеном — например, в генерации с открытым финалом, резюмировании и переводе.

Изначально предполагалось, что BERT естественным образом подходит для решения таких задач как классификация. В свою очередь, GPT-подобные модели снискали популярность не сразу (поскольку возможности их ранних версий были довольно скромными). В конце концов, генеративные возможности авторегрессионных (декодерных) моделей сильно выросли. Общее направление обучения «спрогнозировать следующий токен» предполагает значительно более пространство для практического применения, чем при работе с энкодерными моделями.

Дискретные диффузионные языковые модели

Популяризация диффузионных моделей началась с генерации изображений. Работая в этой области, диффузионные модели постепенно напитывают изображение гауссовским шумом (прямой процесс), а затем обучают нейронную сеть вычищать его в ходе последовательных итераций (обратный процесс). Вот как можно в самом общем виде резюмировать непрерывную диффузию при работе с изображениями:  

  • Прямой процесс: начинаем с чистого образа x₀, затем на каждом шаге небольшими порциями добавляем шум (обычно гауссовский), до тех пор, пока не получим почти чистый шум.

  • Обратный процесс: обучаем модель (зачастую U‐Net) для прогнозирования шума на каждом шаге, постепенно восстанавливая исходное изображение, выполняя дискретные шаги денойзинга.

Применительно к языку эта идея означает, что сначала нужно добавить шум к тексту, а затем поэтапно удалять этот шум. Проще всего это сделать, обрабатывая шум на основе маскировки:

  • Прямой процесс (маскировка):

    • На шаге t = 0 имеем совершенно неповреждённую текстовую последовательность.

    • На каждом последующем шаге t > 0 случайным образом заменяем некоторую долю токенов специальным токеном <MASK>, придерживаясь заранее спланированного порядка (напр., постепенно повышая пропорцию скрытых токенов от 0% до 100%).

    • К наступлению последнего этапа T вся последовательность может быть замаскирована (все токены являются <MASK>).

  • Обратный процесс (денойзинг):

    • Обучаем модель (зачастую это стандартный трансформерный энкодер) прогнозировать ID оригинальных токенов при наличии частично замаскированной последовательности на этапе t.

    • Ситуация такова, как будто мы моделируем язык с применением различных уровней маскировки. На первых шагах маской скрыты лишь отдельные токены (прогнозировать легко), а на последних — уже многие (прогнозировать сложнее).

    • Объединяя по цепочке прогнозы от шагов с сильной маскировкой вплоть до шагов с нулевой, можно восстановить (или сгенерировать) полную последовательность.

В этом фреймворке для дискретной диффузии текста модель изучает вероятности с привязкой к распределению данных. Она оптимизирует вызванные денойзингом потери в масштабе всех проделанных шагов, а не ограничивается отдельной целью по генерации пропущенных токенов и фиксированной вероятностью маскировки.

Как видим, цель такого маскировочного моделирования языка при использовании модели BERT точно такая же, как и при диффузии текста, но только для некоторого подмножества уровней маскировки. Вводя переменные уровни маскировки (от 0 до 1) и планируя последовательность шагов денойзинга по расписанию, можно преобразовать подход BERT, при котором язык моделируется через маскировку токенов, в полноценную генеративную процедуру.

Диффузия в стиле RoBERTa

В 2019 году вышла модель RoBERTa. В сущности, это всего лишь улучшенная версия исходной модели BERT с более качественными гиперпараметрами, расширенным множеством учебных данных и упрощённой целью обучения (только генерация пропущенных токенов, без прогнозирования следующего предложения).

Здесь воспользуемся библиотеками transformers и dataset из HuggingFace, чтобы подтянуть исходные веса RoBERTa, токенизатор и класс Trainer. Всё это позволит нам без труда подогнать модель на множестве данных WikiText. Основной код (полный код выложен здесь) выглядит так:

# Загружаем и токенизируем множество данных, создаём экземпляр модели 
dataset = load_dataset("wikitext", "wikitext-2-raw-v1")
tokenizer = RobertaTokenizerFast.from_pretrained("roberta-base")
model = RobertaForMaskedLM.from_pretrained("roberta-base")

# Создаём аргументы обучения и экземпляр класса Trainer 
training_args = TrainingArguments(
    output_dir="finetuned-roberta-diffusion",
    overwrite_output_dir=True,
    num_train_epochs=NUM_EPOCHS,
    per_device_train_batch_size=BATCH_SIZE,
    save_strategy="epoch",
    save_total_limit=1,
    logging_steps=200,
)

trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=tokenized["train"],
    eval_dataset=tokenized["validation"],
    data_collator=diffusion_collator, # наша собственная реализация
    tokenizer=tokenizer,
)

# Обучаем и сохраняем
trainer.train()
trainer.save_model("finetuned-roberta-diffusion")

В настоящий момент имеем 10 шагов диффузии и случайным образом выбираем процент p из mask_probs (1.0, 0.9, 0.9, …, 0.1), после чего маскируем именно такой процент токенов в каждой порции из датасета. Написанная нами функция diffusion_collator(вот её код) выбирает одну вероятность маскировки p из mask_probs на порцию и именно с вероятностью p задаёт для токенов значение <MASK>.

Чтобы можно было подготовить почву для генерации «промпта», мы на данном этапе вообще не маскируем первые 16 токенов. Это означает, что в ходе обучения на каждом этапе первые 16 токенов всегда будут даваться в качестве контекста для дальнейшей генерации.

В упрощённом виде код diffusion_collator имеет вид:

  def diffusion_collator(examples):
      batch = tokenizer.pad(examples, return_tensors="pt")

      # Случайным образом выбираем вероятность маскировки для данной порции
      mask_prob = random.choice([1.0, 0.9, 0.8, 0.7, 0.6, 0.5, 0.4, 0.3, 0.2, 0.1])

      # Никогда не маскируем первые PREFIX_LEN токенов (они сохраняются для контекста)
      maskable_positions = batch.input_ids[:, PREFIX_LEN:]

      # Создаём случайную маску для выбранного значения вероятности
      mask = torch.rand(maskable_positions.shape) < mask_prob

      # Применяем маскировку
      batch.input_ids[:, PREFIX_LEN:][mask] = tokenizer.mask_token_id
      batch.labels = batch.input_ids.clone()

      return batch

Чтобы выполнить логический вывод (инференс), для начала возьмём на вход тензор размером 256 (поскольку мы генерируем блоки по 256 токенов в каждом). Первые 16 позиций заняты id тех токенов, которые входят в промпт, а оставшиеся 240 токенов — это просто <MASK>. На каждом шаге работаем по запланированному расписанию денойзинга, затем генерируем прогноз, после чего вновь перемаскируем всю последовательность. Процесс устроен так:

Step 0: [PREFIX] <mask> <mask> <mask> <mask> <mask> ...     (100% masked)
Step 1: [PREFIX] will <mask> over <mask> control ...        (90% masked)
Step 2: [PREFIX] will begin <mask> greater control ...      (80% masked)
...
Step 10: [PREFIX] will begin to assert greater control ...  (0% masked - DONE)

Упрощённый код генерации выглядит так:

# Генерируем текст методом итеративного денойзинга
for step, mask_prob in enumerate(mask_probs):
    # Прямой проход: замаскированные токены
    with torch.no_grad():
        outputs = model(input_ids=input_ids, attention_mask=attention_mask)
        predictions = outputs.logits  # форма: (1, MAX_LEN, vocab_size)

    # Для каждой замаскированной позиции делаем выборку из распределения, отфильтрованного по top-k/top-p 
    for pos in range(PREFIX_LEN, MAX_LEN):
        if input_ids[0, pos] == tokenizer.mask_token_id:
            logits = predictions[0, pos, :]
            # Применяем фильтры top-k и top-p 
            filtered_logits = top_k_top_p_filtering(logits, top_k=TOP_K, top_p=TOP_P)
            probs = F.softmax(filtered_logits, dim=-1)
            # Выбираем токен
            sampled_token = torch.multinomial(probs, 1)
            input_ids[0, pos] = sampled_token

    # Перемаскируем порцию токенов, не относящихся к префиксу, и так далее на каждой итерации 
    if mask_prob > 0:
        mask_indices = torch.rand(MAX_LEN - PREFIX_LEN) < mask_prob
        input_ids[0, PREFIX_LEN:][mask_indices] = tokenizer.mask_token_id

Вот пример вывода, сгенерированного тонко настроенной моделью после того, как она обучалась на H200 в течение 30 минут (первая строка — исходный промпт):

Following their victory in the French and Indian War, Britain began to assert
greater...

...dominion over Europe beginning about the early 19th. There conflict took
place on the island, between British and Irish Ireland. British officials 
administered British Ireland, a Celtic empire under the control of the Irish 
nationalist authorities, defined as a dominion of Britain. As the newly Fortic 
states acquired independent and powerful status, many former English colonies
played their part in this new, British-controlled colonial system. Following
this period the Non-Parliamentaryist Party won its influence in Britain in 
1890, led by the support of settlers from the Irish colonies. Looking inwards, 
Sinclair, Lewis questioned, and debated the need to describe " The New Britain "
Перевод DeepSeek

После своей победы в Франко-индейской войне Великобритания начала утверждать большее…

…господство над Европой, начиная примерно с начала XIX века. На острове произошел конфликт между британской и ирландской Ирландией. Британские чиновники управляли Британской Ирландией, кельтской империей, находящейся под контролем ирландских националистических властей, определяемой как доминион Британии. Поскольку новые Фортические государства приобрели независимый и мощный статус, многие бывшие английские колонии сыграли свою роль в этой новой, контролируемой Британией колониальной системе. После этого периода Непарламентаристская партия завоевала свое влияние в Британии в 1890 году при поддержке поселенцев из ирландских колоний. Обращаясь внутрь, Синклер Льюис ставил под вопрос и обсуждал необходимость описания «Новой Британии».

Вывод выглядит удивительно складно! Большинство странностей связаны просто с дефектами форматирования WikiText.

Ниже приведено сравнение нашей диффузионной модели с GPT-2:

Как видим, вывод GPT-2 получается чуть более согласованным и выдаётся немного быстрее (~9 секунд против ~13), но я приятно удивлён, насколько хорошо сработала моя простая реализация. Проверка концепции определённо удалась. Если применить такие новые подходы как AR-диффузия и диффузия с пропуском шагов (Skip-Step Diffusion), и также оптимизировать саму реализацию, то и качество, и скорость работы можно радикально улучшить.

Заключение

Мы убедились, что такие языковые модели как RoBERTa, работающие по принципу маскировки языка, исходно проектировались для «заполнения пробелов», но вполне могут быть переоборудованы в полноценные генеративные движки. Для этого нужно интерпретировать маскировку с переменной частотой как дискретный диффузионный процесс. Постепенно искажая текст вводом токенов <MASK> и обучая модель итеративному денойзингу такой информации при возрастающей интенсивности маскировки, мы фактически преобразуем стандартную цель генерации пропущенных токенов в процедуру пошаговой генерации.

Даже если не изменять архитектуру, тонко настроенная модель RoBERTa может генерировать вполне складный текст, если чуть-чуть подправить цель обучения. Тем самым подтверждается идея, что BERT-подобные модели — это просто диффузионные модели для работы с текстом, обученные при фиксированной частоте маскировки.

Комментарии (1)


  1. LinkToOS
    24.10.2025 11:34

    Целью обучения декодерных моделей является прогнозирование следующего токена. На каждой позиции t нужно предсказать токен, который будет находиться на позиции t+1 с учётом всех токенов, предшествующих t и служащих контекстом. Для предсказания будущих значений используется только та информация, которая находится слева от рассматриваемого токена (процесс однонаправленный).

    Контекстом здесь являются только предыдущие сгенерированные токены, или промп тоже?