С момента своего появления в 2017 году в публикации Attention is All You Need [1] трансформеры стали доминирующим подходом в обработке естественного языка (NLP). В 2021 году в статье An Image is Worth 16x16 Words [2] трансформеры были успешно адаптированы для задач компьютерного зрения. С тех пор для компьютерного зрения было предложено множество архитектур на основе трансформеров.

В этой статье мы рассмотрим трансформер зрения (Vision Transformer, ViT) в том виде, в котором он был представлен в статье [2]. Она включает в себя открытый код ViT, а также концептуальные объяснения компонентов. Реализация ViT, рассмотренная в статье, выполнена с использованием пакета PyTorch.

Оглавление:

  1. Что такое визуальные трансформеры?

  2. Знакомство с моделью

    1. Токенизация изображений

    2. Обработка токенов

    3. Кодирующий блок 

    4. Модуль нейронной сети

    5. Обработка предсказаний

  3. Код полностью

  4. Заключение. Полезные ссылки

Что такое визуальные трансформеры?

Как было сказано в публикации Attention is All You Need, трансформеры — это архитектура машинного обучения, использующая механизм самовнимания в качестве основного компонента для обучения. Трансформеры быстро стали передовым методом решения задач обработки последовательных данных, таких как перевод с одного языка на другой.

В работе An Image is Worth 16x16 Words трансформер, предложенный в работе Attention is All You Need [1], был успешно модифицирован для решения задач классификации изображений, что привело к созданию Vision Transformer (ViT). В основе ViT лежит тот же механизм самовнимания, что и в трансформере из [1]. Однако в отличие от оригинальной архитектуры трансформеров для NLP, которая включает кодировщик и декодер, ViT использует только кодировщик. Выход кодировщика передаётся в выходной слой, который отвечает за финальное предсказание.

Недостатком ViT, как показано в [2], является его зависимость от больших наборов данных для достижения оптимальной производительности. Лучшие модели предобучены на проприетарном датасете JFT-300M. Модели, предобученные на меньшем, открытом датасете ImageNet-21k, демонстрируют результаты на уровне современных свёрточных моделей, таких как ResNet.

Модель Tokens-to-Token ViT: Training Vision Transformers from Scratch on ImageNet [3] пытается устранить необходимость в предобучении, предлагая метод преобразования входного изображения в токены. Подробнее об этом методе можно прочитать здесь. В этой статье мы рассмотрим архитектуру ViT и её реализацию, предложенную в работе [2].

Знакомство с моделью

Эта статья следует структуре модели, описанной в статье An Image is Worth 16x16 Words [2]. Однако код, описанный в статье, не находится в открытом доступе. Код из более поздней публикации Tokens-to-Token ViT [3] доступен на GitHub. Модель Tokens-to-Token ViT (T2T-ViT) добавляет модуль Tokens-to-Token (T2T) к стандартной основе ViT. Код в этой статье основан на компонентах ViT из репозитория Tokens-to-Token ViT [3] на GitHub. Модификации, внесённые в эту реализацию, включают, но не ограничиваются: поддержку несоразмерных изображений и удаление слоёв Dropout.

Ниже представлена диаграмма модели ViT.

Диаграмма модели ViT (изображение автора)
Диаграмма модели ViT (изображение автора)

Токенизация изображений

Первым шагом ViT является создание токенов из входного изображения. Трансформеры работают с последовательностью токенов; в обработке естественного языка (NLP) каждый токен, как правило, представляет собой слово. В случае компьютерного зрения токенизация включает разбиение изображения на патчи фиксированного размера.

ViT преобразует изображение в токены таким образом, что каждый токен представляет собой локальную область — или патч — изображения. Они описывают преобразование изображения высотой H, шириной W и каналами C в N маркеров с размером маркера P:

Каждый токен имеет длину P²∗C, где P² — это количество пикселей в патче, а C — количество каналов.

Давайте рассмотрим пример токенизации патча на примере пиксель-арта «Закат в горах» Луиса Зуно (@ansimuz) [4]. Оригинальная работа была обрезана и преобразована в одноканальное изображение. Это означает, что каждый пиксель представлен числом от нуля до единицы. Одноканальные изображения обычно отображаются в градациях серого, но мы будем отображать их в фиолетовой цветовой гамме для удобства восприятия.

