Заключительная (но ещё не последняя) статья из цикла про диффузные модели, где мы наконец отбросим примитивную модель из полносвязных слоёв и напишем работающий генератор изображений c архитектурой Diffusion Transformer (DiT). Разберёмся зачем нарезать изображения на квадратики и увидим, что произойдёт с вашей генерацией, если проигнорировать главную "слабость" трансформеров - неспособность понимать порядок.
Очень кратко про трансформеры
Перед тем как наконец написать код нашей финальной модели неплохо бы для начала понять, что такое модель-трансформер и для для чего она нужна. Много есть в интернете и статей на эту тему, и не менее полусотни видео на YouTube разной степени запутанности, но я хочу рассмотреть архитектуру трансформера с более практической стороны - не углубляясь далеко в детали и интерпретации. Давайте начнём с того, что вспомним как выглядит поток информации в той модели, которая получилась у нас в прошлый раз. Та модель получала на вход тензор с изображением, превращала его в скрытое представление (одномерный вектор с 600 элементами) и, прогоняя это скрытое представление через несколько DenoisingBlock'ов, выдавала на выход тензор той же размерности, что был получен на вход. Если забыть про первое батч-измерение, то получается, что в момент прохода данных через модель вся информация была представлена "всего лишь" единственным вектором - тензором с шейпом (600).
Но ведь информацию можно представить не только в виде единственного вектора. Запись человеческого голоса или музыкальную композицию логичнее было бы представить в виде некой последовательности векторов - есть начало, конец и "направление". То же самое с текстом в LLM - токенизированный текст превращают в последовательность (sequence) векторов. В таких случаях тензор имеет шейп (N, C), где N - это количество элементов в последовательности, а C - размерность (величина) каждого элемента (вектора) в последовательности.
Тут важно сделать 2 замечания, касающихся информации, которая содержится в такой последовательности векторов.
Во-первых, само расположение векторов в последовательности - это важная информация, терять которую никак нельзя. Если переставить ноты в музыкальном отрывке или слова в предложении, то смысл всего этой последовательности изменится. Это значит, что какая бы модель эти данные не обрабатывала бы она должна каким-то образом учитывать позицию каждого вектора в последовательности.
Во-вторых, модель должна обработать всю информацию, которая содержится в последовательности - другими словами, не пропустить ни одного вектора. Да, знаю, звучит слишком очевидно, но давайте подумаем, как этого вообще можно сделать.
Представим, что наша задача - это создать модель, которая бы принимала на вход текст, превращённый в последовательность векторов, а на выходе бы выдавала скоринг - число в интервале [-1, 1], где минус единица означала бы резко негативное высказывание, ноль - нейтральное, а единица - позитивное. Sentiment analysis это ещё называют, классификатор текста, другими словами. И вот есть у нас на входе последовательность векторов - 3D тензор с шейпом (B, N, C). B - это измерение батча, но это мы уже знаем, а N, C - это матрица, где каждый столбец - это отдельный вектор в последовательности. Будем называть последовательность векторов, поступающую на вход модели матрица входных данных. Вот так бы выглядела матрица входных данных с шейпом (4, 7):
[] [] [] []
[] [] [] []
[] [] [] []
[] [] [] []
[] [] [] []
[] [] [] []
[] [] [] []
4 вектора по 7 значений (числа) в каждом.
Ладно, в каком виде данные идут на вход понятно. Давайте теперь подумаем как можно обработать данные в этой матрице так, чтобы получить интересующее нас число (степень позитивности/негативности). Авторы рекуррентных нейронных сетей (RNN) решили обрабатывать каждый вектор последовательно аккумулируя данные в промежуточное представление. Совсем как функция reduce():
скрытое_представление_0 - это пустое_скрытое_представление.
rnn(скрытое_представление_0, вектор_1) = скрытое_представление_1
rnn(скрытое_представление_1, вектор_2) = скрытое_представление_2
...
rnn(скрытое_представление_n-1, вектор_n) = скрытое_представление_n
С таким подходом информация, хранящаяся во всех векторах будет учтена в итоговом результате. Таким образом, последнее скрытое представление будет (в теории) содержать всю информацию, нужную для классификации (анализ тональности текста, например).
А вот авторы трансформеров решили подойти к задаче по-другому. Идея трансформеров в том, чтобы прогнать матрицу исходных данные через нейросеть (трансформер-блок) целиком и на выходе получить такую же матрицу, но слегка изменённую. Если выражаться точнее, провести исходную матрицу через цепочку таких трансформер-блоков трансформируя её. В результата после серии трансформаций у нас получится матрица того же размера, что и матрица входных данных, но содержащая всю необходимую для классификации информацию в, допустим, последнем векторе. Вы наверное спросите "Почему достаточно последнего вектора? Информация из всех остальных векторов, что, каким-то образом перетекла в него?" Ответа: да, именно для этого и предназначен трансформер блок - он даёт возможность векторам в последовательности обмениваться информацией (контекстуализироваться). Таким образом, условие "использовать данные всех векторов" обеспечивается тем, что, пройдя через цепочку трансформаций, каждый вектор "вберёт" в себя информацию всех остальных векторов в последовательности.
Не стоит сейчас сильно в это углубляться, поэтому запомните вот такие тезисы:
Нейросеть (трансформер-блок) получает на вход матрицу (последовательность векторов)
Внутри трансформер-блока происходит обмен информацией между векторами
На выходе у трансформер-блока последовательность такого же размера, но вектора обогатились информацией из всех остальных векторов в последовательности
Модель-трансформер состоит из цепочки таких трансформер-блоков
Возможно, я детальнее расскажу про работу трансформер-блока и механизм внимания, когда буду писать про то как современные Diffusion Transformers модели используют Relative Positional Bias и Rotary Positional Embeddings - для их работы надо вручную переписывать механизм внимания, поэтому придётся разбирать это всё достаточно глубоко. Но это в другой раз. Сейчас пора бы уже приступить к нашей финальной модели.
Создаём трансформер-блок
Возьмём за основу код, который получился у нас в предыдущей статье. Остановились мы вот на таком блоке:
class DenoiserBlock(nn.Module):
def __init__(self, hidden_dim, mlp_ratio, condition_dim):
super().__init__()
self.ln = nn.LayerNorm(hidden_dim)
self.mlp = nn.Sequential(
nn.Linear(hidden_dim, hidden_dim * mlp_ratio),
nn.SiLU(),
nn.Linear(hidden_dim * mlp_ratio, hidden_dim),
)
self.modulator_mlp = nn.Sequential(
nn.Linear(condition_dim, condition_dim * 4),
nn.SiLU(),
nn.Linear(condition_dim * 4, hidden_dim * 3),
)
nn.init.zeros_(self.modulator_mlp[-1].weight)
nn.init.zeros_(self.modulator_mlp[-1].bias)
def forward(self, x, c):
scale, shift, gate = self.modulator_mlp(c).chunk(3, dim=1)
z = self.ln(x)
z = z * (1 + scale) + shift
z = self.mlp(z)
z = z * gate
return z
Переименуем это класс в TransformerBlock и начнём превращать его в трансформер-блок. Для начала ещё раз взглянем на знакомую схему:

