Обо мне

Привет, меня зовут Василий Техин. В первой статье мы разобрали ResNet, во второй — ViT. Теперь погрузимся в мир генерации изображений с Diffusion Transformer (DiT) — сердцем Stable Diffusion 3.


Пролог: От распознавания к созданию

Представьте нейросеть как художника. Раньше она только анализировала картины ("Это Ван Гог!"). Теперь она создаёт шедевры в стиле Ван Гога и не только!

Изображения из статьи
Изображения из статьи

Ключевые этапы работы DiT:

  1. Обучение:

    • Сжимаем изображение в латентное пространство через VAE (256х256х3 → 32х32х4)

    • Добавляем шум за 1000 шагов (чтобы модель училась удалять шум постепенно)

    • DiT учится предсказывать шум на каждом шаге

  2. Генерация (инференс):

    • Начинаем с чистого шума

    • Постепенно удаляем шум за 1000 шагов

    • Декодируем результат через VAE


Пайплайн обучения и генерации

1. Подготовка данных (VAE)

VAE (Variational Autoencoder) сжимает изображение:

# Для изображения 256x256:
original = (3, 256, 256) → latent = (4, 32, 32)  # Сжатие в 64 раза

Зачем? DiT работает с 32×32×4 латентными векторами — экономия вычислений!

2. Прямой процесс (добавление шума)

Процесс зашумления
Процесс зашумления

1000 шагов постепенного зашумления по формуле:

def forward_diffusion(z0, t, T=1000):
    alpha_t = cos((t/T + 0.008) / 1.008 * π/2)**2  
    noise = torch.randn_like(z0)  # Случайный шум
    z_t = sqrt(alpha_t) * z0 + sqrt(1-alpha_t) * noise  # Зашумленная версия
    return z_t, noise

Где:

  • z0 — исходный латентный вектор изображения

  • t — текущий шаг (1-1000)

  • noise — добавленный шум

3. Обратный процесс (обучение DiT)

Ключевые шаги обучения:

  1. Выбираем случайный шаг t (1-1000)

  2. Зашумляем латентный вектор: z_t, real_noise = forward_diffusion(z0, t)

  3. Подаем в DiT: pred_noise = DiT(z_t, t, text_embed) и получаем предсказанный шум

  4. Считаем MSE-лосс: loss = (real_noise - pred_noise).square().mean()

  5. Обновляем веса через backpropagation

Обратите внимание: DiT учится предсказывать оригинальный шум, а не изображение!

4. Генерация изображений (инференс)

Пошаговый процесс для Stable Diffusion 3(отличается от DiT из оригинальной статьи тем, что подается эмбеддинг текста вместо метки класса):

def generate(prompt, steps=1000):
    # 1. Текстовый эмбеддинг
    text_embed = text_encoder(prompt)  # [1, 768]
    
    # 2. Начальный шум
    z = torch.randn(1, 4, 32, 32)  # z_T
    
    # 3. Итеративное удаление шума
    for t in range(steps, 0, -1):
        # a) Предсказание шума DiT
        pred_noise = DiT(z, t, text_embed)
        
        # b) Classifier-Free Guidance (CFG) - усиление текстового влияния
        if cfg_scale > 1.0:
            uncond_embed = text_encoder("")  # Пустой промпт
            uncond_noise = DiT(z, t, uncond_embed)
            pred_noise = uncond_noise + cfg_scale * (pred_noise - uncond_noise)
        
        # c) Формула обратного шага (DDIM)
        alpha_t = cos((t/steps + 0.008)/1.008 * π/2)**2
        alpha_prev = cos(((t-1)/steps + 0.008)/1.008 * π/2)**2
        z = (z - (1 - alpha_t)/sqrt(1 - alpha_t) * pred_noise) / sqrt(alpha_t)
        z += sqrt(1 - alpha_prev) * torch.randn_like(z)  # Стохастичность
        
    # 4. Декодирование через VAE
    return VAE.decode(z)  # [1, 3, 256, 256]

DiT в деталях: Отличия от ViT

1. Patchify: Работа с латентами

Мы нарезаем на патчи не оригинальное изображение, а латентный вектор

# Для латента 32x32x4 с патчами 2x2:
self.patch_embed = nn.Conv2d(4, dim, kernel_size=2, stride=2)
# → [batch, 256, dim]  (16*16=256 патчей)

Сравнение с ViT: ViT работает с пикселями, DiT — с латентными векторами.

2. Classifier-Free Guidance (CFG)

Механизм усиления текста мы хотим, чтобы изображение из шума соответсвовало тексту, который мы передали:

pred_noise = uncond_noise + guidance_scale * (text_noise - uncond_noise)

Где:

  • uncond_noise — предсказание для пустого промпта

  • text_noise — предсказание для целевого промпта

  • guidance_scale (7-10) — сила влияния текста

3. Cross-Attention Block

В SD3 (не в оригинальном DiT):