Заметьте, что токенизация патчей отсутствует в коде, предоставленном в [3]; весь код в этом разделе написан автором.

mountains = np.load(os.path.join(figure_path, 'mountains.npy'))

H = mountains.shape[0]
W = mountains.shape[1]
print('Mountain at Dusk is H =', H, 'and W =', W, 'pixels.')
print('\n')

fig = plt.figure(figsize=(10,6))
plt.imshow(mountains, cmap='Purples_r')
plt.xticks(np.arange(-0.5, W+1, 10), labels=np.arange(0, W+1, 10))
plt.yticks(np.arange(-0.5, H+1, 10), labels=np.arange(0, H+1, 10))
plt.clim([0,1])
cbar_ax = fig.add_axes([0.95, .11, 0.05, 0.77])
plt.clim([0, 1])
plt.colorbar(cax=cbar_ax);
#plt.savefig(os.path.join(figure_path, 'mountains.png'))
Изображение «Закат в горах» имеет разрешение H = 60 и W = 100 пикселей.
Вывод кода (изображение автора)
Вывод кода (изображение автора)

Это изображение имеет H=60 и W=100. Мы зададим P=20, так как оно делит H и W без остатка.


P = 20
N = int((H*W)/(P**2))
print('There will be', N, 'patches, each', P, 'by', str(P)+'.')
print('\n')

fig = plt.figure(figsize=(10,6))
plt.imshow(mountains, cmap='Purples_r')
plt.hlines(np.arange(P, H, P)-0.5, -0.5, W-0.5, color='w')
plt.vlines(np.arange(P, W, P)-0.5, -0.5, H-0.5, color='w')
plt.xticks(np.arange(-0.5, W+1, 10), labels=np.arange(0, W+1, 10))
plt.yticks(np.arange(-0.5, H+1, 10), labels=np.arange(0, H+1, 10))
x_text = np.tile(np.arange(9.5, W, P), 3)
y_text = np.repeat(np.arange(9.5, H, P), 5)
for i in range(1, N+1):
    plt.text(x_text[i-1], y_text[i-1], str(i), color='w', fontsize='xx-large', ha='center')
plt.text(x_text[2], y_text[2], str(3), color='k', fontsize='xx-large', ha='center');
#plt.savefig(os.path.join(figure_path, 'mountain_patches.png'), bbox_inches='tight'
Всего будет 15 патчей, каждый размером 20 на 20.
Вывод кода (изображение автора)
Вывод кода (изображение автора)

Преобразовав эти патчи в одномерные векторы, мы видим получившиеся токены. В качестве примера рассмотрим патч под номером 12, поскольку он содержит четыре разных оттенка.

print('Each patch will make a token of length', str(P**2)+'.')
print('\n')

patch12 = mountains[40:60, 20:40]
token12 = patch12.reshape(1, P**2)

fig = plt.figure(figsize=(10,1))
plt.imshow(token12, aspect=10, cmap='Purples_r')
plt.clim([0,1])
plt.xticks(np.arange(-0.5, 401, 50), labels=np.arange(0, 401, 50))
plt.yticks([]);
#plt.savefig(os.path.join(figure_path, 'mountain_token12.png'), bbox_inches='tight')
Каждый патч будет создавать токен длиной 400.
Вывод кода (изображение автора)
Вывод кода (изображение автора)

После извлечения токенов из изображения на них обычно применяется линейное отображение, которое изменяет их размерность. Это отображение реализовано с помощью обучаемого линейного слоя. Новая длина токенов в зависимости от контекста называется латентной размерностью, размерностью канала или длиной токена. После этого преобразования токены больше не содержат визуальной информации, которую можно было бы сопоставить с патчами исходного изображения.

Теперь, когда концепция токенизации патчей ясна, можем рассмотреть её реализацию в коде.