Может выглядит запутанно, но, вообще-то, большую часть этой схемы мы уже реализовали! Те части, которые я отметил зелёным - это, по сути, наш текущий код DenoiserBlock'а. Pointwise Feedforward - это наше mlp. А MLP справа внизу - это то, что мы у нас зовётся modulator_mlp. Ладно, теперь понятно, что осталось сделать:
Добавить второй LayerNorm.
Сделать так, чтобы
modulator_mlpвозвращал 6 значений, а не 3 как сейчас.Добавить Multi-Head Self-Attention - единственная новинка и сердце трансформера.
Соединить все слои, расставив в нужных местах skip-connections (плюсы на диаграмме)
Начнём по порядку. Вместо одного LayerNorm
self.ln = nn.LayerNorm(hidden_dim)
Делаем два:
self.ln_1 = nn.LayerNorm(hidden_dim, elementwise_affine=False)
self.ln_2 = nn.LayerNorm(hidden_dim, elementwise_affine=False)
Тут elementwise_affine=False отключает у LayerNorm внутренние scale и shift (на самом деле, ещё в предыдущей статье стоило бы это сделать, но как-то совсем из головы вылетело). Scale и shift внутри LayerNorm нам не нужны, так как мы вручную это делаем через модуляцию.
Теперь настала очередь Модулятора:
Вот тут
self.modulator_mlp = nn.Sequential(
nn.Linear(condition_dim, condition_dim * 4),
nn.SiLU(),
nn.Linear(condition_dim * 4, hidden_dim * 3),
)
Меняем 3 на 6 (и всё)
self.modulator_mlp = nn.Sequential(
nn.Linear(condition_dim, condition_dim * 4),
nn.SiLU(),
nn.Linear(condition_dim * 4, hidden_dim * 6),
)
А ещё переименуем self.mlp в self.ffn
Добавляем MultiheadAttention:
self.attn = nn.MultiheadAttention(
embed_dim=hidden_dim,
num_heads=num_heads,
batch_first=True, # это обязательно!
)
Это новый для нас блок, поэтому остановимся на нём поподробнее. Во-первых, принимает он на выход 3D-тензор с шейпом (B, N, C). Тут B - размер батча, это понятно. Потом идёт матрица (N, C) - N векторов каждый размером C элементов. А на выход получается тензор точно такого же шейпа, просто теперь каждый выходной вектор содержит "выжимку" от всех других векторов в последовательности. Но это ненужные сейчас детали, главное - это уяснить, что 3D-тензор на вход и такой же 3D тензор на выходе. Параметр embed_dim - это ожидаемый размер вектора (C). Количество векторов в последовательности N указывать не надо, так как nn.MultiheadAttention работает с последовательностями любой длины. Про смысл num_heads распишу подробнее в другой раз, скажу лишь, что это должно быть такое число, чтобы num_heads * 32 == embed_dims. А вот batch_first обязательно надо выставить в True. Если этого не сделать, то nn.MultiheadAttention будет ожидать, что на вход ему будут подавать тензор с шейпом (N, B, C). Это всё тяжелое наследие RNN, но помнить об этом надо, а то я однажды (при написании этой статьи) забыл написать batch_first=True у меня вся тренировка пошла насмарку.
Ладно, всё готов для последнего шага: переписать метод forward. Начинаем с модуляции
# Шейп x - (B, N, C). Шейп c - (B, Cond)
def forward(self, x, c):
mod = self.modulator_mlp(c) # (B, C * 6)
mod = mod.unsqueeze(1) # (B, 1, C * 6)
scale_1, shift_1, gate_1, scale_2, shift_2, gate_2 = mod.chunk(6, dim=2)
Каждый scale, shift и gate имеет шейп (B, 1, C). Зачем нам дополнительное измерение посередине? Вспомните, что на вход у нас батч матриц - батч последовательностей векторов с шейпом (B, N, C). Наши модуляторы должны делать shift и scale каждого вектора в этих последовательностях. Но так как применятся они будут следующим образом:
x * scale или x + shift, то количество измерений у них должно совпадать, чтобы pytorch правильно произвёл broadcasting. Для этого и "лишнее" второе измерение.
Продолжаем по схеме выше:
def forward(self, x, c):
mod = self.modulator_mlp(c) # (B, C * 6)
mod = mod.unsqueeze(1) # (B, 1, C * 6)
scale_1, shift_1, gate_1, scale_2, shift_2, gate_2 = mod.chunk(6, dim=2)
z = self.ln1(x) # LayerNorm
z = z * (1 - scale_1) + shift_1 # Scale, Shift
z, _ = self.attn(z, z, z) # Multi-Head Self-Attention
z = z * gate_1 # На схеме Scale, а так это Gate
x_after_attn = x + z # Первый (+)
Результат очень похоже на то, что у нас уже было, только вместо MLP тут MultiheadAttention.
Давайте объясню про эту строку
z, _ = self.attn(z, z, z)
Зачем мы аж 3 раза отправляем наши данные в self.attn? Дело в том, что nn.MultiheadAttention может быть использована как для расчёта self-attention, когда векторы обмениваются информацией друг с другом, так и для cross-attention - когда векторы получают информацию из другой последовательности векторов. В таком случае вызов выглядел бы как-то так: self.attn(z, other_z, other_z). Не буду вдаваться в подробности что каждый аргумент значит, потому что не хочу поверхностно рассказывать про механизм внимания. Лучше дождитесь статьи про RoPE - там будет всё будет разобрано в деталях.
Помимо этого, заметили, что возвращается нам кортеж (tuple) из двух значений? Нам нужно только первое - это информация, которой векторы поделились друг с другом.
Добавляем последний кусок:
def forward(self, x, c):
mod = self.modulator_mlp(c) # (B, C * 6)
mod = mod.unsqueeze(1) # (B, 1, C * 6)
scale_1, shift_1, gate_1, scale_2, shift_2, gate_2 = mod.chunk(6, dim=2)
# Вектора обмениваются информацией.
z = self.ln1(x) # LayerNorm
z = z * (1 - scale_1) + shift_1 # Scale, Shift
z, _ = self.attn(z, z, z) # Multi-Head Self-Attention
z = z * gate_1 # На схеме Scale, а так это Gate
x_after_attn = x + z # Первый (+)
# Вектора "переваривают" полученную информацией.
z = self.ln2(x_after_attn) # LayerNorm
z = z * (1 - scale_2) + shift_2 # Scale, Shift
z = self.ffn(z) # Pointwise Feedforward
z = z * gate_2 # На схеме Scale, а на деле Gate
return x_after_attn + z # Второй (+)
Выстроенные друг за другом такие TransformerBlockи (слои) представляют собой вот такую цепочку трансформаций:
Вектора обмениваются информацией
|
Вектора осваивают новую информацию
|
Вектора обмениваются информацией
|
Вектора осваивают новую информацию
|
Вектора обмениваются информацией
|
Вектора осваивают новую информацию
|
...
В общем-то, эта цепочка и есть трансформер, а TransformerBlock теперь выглядит вот так:
Класс TransformerBlock
class TransformerBlock(nn.Module):
def __init__(self, hidden_dim, num_heads, mlp_ratio, condition_dim):
super().__init__()
self.ln1 = nn.LayerNorm(hidden_dim, elementwise_affine=False)
self.ln2 = nn.LayerNorm(hidden_dim, elementwise_affine=False)
self.attn = nn.MultiheadAttention(embed_dim=hidden_dim, num_heads=num_heads, batch_first=True)
self.ffn = nn.Sequential(
nn.Linear(hidden_dim, hidden_dim * mlp_ratio),
nn.SiLU(),
nn.Linear(hidden_dim * mlp_ratio, hidden_dim),
)
self.modulator_mlp = nn.Sequential(
nn.Linear(condition_dim, condition_dim * 4),
nn.SiLU(),
nn.Linear(condition_dim * 4, hidden_dim * 6),
)
nn.init.zeros_(self.modulator_mlp[-1].weight)
nn.init.zeros_(self.modulator_mlp[-1].bias)
def forward(self, x, c):
mod = self.modulator_mlp(c) # (B, hidden_dim * 6)
mod = mod.unsqueeze(1) # (B, 1, hidden_dim * 6)
scale_1, shift_1, gate_1, scale_2, shift_2, gate_2 = mod.chunk(6, dim=2)
z = self.ln1(x)
z = z * (1 - scale_1) + shift_1
z, _ = self.attn(z, z, z)
z = z * gate_1
x_after_attn = x + z
z = self.ln2(x_after_attn)
z = z * (1 - scale_2) + shift_2
z = self.ffn(z)
z = z * gate_2
return x_after_attn + z
Основная модель
Теперь надо переделать основной класс Denoiser. Сейчас он выглядит вот так:
class Denoiser(nn.Module):
def __init__(self, hidden_dims, num_blocks, condition_dim):
super().__init__()
self.input_encoder = nn.Sequential(
nn.Flatten(start_dim=1),
nn.Linear(SIZE * SIZE, hidden_dims),
)
self.class_embeddings = nn.Embedding(num_classes, condition_dim)
self.class_mlp = nn.Sequential(
nn.Linear(condition_dim, condition_dim * 4),
nn.SiLU(),
nn.Linear(condition_dim * 4, condition_dim),
)
block_list = [DenoiserBlock(hidden_dims, 4, condition_dim) for _ in range(num_blocks)]
self.blocks = nn.ModuleList(block_list)
self.output_decoder = nn.Sequential(
nn.Linear(hidden_dims, SIZE * SIZE),
nn.Unflatten(1, (1, SIZE, SIZE)),
)
self.time_linear = nn.Sequential(
nn.Linear(1, condition_dim),
nn.LayerNorm(condition_dim)
)
def forward(self, x, t, c):
hidden = self.input_encoder(x)
time_embedding = self.time_linear(t)
class_embedding = self.class_embeddings(c)
class_condition = self.class_mlp(class_embedding)
condition = time_embedding + class_condition
for block in self.blocks:
hidden = hidden + block(hidden, condition)
return self.output_decoder(hidden)
Первым делом добавляем в конструктор новый параметр num_heads:
def __init__(self, hidden_dims, num_heads, num_blocks, condition_dim):
И правим создание трансформер-блоков:
block_list = [TransformerBlock(hidden_dims, num_heads, 4, condition_dim) for _ in range(num_blocks)]
А вот с input_encoderом будет сложнее. Напоминаю, Diffusion Transformer работает с последовательностью векторов - тензор шейпом (B, N, C). А вот на вход модели будет подаваться тензор чёрно-белого изображений с шейпом (B, 1, 24, 24). Чтобы вообще запустить наш трансформер, входной тензор надо превратить в последовательность векторов. Для этого воспользуемся давно отработанным (ещё со времён Visual Transformers) приёмом.
Идея простая:
-
Разбить наше изображение на несколько кусков (патчей):