class CrossAttentionBlock(nn.Module):
    def forward(self, x, text_emb):
        # Проекция текста
        q = self.wq(x)  # [batch, tokens, dim]
        k = self.wk(text_emb)  # [batch, text_tokens, dim]
        v = self.wv(text_emb)
        
        # Attention
        attn = softmax(q @ k.transpose(-2,-1) / sqrt(dim))
        return attn @ v  # Текст-условные признаки

Зачем? Точнее связывает текст и визуальные патчи.

4. In-Context Conditioning

Механизм в DiT-XL:

  • Ввод текстовых токенов как патчей

  • Пример: [IMAGE_PATCH1, TEXT_TOKEN1, IMAGE_PATCH2, ...]

  • Позволяет смешивать текст и изображение на входе

5. AdaLN-Zero

Улучшение в DiT-2:

  • Инициализация параметров γ в AdaLN нулями

  • Первые шаги обучения: AdaLN = Identity Function

  • Стабилизирует раннее обучение


Разберём на простом примере

Оригинальный DiT (класс-условный)

Uncurated 512 × 512 DiT-XL/2 samples. Classifier-free guidance scale = 2.0 Class label = “panda” (388)
Uncurated 512 × 512 DiT-XL/2 samples. Classifier-free guidance scale = 2.0 Class label = “panda” (388)
# Генерация "собаки" (класс 207)
class_label = 207
z = torch.randn(1, 4, 32, 32)  # Начальный шум

for t in range(1000, 0, -1):
    pred_noise = DiT(z, t, class_label)  # Прямой вызов
    z = update_step(z, pred_noise, t)  # Обновление латента

Особенности:

  • Простой ввод класса вместо текста

  • Нет CFG и Cross-Attention

Stable Diffusion 3 (текст-условный)

prompt = "пудель в розовой шапочке"
text_embed = text_encoder(prompt)  # [1, 768]

for t in range(1000, 0, -1):
    # 1. Основное предсказание
    pred_noise = DiT(z, t, text_embed)
    
    # 2. Classifier-Free Guidance (усиление текста)
    uncond_noise = DiT(z, t, text_encoder(""))
    pred_noise = uncond_noise + 7.5 * (pred_noise - uncond_noise)
    
    # 3. Cross-Attention в некоторых блоках
    # (см. архитектуру ниже)

Нововведения SD3(относительно DiT):

  • Текст через T5 вместо классов

  • CFG с масштабом 7.5 для точного следования промпту


Оценка качества: Метрики

1. FID (Fréchet Inception Distance)

Как работает:

  1. Берем 50k реальных и 50k сгенерированных изображений

  2. Пропускаем через Inception-v3 (получаем признаки)

  3. Считаем "расстояние" между распределениями:

FID = ||μ_real - μ_gen||^2 + Tr(Σ_real + Σ_gen - 2(Σ_real Σ_gen)^{1/2})

Интерпретация:

  • FID = 0 — идеальное совпадение

  • FID < 5 — фотореалистичные изображения

  • DiT-XL: FID = 2.27 (ImageNet 256x256)

2. IS (Inception Score)

IS = exp(E_x[KL(p(y|x) || p(y))])

Где:

  • p(y|x) — распределение классов для изображения

  • p(y) — общее распределение классов

  • Высокий IS = разнообразные и узнаваемые изображения


Почему DiT — это будущее Stable Diffusion?

✅ Преимущества перед U-Net:

Параметр

U-Net (SD 2.1)

DiT (SD 3)

Качество (FID)

3.85

2.27

Масштабируемость

Ограничена

Линейный рост

Разрешение

768x768

1024x1024

Текстовая привязка

Средняя

Точная

❌ Ограничения:

  1. Ресурсы: Обучение DiT-XL требует 500,000 GPU-hours

  2. Память: Генерация 1024px требует 48GB VRAM


Философский итог

DiT объединяет три революции ИИ:

  1. Сжатие данных (VAE)

  2. Трансформеры (ViT)

  3. Диффузионные процессы


Проверь себя

  1. Почему DiT работает с 32x32, а не 256x256?

  2. Как Classifier-Free Guidance улучшает генерацию?


Резюме

Diffusion Transformer (DiT):

  • Работает в латентном пространстве VAE (32x32x4)

  • Заменяет U-Net на трансформер с AdaLN

  • Оригинал: класс-условная генерация

Stable Diffusion 3:

  • Текст через текст энкодер и Cross-Attention

  • Classifier-Free Guidance для точности

  • Поддержка 1024px изображений

  • FID 2.27 — новый стандарт качества

Ссылки:

  1. Оригинальная статья DiT

  2. Stable Diffusion 3

  3. CFG в диффузионных моделях

Комментарии (5)


  1. Licemery
    02.07.2025 17:01

    Чего ж вы с провальным сд2 сравниваете, там и сд1.5 лучше выйдет при большом желании. Сравните с сдхл.

    > Генерация 1024px требует 48GB VRAM

    А как я на 12 гигах генерю?