class Patch_Tokenization(nn.Module):
    def __init__(self,
                img_size: tuple[int, int, int]=(1, 1, 60, 100),
                patch_size: int=50,
                token_len: int=768):

        """ Patch Tokenization Module
            Args:
                img_size (tuple[int, int, int]): size of input (channels, height, width)
                patch_size (int): the side length of a square patch
                token_len (int): desired length of an output token
        """
        super().__init__()

        ## Defining Parameters
        self.img_size = img_size
        C, H, W = self.img_size
        self.patch_size = patch_size
        self.token_len = token_len
        assert H % self.patch_size == 0, 'Height of image must be evenly divisible by patch size.'
        assert W % self.patch_size == 0, 'Width of image must be evenly divisible by patch size.'
        self.num_tokens = (H / self.patch_size) * (W / self.patch_size)

        ## Defining Layers
        self.split = nn.Unfold(kernel_size=self.patch_size, stride=self.patch_size, padding=0)
        self.project = nn.Linear((self.patch_size**2)*C, token_len)

    def forward(self, x):
        x = self.split(x).transpose(1,0)
        x = self.project(x)
        return x

Обратите внимание на два оператора assert, которые проверяют, что размеры изображения делятся на размер патча без остатка. Фактическое разбиение на патчи реализовано с помощью слоя torch.nn.Unfold.

Мы запустим пример этого кода, используя обрезанную, одноканальную версию изображения «Закат в горах». Мы должны увидеть значения для количества токенов и начального размера токена, как было описано выше. Мы будем использовать token_len=768 в качестве проецируемой длины, что соответствует размеру базового варианта ViT.

Первая строка в блоке кода ниже преобразует изображение «Закат в горах» из массива NumPy в тензор Torch. Нам также нужно применить unsqueeze к тензору, чтобы создать размерность канала и размерность пакета. Как и выше, у нас один канал. Поскольку у нас одно изображение, размер батча (batchsize) равен 1.

x = torch.from_numpy(mountains).unsqueeze(0).unsqueeze(0).to(torch.float32)
token_len = 768
print('Input dimensions are\n\tbatchsize:', x.shape[0], '\n\tnumber of input channels:', x.shape[1], '\n\timage size:', (x.shape[2], x.shape[3]))

# Define the Module
patch_tokens = Patch_Tokenization(img_size=(x.shape[1], x.shape[2], x.shape[3]),
                                    patch_size = P,
                                    token_len = token_len)
Входные размеры: 
  размер батча: 1 
  количество входных каналов: 1
  размер изображения: (60, 100)

Теперь разделим изображение на токены.

x = patch_tokens.project(x)
print('After projection, dimensions are\n\tbatchsize:', x.shape[0], '\n\tnumber of tokens:', x.shape[1], '\n\ttoken length:', x.shape[2])
Размеры после токенизации патчей: 
  размер батча: 1 
  количество токенов: 15 
  длина токена: 400

Как мы видели в примере, имеется N=15 токенов, каждый из которых имеет длину 400. Наконец, мы проецируем токены до длины token_len

x = patch_tokens.project(x)
print('After projection, dimensions are\n\tbatchsize:', x.shape[0], '\n\tnumber of tokens:', x.shape[1], '\n\ttoken length:', x.shape[2])
Размеры после проекции: 
  размер батча: 1 
  количество токенов: 15 
  длина токена: 768

Теперь, когда у нас есть токены, мы готовы перейти к работе с ViT.

Обработка токенов

Следующие два шага ViT, предшествующие кодировке, мы обозначим как «обработка токенов». Компонент обработки токенов, который отвечает за подготовку данных для кодирующих блоков, показан на диаграмме ViT ниже.

Компоненты обработки токенов на диаграмме ViT (изображение автора)
Компоненты обработки токенов на диаграмме ViT (изображение автора)

Первым шагом является добавление пустого токена, называемого токеном предсказания (Prediction Token), к токенам изображения. Этот токен будет использоваться на выходе кодирующих блоков для создания предсказания. Изначально он пуст — эквивалентен нулю — чтобы он мог получать информацию от других токенов изображения.

Мы начнём со 175 токенов. Каждый токен имеет длину 768, что соответствует базовому варианту ViT. Мы используем размер батча 13, потому что это простое число и оно не будет путаться с другими параметрами.

# Define an Input
num_tokens = 175
token_len = 768
batch = 13
x = torch.rand(batch, num_tokens, token_len)
print('Input dimensions are\n\tbatchsize:', x.shape[0], '\n\tnumber of tokens:', x.shape[1], '\n\ttoken length:', x.shape[2])