Из каждого патча (в нашем случае это маленькая матрица 2x2) сформировать вектор длиной 4. Превратить двумерный тензор в одномерный, другими словами. А потом объединить эти вектора в последовательность.
С помощью nn.Linear(4, hidden_dim) спроецировать каждый вектор пикселей в вектор скрытого представления. И получится у нас как раз нужный нам вектор (B, N, C), где N - это количество патчей (в нашем случае (24 / 2) * (24 / 2) == 144), а C - размерность вектора скрытого представления hiddent_dim.
Попробуем это реализовать в коде. Для "нарезки" тензора изображения на прямоугольные патчи в PyTorch уже есть готовый класс nn.Unfold. А работает он вот так:
# Код для примера
unfold = nn.Unfold(kernel_size=2, stride=2)
# Шейп переменной x - (B, 1, 24, 24)
y = unfold(x)
# Шейп переменной y - (B, 1*4, 24/2 * 24/2) == (B, 4, 144)
another_unfold = nn.Unfold(kernel_size=3, stride=3)
z = y = another_unfold(x)
# Шейп переменной z - (B, 1*9, 24/3 * 24/3) == (B, 9, 64)
Здесь nn.Unfold(kernel_size=2, stride=2) нарежет входной тензор изображения на патчи 2x2 и объединит два пространственных измерения в одно. Правда, вернёт он нам тензор с шейпом (B, pixel_vector_size, N), где последнее измерение - это длина последовательности. А нам надо, чтобы длина последовательности была вторым измерением. Для этого есть функция permute:
# x.shape == (B, 4, 144)
y = x.permute(0, 2, 1) # меняем местами второе и третье измерения
# x.shape == (B, 144, 4)
После этого останется лишь проецировать вектор длиной 4 на вектор длиной hidden_dim. Тут всё просто через обычный nn.Linear(4, hidden_dim). Вот так это всё будет выглядеть в нашем классе:
Вместо
self.input_encoder = nn.Sequential(
nn.Flatten(start_dim=1),
nn.Linear(SIZE * SIZE, hidden_dims),
)
Делаем
self.input_patcher = nn.Unfold(kernel_size=2, stride=2)
self.input_projector = nn.Linear(4, hidden_dims)
А внутри метода forward заменяем
hidden = self.input_encoder(x)
на
input_patches = self.input_patcher(x)
input_seq = input_patches.permute(0, 2, 1)
hidden = self.input_projector(input_seq)
С input_encoderом разобрались, теперь надо также адаптировать ouput_encoder. Он должен будет превращать последовательность векторов обратно в одноканальное (чёрно-белое) изображение. Хорошо, что в PyTorch уже есть класс "обратный" nn.Unfold - это nn.Fold:
# Код для примера
fold = nn.Fold(output_size=(SIZE, SIZE), kernel_size=2, stride=2)
# x.shape == (B, hidden_dims, 144)
y = fold(x)
# y.shape == (B, hidden_dims/4, SIZE, SIZE)
Только тут одна загвоздка - nn.Fold хоть и сливает патчи в тензор изображения, число каналов при этом будет четверть от hidden_dims. А нам надо чтобы число каналов было 1. А так как hidden_dims явно побольше 4-х будет, то нам надо либо уменьшить число каналов до операции Fold, либо сделать это уже после Fold'a, обработав тензор изображения через свёрточную сеть. На самом деле так и поступим:
nn.Conv2d(in_channels=hidden_dims/4, out_channels=1, kernel_size=3, padding=1)
Про свёрточные сети и nn.Conv2d прочитайте где-нибудь отдельно. Хоть в документации PyTorch. Если я ещё их начну объяснять, то и так уже большая статья окончательно все берега потеряет. А так эта операция просто сделает из тензора с шейпом (B, hidden_dims/4, 24, 24) тензор одноканального изображения (B, 1, 24, 24) - именно то, что нам и нужно.
Меняем код. Теперь вместо
self.output_decoder = nn.Sequential(
nn.Linear(hidden_dims, SIZE * SIZE),
nn.Unflatten(1, (1, SIZE, SIZE)),
)
Пишем
self.output_decoder = nn.Sequential(
nn.Fold(output_size=(SIZE, SIZE), kernel_size=2, stride=2),
nn.Conv2d(in_channels=hidden_dims//4, out_channels=1, kernel_size=3, padding=1),
)
А в самом конце метода forward
return self.output_decoder(hidden)
Заменяем на
output_seq = hidden.permute(0, 2, 1)
return self.output_decoder(output_seq)
Как теперь выглядит наш Denoiser класс
class Denoiser(nn.Module):
def __init__(self, hidden_dims, num_heads, num_blocks, condition_dim):
super().__init__()
self.input_patcher = nn.Unfold(kernel_size=2, stride=2)
self.input_projector = nn.Linear(4, hidden_dims)
self.class_embeddings = nn.Embedding(num_classes, condition_dim)
self.class_mlp = nn.Sequential(
nn.Linear(condition_dim, condition_dim * 4),
nn.SiLU(),
nn.Linear(condition_dim * 4, condition_dim),
)
block_list = [TransformerBlock(hidden_dims, num_heads, 4, condition_dim) for _ in range(num_blocks)]
self.blocks = nn.ModuleList(block_list)
self.output_decoder = nn.Sequential(
nn.Fold(output_size=(SIZE, SIZE), kernel_size=2, stride=2),
nn.Conv2d(in_channels=hidden_dims//4, out_channels=1, kernel_size=3, padding=1),
)
self.time_linear = nn.Sequential(
nn.Linear(1, condition_dim),
nn.LayerNorm(condition_dim)
)
def forward(self, x, t, c):
input_patches = self.input_patcher(x)
input_seq = input_patches.permute(0, 2, 1)
hidden = self.input_projector(input_seq)
time_embedding = self.time_linear(t)
class_embedding = self.class_embeddings(c)
class_condition = self.class_mlp(class_embedding)
condition = time_embedding + class_condition
for block in self.blocks:
# теперь мы просто передаём результат работы одного блока
# на вход другому блоку
hidden = block(hidden, condition)
output_seq = hidden.permute(0, 2, 1)
return self.output_decoder(output_seq)
Можно приступить к обучению, но сначала надо внести несколько изменений в код тренировки.
Там где мы определяем subset надо вместо
subset = Subset(dataset, torch.randperm(4096 * 2))
написать
subset = dataset
Будем использовать для обучения не срез из 9 тысяч сэмплов, а весь датасет из 112 тысяч. Потому что Diffusion Transformer слишком уж хорошо обучается по сравнению с предыдущей нашей моделью и запросто может осилить обучающий набор такого размера.
Теперь инициализация модели:
model = Denoiser(hidden_dims=16*8, num_heads=8, num_blocks=8, condition_dim=32)
Я тут немного перегнул с количеством голов, можно обойтись и меньшим количеством. А длина вектора получается 128, это означает, что наше скрытое представление - это матрица (144, 128) - последовательность из 144 векторов каждый размеров 128. Заметьте, что размер скрытого представления в DiT модели намного больше - 144 x 128 == 18432 элементов в матрице, против вектора 600 элементами из нашей предыдущей модели, при том что сама DiT-модель меньше раза в 3 и работает лучше, но это я забегаю вперёд.
Осталось лишь задать гиперпараметры:
BATCH_SIZE = 128
LR = 6e-4
DEVICE = 'cuda'
EPOCHS = 10
EPOCHS поменьше, потому что обучающий набор теперь больше. Всё готово к тренировке, запускаем и..

Вот так да! Не работает.
Дело в том, что написанный нами трансформер полностью игнорирует критически важную часть информации, которая содержится во входной матрице данных - информацию о расположении векторов в последовательности. Сейчас поясню на простом примере. Допустим, у нас есть вот такая последовательность из 4-х векторов:
[1] [8] [2] [0]
[1] [6] [3] [1]
[5] [8] [6] [5]
Вспоминаем, что наш трансформер - это просто цепочки чередующихся MHA (Multi-Head Attention) и FNN (Feedforward Network) слоёв, где FNN работает на уровне отдельных векторов. Что я имею ввиду? Для матрицы с шейпом (4, 3) как здесь FNN слой выглядел бы вот так:
nn.Sequential(
nn.Linear(3, 12),
nn.SiLU(),
nn.Linear(12, 3),
)
Этот слой обрабатывает каждый из 4-х векторов независимо друг от друга, даже параллельно я скажу. Как и задумано, ведь для обмена информацией между векторами у нас есть MHA слои. И вот тут то и вылезает наша проблема - механизм Attention сам по себе никак не учитывает порядок векторов в последовательности. Другими словами, если бы мы поменяли местами вектора в последовательности, то результат работы MHA-слоя для каждого вектора не изменился бы:
до MHA до MHA
[1] [8] [2] [0] [8] [2] [0] [1]
[1] [6] [3] [1] [6] [3] [1] [1]
[5] [8] [6] [5] [8] [6] [5] [5]
после MHA после MHA
[0] [1] [0] [0] [1] [0] [0] [0]
[1] [0] [1] [0] [0] [1] [0] [1]
[2] [0] [1] [2] [0] [1] [2] [2]
А раз перестановка векторов местами не влияет на то, какую информацию получит каждый вектор после прохода через MHA-слой (а это единственное место, где отдельный вектор может как-то что-то узнать о том, что помимо него в последовательности вообще есть другие несущие информацию вектора), то для нашего трансформера все вот эти входные изображению будут "выглядеть" совершенно идентично:

Из-за этого всё обучение модели идёт насмарку. Но выход есть!
На самом деле, есть несколько способов внедрить в модель информацию о позиции векторов в последовательности. Вот только чтобы реализовать наиболее продвинутые "техники" такие как Relative Positional Bias или ставший уже повсеместными RoPE надо переопределять механизм внимания. Эти темы мы разберём в следующих статьях, а сейчас решим проблему самым простым способом - через Absolute Positional Embeddings.
Работать это будет следующим образом: как только мы формируем матрицу входных данных (последовательность векторов) из изображения, поданного методу forward, надо к этой матрице прибавить матрицу (тензор) такого же размера. В этой матрице и будет содержаться информация о том, какую позицию занимает вектор в последовательности. И так как эта матрица всё время будет одна и так же, модель, а точнее трансформер-блоки научатся "вычленять" эту информацию из переданной им последовательности векторов. Словами не так понятно выходит, поэтому смотрим на код:
# Сразу после
hidden = self.input_projector(input_seq)
# Прибавляем матрицу (pos_embeddings)
hidden = hidden + self.pos_embeddings * self.pos_embeddings_scale
А self.pos_embeddings и self.pos_embeddings_scale определяем в конструкторе:
self.pos_embeddings = nn.Parameter(data=torch.randn(144, hidden_dims))
self.pos_embeddings_scale = nn.Parameter(torch.zeros(1))
torch.randn(144, hidden_dims) создаст тензор совпадающий по форме с нашей последовательностью векторов, тут понятно. А nn.Parameter нужен для того, чтобы во время обучения наша DiT модель могла адаптировать (выучить) матрицу positional embeddings. Можно было бы оставить pos_embeddings как статичный тензор, но тогда бы torch.randn не подошёл бы. Идём дальше. Уже догадались для чего нужен self.pos_embeddings_scale? Это тоже обучающийся тензор из одного единственного значения, а инициализируется он нулём для того, чтобы в начале обучения исключить влияние self.pos_embeddings на hidden. Если этого не сделать, то hidden буквально потонет в шуме и первые шаги тренировки будут очень нестабильными. Нет, конечно, потом все градиенты выровняются, но какой-то ресурс тренировки будет потерян. В общем, это просто способ немного ускорить тренировку.
Ладно, теперь всё готово для следующей попытки. И на этот раз получилось гораздо лучше:

Уже на 6-й эпохе модель обучилась всем 47 классам.
А теперь сравним, что выдаст нам модель из предыдущей статьи, если тоже тренировать её на всём датасете:
Результат полносвязной модели

Вот здесь финальная версия кода
Что дальше?
Вы думаете, что написали мы модель и всё на этом? Как раз наоборот, сейчас на руках у нас прототип Diffusion Transformer'a - идеальный "модельный организм", на котором можно продолжать экспериментировать добавляя фичи "взрослых" моделей.
Первое, что приходит в голову - это заменить имеющийся Absolute Positional Embeddings чем-то более современным, но сразу же бросаться имплементировать популярный Rotary Positional Embeddings (RoPE) будет слишком сложно, поэтому стоит начать с чего-то попроще, а конкретно с Relative Positional Bias, на примере которого разобрать и сам механизм внимания (сердце трансформера) и наглядно показать, каким образом можно модифицировать этот механизм, чтобы модель воспринимала информацию о взаимном расположении векторов в последовательности.
В следующих статьях я подробно разберу механизм внимания и расскажу о современных способах внедрить в модель информацию о взаимном расположении векторов, а также попробуем создать Autoregressive Transformer, который будет генерить нам изображения токен за токеном совсем как GPT. Продолжение следует.
Бонус. Анимация инференса в 40 шагов:
