Обо мне
Привет, меня зовут Василий Техин. В первой статье мы разобрали ResNet, во второй — ViT. Теперь погрузимся в мир генерации изображений с Diffusion Transformer (DiT) — сердцем Stable Diffusion 3.
Пролог: От распознавания к созданию
Представьте нейросеть как художника. Раньше она только анализировала картины ("Это Ван Гог!"). Теперь она создаёт шедевры в стиле Ван Гога и не только!

Ключевые этапы работы DiT:
- 
Обучение: - Сжимаем изображение в латентное пространство через VAE (256х256х3 → 32х32х4) 
- Добавляем шум за 1000 шагов (чтобы модель училась удалять шум постепенно) 
- DiT учится предсказывать шум на каждом шаге 
 
- 
Генерация (инференс): - Начинаем с чистого шума 
- Постепенно удаляем шум за 1000 шагов 
- Декодируем результат через VAE 
 
Пайплайн обучения и генерации
1. Подготовка данных (VAE)
VAE (Variational Autoencoder) сжимает изображение:
# Для изображения 256x256:
original = (3, 256, 256) → latent = (4, 32, 32)  # Сжатие в 64 раза
Зачем? DiT работает с 32×32×4 латентными векторами — экономия вычислений!
2. Прямой процесс (добавление шума)

1000 шагов постепенного зашумления по формуле:
def forward_diffusion(z0, t, T=1000):
    alpha_t = cos((t/T + 0.008) / 1.008 * π/2)**2  
    noise = torch.randn_like(z0)  # Случайный шум
    z_t = sqrt(alpha_t) * z0 + sqrt(1-alpha_t) * noise  # Зашумленная версия
    return z_t, noise
Где:
- z0— исходный латентный вектор изображения
- t— текущий шаг (1-1000)
- noise— добавленный шум
3. Обратный процесс (обучение DiT)
Ключевые шаги обучения:
- Выбираем случайный шаг - t(1-1000)
- Зашумляем латентный вектор: - z_t, real_noise = forward_diffusion(z0, t)
- Подаем в DiT: - pred_noise = DiT(z_t, t, text_embed)и получаем предсказанный шум
- Считаем MSE-лосс: - loss = (real_noise - pred_noise).square().mean()
- Обновляем веса через backpropagation 
Обратите внимание: DiT учится предсказывать оригинальный шум, а не изображение!
4. Генерация изображений (инференс)
Пошаговый процесс для Stable Diffusion 3(отличается от DiT из оригинальной статьи тем, что подается эмбеддинг текста вместо метки класса):
def generate(prompt, steps=1000):
    # 1. Текстовый эмбеддинг
    text_embed = text_encoder(prompt)  # [1, 768]
    
    # 2. Начальный шум
    z = torch.randn(1, 4, 32, 32)  # z_T
    
    # 3. Итеративное удаление шума
    for t in range(steps, 0, -1):
        # a) Предсказание шума DiT
        pred_noise = DiT(z, t, text_embed)
        
        # b) Classifier-Free Guidance (CFG) - усиление текстового влияния
        if cfg_scale > 1.0:
            uncond_embed = text_encoder("")  # Пустой промпт
            uncond_noise = DiT(z, t, uncond_embed)
            pred_noise = uncond_noise + cfg_scale * (pred_noise - uncond_noise)
        
        # c) Формула обратного шага (DDIM)
        alpha_t = cos((t/steps + 0.008)/1.008 * π/2)**2
        alpha_prev = cos(((t-1)/steps + 0.008)/1.008 * π/2)**2
        z = (z - (1 - alpha_t)/sqrt(1 - alpha_t) * pred_noise) / sqrt(alpha_t)
        z += sqrt(1 - alpha_prev) * torch.randn_like(z)  # Стохастичность
        
    # 4. Декодирование через VAE
    return VAE.decode(z)  # [1, 3, 256, 256]
DiT в деталях: Отличия от ViT
1. Patchify: Работа с латентами
Мы нарезаем на патчи не оригинальное изображение, а латентный вектор
# Для латента 32x32x4 с патчами 2x2:
self.patch_embed = nn.Conv2d(4, dim, kernel_size=2, stride=2)
# → [batch, 256, dim]  (16*16=256 патчей)
Сравнение с ViT: ViT работает с пикселями, DiT — с латентными векторами.
2. Classifier-Free Guidance (CFG)
Механизм усиления текста мы хотим, чтобы изображение из шума соответсвовало тексту, который мы передали:
pred_noise = uncond_noise + guidance_scale * (text_noise - uncond_noise)
Где:
- uncond_noise— предсказание для пустого промпта
- text_noise— предсказание для целевого промпта
- guidance_scale(7-10) — сила влияния текста
3. Cross-Attention Block
В SD3 (не в оригинальном DiT):
class CrossAttentionBlock(nn.Module):
    def forward(self, x, text_emb):
        # Проекция текста
        q = self.wq(x)  # [batch, tokens, dim]
        k = self.wk(text_emb)  # [batch, text_tokens, dim]
        v = self.wv(text_emb)
        
        # Attention
        attn = softmax(q @ k.transpose(-2,-1) / sqrt(dim))
        return attn @ v  # Текст-условные признаки