# Append a Prediction Token
pred_token = torch.zeros(1, 1, token_len).expand(batch, -1, -1)
print('Prediction Token dimensions are\n\tbatchsize:', pred_token.shape[0], '\n\tnumber of tokens:', pred_token.shape[1], '\n\ttoken length:', pred_token.shape[2])

x = torch.cat((pred_token, x), dim=1)
print('Dimensions with Prediction Token are\n\tbatchsize:', x.shape[0], '\n\tnumber of tokens:', x.shape[1], '\n\ttoken length:', x.shape[2])
Входные размеры:
   размер батча: 13 
   количество токенов: 175 
   длина токена: 768
Размеры токена предсказания:
   размер батча: 13 
   количество токенов: 1 
   длина токена: 768
Размеры с токеном предсказания:
   размер батча: 13 
   количество токенов: 176 
   длина токена: 768

Теперь мы добавляем позиционное кодирование для наших токенов. Позиционное кодирование позволяет трансформеру понимать порядок токенов изображения. Обратите внимание, что позиционное кодирование добавляется к токенам с помощью сложения, а не конкатенации, что позволяет сохранить исходную размерность токенов. Подробности о реализации и различных вариантах позиционного кодирования — это отдельная сложная тема, которую можно рассмотреть в другой раз.

def get_sinusoid_encoding(num_tokens, token_len):
    """ Make Sinusoid Encoding Table

        Args:
            num_tokens (int): number of tokens
            token_len (int): length of a token
            
        Returns:
            (torch.FloatTensor) sinusoidal position encoding table
    """

    def get_position_angle_vec(i):
        return [i / np.power(10000, 2 * (j // 2) / token_len) for j in range(token_len)]

    sinusoid_table = np.array([get_position_angle_vec(i) for i in range(num_tokens)])
    sinusoid_table[:, 0::2] = np.sin(sinusoid_table[:, 0::2])
    sinusoid_table[:, 1::2] = np.cos(sinusoid_table[:, 1::2]) 

    return torch.FloatTensor(sinusoid_table).unsqueeze(0)

PE = get_sinusoid_encoding(num_tokens+1, token_len)
print('Position embedding dimensions are\n\tnumber of tokens:', PE.shape[1], '\n\ttoken length:', PE.shape[2])

x = x + PE
print('Dimensions with Position Embedding are\n\tbatchsize:', x.shape[0], '\n\tnumber of tokens:', x.shape[1], '\n\ttoken length:', x.shape[2])
Размеры позиционного эмбеддинга:
   количество токенов: 176 
   длина токена: 768
Размеры с позиционным эмбеддингом:
   размер батча: 13 
   количество токенов: 176 
   длина токена: 768

Теперь наши токены готовы для передачи в кодирующие блоки, где начинается основная фаза обучения модели.

Кодирующий блок 

Кодирующие блоки — это компоненты, где модель фактически обучается, обрабатывая токены изображения с помощью механизма внимания и нейронных сетей. Количество кодирующих блоков является гиперпараметром, задаваемым пользователем. Схема кодирующего блока приведена ниже.

Кодирующий блок (изображение автора)
Кодирующий блок (изображение автора)

Код для кодирующего блока приведён ниже.

class Encoding(nn.Module):

    def __init__(self,
       dim: int,
       num_heads: int=1,
       hidden_chan_mul: float=4.,
       qkv_bias: bool=False,
       qk_scale: NoneFloat=None,
       act_layer=nn.GELU, 
       norm_layer=nn.LayerNorm):
        
        """ Encoding Block

            Args:
                dim (int): size of a single token
                num_heads(int): number of attention heads in MSA
                hidden_chan_mul (float): multiplier to determine the number of hidden channels (features) in the NeuralNet component
                qkv_bias (bool): determines if the qkv layer learns an addative bias
                qk_scale (NoneFloat): value to scale the queries and keys by; 
                                    if None, queries and keys are scaled by ``head_dim ** -0.5``
                act_layer(nn.modules.activation): torch neural network layer class to use as activation
                norm_layer(nn.modules.normalization): torch neural network layer class to use as normalization
        """

        super().__init__()

        ## Define Layers
        self.norm1 = norm_layer(dim)
        self.attn = Attention(dim=dim,
                            chan=dim,
                            num_heads=num_heads,
                            qkv_bias=qkv_bias,
                            qk_scale=qk_scale)
        self.norm2 = norm_layer(dim)
        self.neuralnet = NeuralNet(in_chan=dim,
                                hidden_chan=int(dim*hidden_chan_mul),
                                out_chan=dim,
                                act_layer=act_layer)

    def forward(self, x):
        x = x + self.attn(self.norm1(x))
        x = x + self.neuralnet(self.norm2(x))
        return x

Параметры num_heads, qkv_bias и qk_scale задают ключевые аспекты модуля внимания (Attention). Подробное рассмотрение механизма внимания для визуальных трансформеров оставим на другой раз.

Параметры hidden_chan_mul определяют размер скрытых слоёв нейронной сети, а act_layer задаёт функцию активации, которая может быть выбрана из модуля torch.nn.modules.activation. Мы рассмотрим модуль нейронной сети подробнее далее в статье.

Слой norm_layer задаёт тип нормализации и может быть выбран из любого слоя в torch.nn.modules.normalization.

Теперь мы рассмотрим каждый синий блок на диаграмме и сопровождающий его код. Мы будем использовать 176 токенов длиной 768. Размер пакета будет равен 13, так как это простое число и не вызовет путаницы с другими параметрами. Мы будем использовать 4 головы внимания, поскольку это позволяет разделить длину токена (768) на равные части для каждой «головы» в механизме внимания. Размерность каждой головы внимания вычисляется автоматически и не отображается напрямую в кодирующем блоке

# Define an Input
num_tokens = 176
token_len = 768
batch = 13
heads = 4
x = torch.rand(batch, num_tokens, token_len)
print('Input dimensions are\n\tbatchsize:', x.shape[0], '\n\tnumber of tokens:', x.shape[1], '\n\ttoken length:', x.shape[2])

# Define the Module
E = Encoding(dim=token_len, num_heads=heads, hidden_chan_mul=1.5, qkv_bias=False, qk_scale=None, act_layer=nn.GELU, norm_layer=nn.LayerNorm)
E.eval();
Входные размеры:
   размер батча: 13
   количество токенов: 176
   длина токена: 768

Теперь мы пройдём через слой нормализации и модуль внимания. Модуль внимания в кодирующем блоке спроектирован так, чтобы не изменять длину токена. Это достигается использованием линейных проекций после механизма внимания. После модуля внимания мы реализуем наше первое разветвление (split connection).

y = E.norm1(x)
print('After norm, dimensions are\n\tbatchsize:', y.shape[0], '\n\tnumber of tokens:', y.shape[1], '\n\ttoken size:', y.shape[2])
y = E.attn(y)
print('After attention, dimensions are\n\tbatchsize:', y.shape[0], '\n\tnumber of tokens:', y.shape[1], '\n\ttoken size:', y.shape[2])
y = y + x
print('After split connection, dimensions are\n\tbatchsize:', y.shape[0], '\n\tnumber of tokens:', y.shape[1], '\n\ttoken size:', y.shape[2])
После нормализации размеры:
   размер батча: 13
   количество токенов: 176
   размер токена: 768
После слоя внимания размеры:
   размер батча: 13
   количество токенов: 176
   размер токена: 768
После разделения соединения размеры:
   размер батча: 13
   количество токенов: 176
   размер токена: 768

Теперь мы проходим ещё один слой нормализации, а затем модуль нейронной сети. Мы заканчиваем вторым разветвлением.

z = E.norm2(y)
print('After norm, dimensions are\n\tbatchsize:', z.shape[0], '\n\tnumber of tokens:', z.shape[1], '\n\ttoken size:', z.shape[2])
z = E.neuralnet(z)
print('After neural net, dimensions are\n\tbatchsize:', z.shape[0], '\n\tnumber of tokens:', z.shape[1], '\n\ttoken size:', z.shape[2])
z = z + y
print('After split connection, dimensions are\n\tbatchsize:', z.shape[0], '\n\tnumber of tokens:', z.shape[1], '\n\ttoken size:', z.shape[2])
После нормализации размеры:
   размер батча: 13
   количество токенов: 176
   размер токена: 768
После нейронной сети размеры:
   размер батча: 13
   количество токенов: 176
   размер токена: 768
После разделения соединения размеры:
   размер батча: 13
   количество токенов: 176
   размер токена: 768

Таким образом завершается обработка в одном кодирующем блоке. Поскольку длина и форма токенов остаются неизменными после обработки, модель может легко пропускать токены через несколько кодирующих блоков, что определяется гиперпараметром глубины.

Модуль нейронной сети

Модуль нейронной сети (Neural Network, NN) является подкомпонентом кодирующего блока. Модуль NN очень прост и состоит из полносвязного слоя, слоя активации, за которым следует ещё один полносвязный слой. Активационным слоем может быть любой слой из torch.nn.modules.activation, который передаётся на вход модулю. Модуль NN может быть настроен так, чтобы сохранять ту же форму на входе и выходе, хотя внутренние слои могут изменять размерность данных. Мы не будем подробно рассматривать этот код, так как нейронные сети являются распространённой темой в машинном обучении и не являются фокусом этой статьи. Однако ниже представлен код модуля NN.

class NeuralNet(nn.Module):
    def __init__(self,
       in_chan: int,
       hidden_chan: NoneFloat=None,
       out_chan: NoneFloat=None,
       act_layer = nn.GELU):
        """ Neural Network Module

            Args:
                in_chan (int): number of channels (features) at input
                hidden_chan (NoneFloat): number of channels (features) in the hidden layer;
                                        if None, number of channels in hidden layer is the same as the number of input channels
                out_chan (NoneFloat): number of channels (features) at output;
                                        if None, number of output channels is same as the number of input channels
                act_layer(nn.modules.activation): torch neural network layer class to use as activation
        """

        super().__init__()

        ## Define Number of Channels
        hidden_chan = hidden_chan or in_chan
        out_chan = out_chan or in_chan

        ## Define Layers
        self.fc1 = nn.Linear(in_chan, hidden_chan)
        self.act = act_layer()
        self.fc2 = nn.Linear(hidden_chan, out_chan)

    def forward(self, x):
        x = self.fc1(x)
        x = self.act(x)
        x = self.fc2(x)
        return x

Обработка предсказаний

После прохождения через кодирующие блоки, последним этапом для модели является выполнение предсказания. Компонент «обработка предсказания» на диаграмме ViT показан ниже.

Компоненты обработки предсказаний в диаграмме ViT (изображение автора)
Компоненты обработки предсказаний в диаграмме ViT (изображение автора)

Мы рассмотрим каждый этап этого процесса. Мы продолжим с 176 токенами длиной 768. Мы будем использовать размер пакета, равный 1, чтобы проиллюстрировать, как делается одно предсказание. Размер пакета больше 1 означал бы, что модель делает несколько предсказаний одновременно, вычисляя их параллельно.

# Define an Input
num_tokens = 176
token_len = 768
batch = 1
x = torch.rand(batch, num_tokens, token_len)
print('Input dimensions are\n\tbatchsize:', x.shape[0], '\n\tnumber of tokens:', x.shape[1], '\n\ttoken length:', x.shape[2])
Входные размеры:
   размер батча: 1
   количество токенов: 176
   длина токена: 768

Сначала все токены проходят через слой нормализации.

norm = nn.LayerNorm(token_len)
x = norm(x)
print('After norm, dimensions are\n\tbatchsize:', x.shape[0], '\n\tnumber of tokens:', x.shape[1], '\n\ttoken size:', x.shape[2])
После нормализации размеры:
   размер батча: 1
   количество токенов: 1001
   размер токена: 768

Затем мы отделим токен предсказания от остальных токенов. На протяжении всех кодирующих блоков токен предсказания накапливал информацию от остальных токенов и стал ненулевым. Мы будем использовать только этот токен предсказания, чтобы сделать окончательное предсказание.

pred_token = x[:, 0]
print('Length of prediction token:', pred_token.shape[-1])
Длина токена предсказания: 768

Наконец, токен предсказания пропускается через «голову» (head), чтобы сделать предсказание. Голова обычно представляет собой разновидность нейронной сети, и её структура варьируется в зависимости от модели. В статье An Image is Worth 16x16 Words [2] используется MLP (multilayer perceptron, многослойный перцептрон) с одним скрытым слоем во время предварительного обучения и линейный слой во время окончательной тонкой настройки модели. В модели Tokens-to-Token ViT [3] используется один линейный слой в качестве «головы». В данном примере используется один линейный слой.

Обратите внимание, что форма выходных данных «головы» зависит от задачи обучения. Для классификации это обычно вектор длиной, равной количеству классов, с использованием кодировки one-hot. Для задачи регрессии это может быть любое целое число прогнозируемых параметров. В этом примере выходное значение имеет размер 1, представляя собой одно числовое значение, предсказанное для задачи регрессии.

head = nn.Linear(token_len, 1)
pred = head(pred_token)
print('Length of prediction:', (pred.shape[0], pred.shape[1]))
print('Prediction:', float(pred))
Длина предсказания: (1, 1)
Предсказание: -0.5474240779876709

И это всё! Модель сделала предсказание!

Код полностью

Для создания полного ViT модуля мы используем модуль токенизации патчей, определённый выше, и модуль ViT Backbone. ViT Backbone определяется ниже и содержит компоненты обработки токенов, кодирующие блоки и компоненты обработки предсказаний.

class ViT_Backbone(nn.Module):
    def __init__(self,
                preds: int=1,
                token_len: int=768,
                num_heads: int=1,
                Encoding_hidden_chan_mul: float=4.,
                depth: int=12,
                qkv_bias=False,
                qk_scale=None,
                act_layer=nn.GELU,
                norm_layer=nn.LayerNorm):

        """ VisTransformer Backbone
            Args:
                preds (int): number of predictions to output
                token_len (int): length of a token
                num_heads(int): number of attention heads in MSA
                Encoding_hidden_chan_mul (float): multiplier to determine the number of hidden channels (features) in the NeuralNet component of the Encoding Module
                depth (int): number of encoding blocks in the model
                qkv_bias (bool): determines if the qkv layer learns an addative bias
                qk_scale (NoneFloat): value to scale the queries and keys by; 
                 if None, queries and keys are scaled by ``head_dim ** -0.5``
                act_layer(nn.modules.activation): torch neural network layer class to use as activation
                norm_layer(nn.modules.normalization): torch neural network layer class to use as normalization
        """

        super().__init__()

        ## Defining Parameters
        self.num_heads = num_heads
        self.Encoding_hidden_chan_mul = Encoding_hidden_chan_mul
        self.depth = depth

        ## Defining Token Processing Components
        self.cls_token = nn.Parameter(torch.zeros(1, 1, self.token_len))
        self.pos_embed = nn.Parameter(data=get_sinusoid_encoding(num_tokens=self.num_tokens+1, token_len=self.token_len), requires_grad=False)

        ## Defining Encoding blocks
        self.blocks = nn.ModuleList([Encoding(dim = self.token_len, 
                                               num_heads = self.num_heads,
                                               hidden_chan_mul = self.Encoding_hidden_chan_mul,
                                               qkv_bias = qkv_bias,
                                               qk_scale = qk_scale,
                                               act_layer = act_layer,
                                               norm_layer = norm_layer)
             for i in range(self.depth)])

        ## Defining Prediction Processing
        self.norm = norm_layer(self.token_len)
        self.head = nn.Linear(self.token_len, preds)

        ## Make the class token sampled from a truncated normal distrobution 
        timm.layers.trunc_normal_(self.cls_token, std=.02)

    def forward(self, x):
        ## Assumes x is already tokenized

        ## Get Batch Size
        B = x.shape[0]
        ## Concatenate Class Token
        x = torch.cat((self.cls_token.expand(B, -1, -1), x), dim=1)
        ## Add Positional Embedding
        x = x + self.pos_embed
        ## Run Through Encoding Blocks
        for blk in self.blocks:
            x = blk(x)
        ## Take Norm
        x = self.norm(x)
        ## Make Prediction on Class Token
        x = self.head(x[:, 0])
        return x

Из модуля ViT Backbone мы можем определить полную модель ViT.

class ViT_Model(nn.Module):
 def __init__(self,
    img_size: tuple[int, int, int]=(1, 400, 100),
    patch_size: int=50,
    token_len: int=768,
    preds: int=1,
    num_heads: int=1,
    Encoding_hidden_chan_mul: float=4.,
    depth: int=12,
    qkv_bias=False,
    qk_scale=None,
    act_layer=nn.GELU,
    norm_layer=nn.LayerNorm):

  """ VisTransformer Model

   Args:
    img_size (tuple[int, int, int]): size of input (channels, height, width)
    patch_size (int): the side length of a square patch
    token_len (int): desired length of an output token
    preds (int): number of predictions to output
    num_heads(int): number of attention heads in MSA
    Encoding_hidden_chan_mul (float): multiplier to determine the number of hidden channels (features) in the NeuralNet component of the Encoding Module
    depth (int): number of encoding blocks in the model
    qkv_bias (bool): determines if the qkv layer learns an addative bias
    qk_scale (NoneFloat): value to scale the queries and keys by; 
         if None, queries and keys are scaled by ``head_dim ** -0.5``
    act_layer(nn.modules.activation): torch neural network layer class to use as activation
    norm_layer(nn.modules.normalization): torch neural network layer class to use as normalization
  """
  super().__init__()

  ## Defining Parameters
  self.img_size = img_size
  C, H, W = self.img_size
  self.patch_size = patch_size
  self.token_len = token_len
  self.num_heads = num_heads
  self.Encoding_hidden_chan_mul = Encoding_hidden_chan_mul
  self.depth = depth

  ## Defining Patch Embedding Module
  self.patch_tokens = Patch_Tokenization(img_size,
           patch_size,
           token_len)

  ## Defining ViT Backbone
  self.backbone = ViT_Backbone(preds,
         self.token_len,
         self.num_heads,
         self.Encoding_hidden_chan_mul,
         self.depth,
         qkv_bias,
         qk_scale,
         act_layer,
         norm_layer)
  ## Initialize the Weights
  self.apply(self._init_weights)

 def _init_weights(self, m):
  """ Initialize the weights of the linear layers & the layernorms
  """
  ## For Linear Layers
  if isinstance(m, nn.Linear):
   ## Weights are initialized from a truncated normal distrobution
   timm.layers.trunc_normal_(m.weight, std=.02)
   if isinstance(m, nn.Linear) and m.bias is not None:
    ## If bias is present, bias is initialized at zero
    nn.init.constant_(m.bias, 0)
  ## For Layernorm Layers
  elif isinstance(m, nn.LayerNorm):
   ## Weights are initialized at one
   nn.init.constant_(m.weight, 1.0)
   ## Bias is initialized at zero
   nn.init.constant_(m.bias, 0)
   
 @torch.jit.ignore ##Tell pytorch to not compile as TorchScript
 def no_weight_decay(self):
  """ Used in Optimizer to ignore weight decay in the class token
  """
  return {'cls_token'}

 def forward(self, x):
  x = self.patch_tokens(x)
  x = self.backbone(x)
  return x

Параметры img_size, patch_size и token_len определяют размер изображения, размер патча и длину токена, соответственно, в модуле токенизации патчей. 

Параметры num_heads, Encoding_hidden_channel_mul, qkv_bias, qk_scale и act_layer определяют модули кодирующих блоков. Слой активации (act_layer) может быть любым слоем из torch.nn.modules.activation. Параметр depth определяет количество кодирующих блоков в модели.

Параметр norm_layer задаёт нормализацию как внутри, так и за пределами модулей кодирующих блоков. Его можно выбрать из любого слоя в torch.nn.modules.normalization.

Метод _init_weights взят из кода T2T-ViT [3]. Этот метод можно удалить, чтобы все обучаемые веса и смещения инициализировались случайным образом. В текущей реализации веса линейных слоев инициализируются усечённым нормальным распределением; смещения линейных слоев инициализируются нулями; веса слоев нормализации инициализируются единицей; смещения слоев нормализации — нулями.

Заключение

Теперь, обладая глубоким пониманием механики ViT, вы можете приступить к обучению своих моделей! Ниже приведён список ресурсов, где можно скачать код для моделей ViT. Некоторые из них позволяют вносить больше модификаций в модель, чем другие. 

  • GitHub репозиторий для этой серии статей

  • GitHub репозиторий для статьи An Image is Worth 16x16 Words — содержит предварительно обученные модели и код для тонкой настройки, но не содержит определений моделей

  • ViT, как реализовано в PyTorch Image Models (timm) timm.create_model('vit_base_patch16_224', pretrained=True)

  • Пакет vit-pytorch от Фила Ванга (Phil Wang)


В заключение напоминание о ближайших открытых уроках по компьютерному зрению и машинному обучению:

Комментарии (0)