Зачем? Точнее связывает текст и визуальные патчи.
4. In-Context Conditioning
Механизм в DiT-XL:
- Ввод текстовых токенов как патчей 
- Пример: - [IMAGE_PATCH1, TEXT_TOKEN1, IMAGE_PATCH2, ...]
- Позволяет смешивать текст и изображение на входе 
5. AdaLN-Zero
Улучшение в DiT-2:
- Инициализация параметров γ в AdaLN нулями 
- Первые шаги обучения: AdaLN = Identity Function 
- Стабилизирует раннее обучение 
Разберём на простом примере
Оригинальный DiT (класс-условный)

# Генерация "собаки" (класс 207)
class_label = 207
z = torch.randn(1, 4, 32, 32)  # Начальный шум
for t in range(1000, 0, -1):
    pred_noise = DiT(z, t, class_label)  # Прямой вызов
    z = update_step(z, pred_noise, t)  # Обновление латента
Особенности:
- Простой ввод класса вместо текста 
- Нет CFG и Cross-Attention 
Stable Diffusion 3 (текст-условный)
prompt = "пудель в розовой шапочке"
text_embed = text_encoder(prompt)  # [1, 768]
for t in range(1000, 0, -1):
    # 1. Основное предсказание
    pred_noise = DiT(z, t, text_embed)
    
    # 2. Classifier-Free Guidance (усиление текста)
    uncond_noise = DiT(z, t, text_encoder(""))
    pred_noise = uncond_noise + 7.5 * (pred_noise - uncond_noise)
    
    # 3. Cross-Attention в некоторых блоках
    # (см. архитектуру ниже)
Нововведения SD3(относительно DiT):
- Текст через T5 вместо классов 
- CFG с масштабом 7.5 для точного следования промпту 
Оценка качества: Метрики
1. FID (Fréchet Inception Distance)
Как работает:
- Берем 50k реальных и 50k сгенерированных изображений 
- Пропускаем через Inception-v3 (получаем признаки) 
- Считаем "расстояние" между распределениями: 
FID = ||μ_real - μ_gen||^2 + Tr(Σ_real + Σ_gen - 2(Σ_real Σ_gen)^{1/2})
Интерпретация:
- FID = 0 — идеальное совпадение 
- FID < 5 — фотореалистичные изображения 
- DiT-XL: FID = 2.27 (ImageNet 256x256) 
2. IS (Inception Score)
IS = exp(E_x[KL(p(y|x) || p(y))])
Где:
- p(y|x)— распределение классов для изображения
- p(y)— общее распределение классов
- Высокий IS = разнообразные и узнаваемые изображения 
Почему DiT — это будущее Stable Diffusion?
✅ Преимущества перед U-Net:
| Параметр | U-Net (SD 2.1) | DiT (SD 3) | 
|---|---|---|
| Качество (FID) | 3.85 | 2.27 | 
| Масштабируемость | Ограничена | Линейный рост | 
| Разрешение | 768x768 | 1024x1024 | 
| Текстовая привязка | Средняя | Точная | 
❌ Ограничения:
- Ресурсы: Обучение DiT-XL требует 500,000 GPU-hours 
- Память: Генерация 1024px требует 48GB VRAM 
Философский итог
DiT объединяет три революции ИИ:
- Сжатие данных (VAE) 
- Трансформеры (ViT) 
- Диффузионные процессы 
Проверь себя
- Почему DiT работает с 32x32, а не 256x256? 
- Как Classifier-Free Guidance улучшает генерацию? 
Резюме
Diffusion Transformer (DiT):
- Работает в латентном пространстве VAE (32x32x4) 
- Заменяет U-Net на трансформер с AdaLN 
- Оригинал: класс-условная генерация 
Stable Diffusion 3:
- Текст через текст энкодер и Cross-Attention 
- Classifier-Free Guidance для точности 
- Поддержка 1024px изображений 
- FID 2.27 — новый стандарт качества 
Ссылки:
 
           
 
Licemery
Чего ж вы с провальным сд2 сравниваете, там и сд1.5 лучше выйдет при большом желании. Сравните с сдхл.
> Генерация 1024px требует 48GB VRAM
А как я на 12 гигах генерю?