Библиотек будет использовано минимум.

import random
import math

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from tqdm.notebook import tqdm

import torch
import torch.utils.data as tdutils
from torch import nn, optim

Определение dataset'а

Обучение любой нейронки начинается именно с этого. Для целей поиграться датасет можно взять синтетический. Например, нагенерить из арифметических выражений. Если в выражении "2+2=4" случайным образом заменить один символ, получится достаточно простая, но нетривиальная задача по коррекции ошибок. Чтобы задать датасет в пригодном для использовании в pytorch виде, нужно создать класс наследник IterableDataset и переопределить метод __iter__

OPS = '+-*/%'
DIGITS = '0123456789'
CHARS = ' ' + DIGITS + OPS + '='
OPS_METHODS = {
    '+': lambda v1, v2: v1 + v2,
    '-': lambda v1, v2: v1 - v2,
    '*': lambda v1, v2: v1 * v2,
    '/': lambda v1, v2: 0 if v2 == 0 else v1 // v2,
    '%': lambda v1, v2: v1 % v2

}

class SampleSet(tdutils.IterableDataset):
    def __init__(self, val_min=0, val_max=99):
        self.val_min, self.val_max = val_min, val_max
        assert val_min > 0
        max_res = val_max * val_max
        self.str_size = len(f'{val_max}*{val_max}={max_res}')
        
    def __iter__(self):
        while True:
            yield self.make_sample()
            
    def to_tensor(self, str_value):
        res = torch.zeros([self.str_size], dtype=torch.uint8)
        converted = torch.tensor([
            CHARS.index(char) for char in str_value
        ])
        res[0:len(converted)] = converted
        return res
        
            
    def make_sample(self):
        val1 = random.randint(self.val_min, self.val_max)
        val2 = random.randint(self.val_min, self.val_max)
        op = OPS[random.randint(0, len(OPS) - 1)]
        res = OPS_METHODS[op](val1, val2)

        original = f'{val1}{op}{val2}={res}'
        lst = list(original)
        lst[random.randint(0, len(original) - 1)] = CHARS[random.randint(0, len(CHARS)-1)]
        replaced = ''.join(lst)
                        
        return {
            'task': self.to_tensor(replaced),
            'answer': self.to_tensor(original)
        }
      
      
# код для проверки
_sample = SampleSet(1, 99).make_sample()
_sample
# output
{
  'task': tensor([ 4, 10, 11,  6,  6, 16,  2,  2,  5,  0], dtype=torch.uint8),
  'answer': tensor([ 6, 10, 11,  6,  6, 16,  2,  2,  5,  0], dtype=torch.uint8)
}

Полученный датасет уже можно скармливать в DataLoader

_dataset = tdutils.DataLoader(
    dataset=SampleSet(1, 9),
    batch_size=8
)
_dataset_iter = iter(_dataset)
_batch = next(_dataset_iter)
_batch

# output
{'task': tensor([[ 7,  2,  7, 16,  2,  0],
         [ 8, 13,  2, 16,  8,  0],
         [ 4, 12,  4, 16, 15,  0],
         [ 7, 13,  7, 18,  4,  7],
         [ 2, 15,  2, 16,  2,  0],
         [ 2, 15,  9, 16,  2,  0],
         [ 6, 13,  3,  5,  2,  1],
         [10, 13,  3, 14,  2,  9]], dtype=torch.uint8),
 'answer': tensor([[ 7, 14,  7, 16,  2,  0],
         [ 8, 13,  2, 16,  8,  0],
         [ 4, 12,  4, 16,  1,  0],
         [ 7, 13,  7, 16,  4,  7],
         [ 2, 15,  4, 16,  2,  0],
         [ 2, 15,  5, 16,  2,  0],
         [ 6, 13,  3, 16,  2,  1],
         [10, 13,  3, 16,  2,  9]], dtype=torch.uint8)}

А еще для работы с датасетом не помешают функции отображения

def tensor_to_str(tensor):    
    res = ''.join([
        CHARS[val] for val in tensor
    ])
    return res.strip(' ')
    
    
def show_sample(dct):
    task = tensor_to_str(dct['task'])
    answer = tensor_to_str(dct['answer'])
    return f'{task}->{answer}'

show_sample(_sample)
# output
'39+55=114->59+55=114'

Embeddings

Вся магия трансформера начинается с перевода входной последовательности в векторное представление. Сделать это можно встроенным в pytorch модулем nn.Embedding. Он поддерживает внутри себя словарик (тензор размером d_chars * d_models)

class Embed(nn.Module):
    def __init__(self, d_chars, d_model):
        super().__init__()                
        self.embedding = nn.Embedding(d_chars, d_model)
        
    def forward(self, batch):
        return self.embedding(batch['task'].long())   
    
_embed = Embed(len(CHARS), 32)(_batch)
_embed.shape
# output
torch.Size([8, 6, 32])

Размерность пространства d_model это гиперпараметр, который нужно будет подбирать вне процедуры градиентного спуска. Размерность данных не меняется на протяжении всего пути их прохождения через модель и на выходе понадобится обратное преобразование. Его впринципе можно сделать на основе того же словарика, но проще просто выучить отдельным полносвязным слоем

class DecodeEmbed(nn.Module):
    def __init__(self, d_chars, d_model):
        super().__init__()
        self.decode = nn.Linear(d_model, d_chars)
        self.softmax = nn.Softmax(dim=-1)
        
    def forward(self, embed):        
        return torch.argmax(self.softmax(self.decode(embed)), dim=-1)
    
DecodeEmbed(len(CHARS), 32)(_embed).shape
# output
torch.Size([8, 6])

Attention

Внимание это механизм, позволяющий сетке направить пристальный взгляд на какой-то из входных элемент обрабатываемой последовательности. Eсли более формально, то каждому элементу в последовательности приписывается некоторый ключ K и значение V. Дальше векторным запросом Q можно запросить нужную информацию. Существует много способов как конкретно это сделать. В трансформе за основу взят DotProduct Attention. В нем в качестве весов, с которыми нужно складывать значение, берется softmax произведения Q и K

\mathrm{softmax}(Q \cdot K^T) \cdot V

Как на пальцах это работает. Путь информация о том, что именно лежит в символах последовательности закодирована в двухмерном пространстве.

_k = torch.tensor([
    [1, 1], [-1, 1], [0.01, 0.02]
]).float()
_q = torch.tensor([
    [-1, 1], [1, 1], [0, 1]
]).float() * 10

_v = torch.tensor([
    [0, 1, 2, 3],
    [4, 5, 6, 7],
    [8, 9, 10, 11]
]).float()

Вектора K первого и второго символа ортогональны. В третьем лежит какой-то шум. Вектор _q имеет похожую структуру

torch.matmul(_q, _k.T)
# output
tensor([[ 0.0000, 20.0000,  0.1000],
        [20.0000,  0.0000,  0.3000],
        [10.0000, 10.0000,  0.2000]])

После матричного умножения в видно, что в первый символ нужно записать содержание второго, а во второй первый. Запрос в третьей позиции лежит посередине между первым и вторым символом. В матрице он представлен равными весами для первого и второго. Третий символ (последний столбец) не вносит заметно вклада в итоговый результат. Картину еще более сглаживает применение softmax'а

_atw = nn.Softmax(dim=-1)(torch.matmul(_q, _k.T))
(_atw * 100).numpy().astype(int)
#output 
array([[  0, 100,   0],
       [100,   0,   0],
       [ 49,  49,   0]])

Результат предсказуем

_att_res = torch.matmul(_atw, _v)
_att_res.cpu().numpy().astype(int)
# output
array([[4, 5, 6, 7],
       [0, 1, 2, 3],
       [2, 3, 4, 5]])

В ответе первые и вторые строки переставлены, а в последней строчке лежит их среднее

Но это еще не все. В трансформере используется multihead attention. Оно состоит из dot product головок пониженной размерности. Кол-во головок это еще один гиперпараметр. Выводы отдельных голов конкатенируются и выучиваемым преобразованием трансформируются в исходное пространство. А чтобы сетке было проще обучаться в каждой из головок добавленно деление на корень из размерности. Просто, чтобы дисперсия на выходе была такая же, как и на входе.

class Attention(nn.Module):
    def forward(self, q, k, v):
        sel = torch.matmul(q, k.transpose(-1, -2))
        weights = nn.Softmax(dim=-1)(sel / math.sqrt(k.shape[-1]))
        return torch.matmul(weights, v)        
        
class ProjectedAttention(nn.Module):
    def __init__(self, d_qk1, d_qk2, d_v1, d_v2):
        super().__init__()
        self.keys = nn.Linear(d_qk1, d_qk2)
        self.queries = nn.Linear(d_qk1, d_qk2)
        self.values = nn.Linear(d_v1, d_v2)
        self.att = Attention()
        
    def forward(self, q, k, v):
        return self.att(
            self.queries(q),
            self.keys(k),
            self.values(v)
        )
                

class MultiHeadAttention(nn.Module):
    def __init__(self, d_model, h):
        super().__init__()
        self.heads = nn.ModuleList([
            ProjectedAttention(
                d_model, d_model // h, 
                d_model, d_model // h
            ) 
            for _ in range(h)
        ])           
        self.final = nn.Linear(d_model, d_model)
        
    def forward(self, q, k, v):
        head_res = torch.cat([
            head(q, k, v)
            for head in self.heads
        ], dim=-1)
        return self.final(head_res)
        
_model = MultiHeadAttention(d_model=32, h=2)
_multi = _model(_embed, _embed, _embed)
_multi.shape
# output
torch.Size([8, 6, 32])

Positional Encoding

Трансформер не имеет никакой другой связи с информацией, записанной в соседних токенах, кроме механизма внимания и, чтобы снабдить его возможностью запрашивать содержимое соседних ячеек, к эмбендингам добавляется позиционное кодирование. Как это делается проще понять в комплексной нотации: вещественное пространство четной размерности v_len можно представить как комплексное размерностью в двое меньше. Для позиции pos в компоненте l прибавляется вектор

e^{\left( \mathrm{pos} \cdot M^{l} \cdot \sqrt{-1} \right)}

где M -- некоторое маленькое число. При переходе к вещественным числам там возникает sin и cos. При таком кодировании обозначение позиций со смещением в ту или другую сторону могут быть получены из текущего линейным преобразованием.

e^{\left( (\mathrm{pos}+ k) \cdot M^{l} \cdot \sqrt{-1} \right)}=e^{\left( \mathrm{pos} \cdot M^{l} \cdot \sqrt{-1} \right)}\cdot e^{\left( k \cdot M^{l} \cdot \sqrt{-1} \right)}
M = 1/10000

def pos_tensor(seq_len, v_len):
    power = 2 * torch.arange(v_len // 2).float() / v_len
    arg = torch.outer(
        torch.arange(seq_len),
        M ** power
        
    )
    res = torch.cat([torch.sin(arg), torch.cos(arg)], dim=-1)
    return res
    
class PositionInfo(nn.Module):
    def forward(self, data):
        pos_info = pos_tensor(data.shape[-2], data.shape[-1]).to(data.device)
        return data + pos_info
    
PositionInfo()(_embed).shape

Encoders и Decoders

Трансформер состоит из Encoder и Decoder блоков. Encoder это просто multihead self attention слой с последующим feed-forward слоем. Feed forward состоит из двух полносвязных слоев с одинаковыми для всех позиций весами (свертка с ядром 1). Осложнена эта картина skip connection'ами и пакетной нормализацией.

class BigBatch(nn.Module):
    def __init__(self, net):
        super().__init__()
        self.net = net
        
    def forward(self, batch):
        first_size = batch.shape[0]
        second_size = batch.shape[1]
        big_batch = batch.reshape(
            first_size * second_size, *tuple(batch.shape[2:])
        )
        big_res = self.net(big_batch)
        return big_res.reshape(
            first_size, second_size, *tuple(big_res.shape[1:])
        )
        
class Encoder(nn.Module):
    def __init__(self, d_model, h):
        super().__init__()
        self.self_att = MultiHeadAttention(d_model, h)
        self.norm1 = BigBatch(nn.BatchNorm1d(d_model))
        self.feed_forward = nn.Sequential(
            BigBatch(nn.Linear(d_model, d_model)),
            nn.ReLU(),
            BigBatch(nn.Linear(d_model, d_model)),
        )
        self.norm2 = BigBatch(nn.BatchNorm1d(d_model))

        
        
    def forward(self, data):
        res1 = self.self_att(data, data, data)
        res1r = self.norm1(data + res1)
        res2 = self.feed_forward(res1r)
        res2r = self.norm2(res1r + res2)
        return res2r
        
_encode = Encoder(32, 2)(_embed)
_encode.shape

# output
torch.Size([8, 6, 32])

Decoder чуть сложнее. Кроме self attention'а у него есть слой внимания в выхлопу encoder'а

class Decoder(nn.Module):
    def __init__(self, d_model, h):
        super().__init__()
        self.self_att = MultiHeadAttention(d_model, h)
        self.norm1 = BigBatch(nn.BatchNorm1d(d_model))
        
        self.src_att = MultiHeadAttention(d_model, h)
        self.norm2 = BigBatch(nn.BatchNorm1d(d_model))
        
        self.feed_forward = nn.Sequential(
            BigBatch(nn.Linear(d_model, d_model)),
            nn.ReLU(),
            BigBatch(nn.Linear(d_model, d_model)),
        )
        self.norm3 = BigBatch(nn.BatchNorm1d(d_model))
        
        
    def forward(self, src, tgt):
        res1r = self.self_att(tgt, tgt, tgt)
        res1 = self.norm1(res1r + tgt)
        res2r = self.src_att(res1, res1, src)
        res2 = self.norm2(res1 + res2r)
        res3r = self.feed_forward(res2)
        res3 = self.norm3(res2 + res3r)
        return res3
        
_decode = Decoder(32, 2)(_embed, _embed)
_decode.shape
# output
torch.Size([8, 6, 32])

Все в сборе

Модель состоит из embedding'ов, несколько слоев Encoder'а, Decoder'а и финального выходного слоя.

class Model(nn.Module):
    def __init__(self, d_chars, d_model, h, n_layers=2):
        super().__init__()
        self.embed = Embed(d_chars, d_model)
        self.pos = PositionInfo()
        self.encoders = nn.ModuleList([
            Encoder(d_model, h) for _ in range(n_layers)
        ])
        self.decoders = nn.ModuleList([
            Decoder(d_model, h) for _ in range(n_layers)
        ])
        self.embed_decoder = nn.Linear(d_model, d_chars)
        
        
    def forward(self, batch):
        enc_out = self.pos(self.embed(batch))
        for layer in self.encoders:
            enc_out = layer(enc_out)
        dec_out = enc_out
        for layer in self.decoders:
            dec_out = layer(enc_out, dec_out)
        char_out = self.embed_decoder(dec_out)
        return char_out
        
    
_model_out = Model(len(CHARS), 32, 2, 2)(_batch)
_model_out.shape
# output
torch.Size([8, 6, 19])

Маленькая оговорка: все, да не все. Здесь нет dropout'ов и mask'ed attention'а. В оригинале декодеру позволено запрашивать только предыдущие значения, чтобы сделать модель авторегрессионной. Здесь же, для такой простой задачи, не очень понятно, зачем это может понадобиться.

Процедура обучения

Первое что нужно сделать: описать функцию потерь. Для такой задачи можно взять cross entropy loss для отдельных символов.

def mean_batch_loss(answer, model_out):
    return nn.CrossEntropyLoss()(
        model_out.permute(0, 2, 1),
        answer.long()
    )
    
mean_batch_loss(_batch['answer'], _model_out)
# output
tensor(2.7526, grad_fn=<NllLoss2DBackward>)

Чтобы не путаться, все параметры обучения можно загнать в отдельный объект.

DEVICE = 'cuda:0'

class Context():
    def __init__(self):
        self.train_epoch = 10000
        self.val_samples = 10000
        self.device = DEVICE
        self.batch_loss = 0
        self.batch_discont = 0.8
        self.history = []
        
_ctx = Context()
_ctx.model = Model(len(CHARS), 32, 2, 2).to(_ctx.device)
_ctx.opt = optim.SGD(_ctx.model.parameters(), lr=0.01)
_ctx.epoch_size = 1000
_ctx.val_samples = 1000

Учится будем на GPU и нам пригодится функция, отправляющая туда данные

def to_device(device, data):
    if isinstance(data, torch.Tensor):
        return data.to(device)
    elif isinstance(data, dict):
        return {
            key: to_device(device, value)
            for key, value in data.items()
        }
    
to_device(DEVICE, _batch)
# output
{'task': tensor([[ 7,  2,  7, 16,  2,  0],
         [ 8, 13,  2, 16,  8,  0],
         [ 4, 12,  4, 16, 15,  0],
         [ 7, 13,  7, 18,  4,  7],
         [ 2, 15,  2, 16,  2,  0],
         [ 2, 15,  9, 16,  2,  0],
         [ 6, 13,  3,  5,  2,  1],
         [10, 13,  3, 14,  2,  9]], device='cuda:0', dtype=torch.uint8),
 'answer': tensor([[ 7, 14,  7, 16,  2,  0],
         [ 8, 13,  2, 16,  8,  0],
         [ 4, 12,  4, 16,  1,  0],
         [ 7, 13,  7, 16,  4,  7],
         [ 2, 15,  4, 16,  2,  0],
         [ 2, 15,  5, 16,  2,  0],
         [ 6, 13,  3, 16,  2,  1],
         [10, 13,  3, 16,  2,  9]], device='cuda:0', dtype=torch.uint8)}

Основа основ -- процедура скармливания отдельного batch'а в сетку

def feed_batch(ctx, batch):
    ctx.opt.zero_grad()
    on_device = to_device(ctx.device, batch)
    model_out = ctx.model(on_device)
    loss = mean_batch_loss(on_device['answer'], model_out)
    loss.backward()
    ctx.opt.step()
    ctx.batch_loss = (
        (1 - ctx.batch_discont) * ctx.batch_loss
        + ctx.batch_discont * loss.detach().cpu().numpy()
    )
 
feed_batch(_ctx, _batch)

Всю нудную работу по расчету градиентов pytorch берет на себя. Для каждого тензора (если явно не указать обратное) он поддерживает значение текущего накопленного градиента а также функции, делающие back propagation. Нам же остается сбросить его перед началом обработки пакета (zero_grad), и вызвать loss.backward(), opt.step() после применения модели.

Процедура обучения обычно долгая. На текущее значение лоса полезно поглядывать. Вдруг он ушел в вверх или застопорился. Для этого можно хранить средний batch_loss c экспоненциальным backoff'ом.

Но им одним дело не ограничивается. Хорошо бы периодически честно пересчитывать метрики. В качестве метрик можно взять точность по символам и точность по семплам. Функция feed_epoch служим именно этой цели. Она скармливает в сетку некоторое кол-во семплов, считает метрики и записывает их для истории.

def calc_metrics(ctx, dataset_iter):
    ctx.model.eval()
    counters = {
        'batch': 0,
        'batch_loss': 0.0,
        'batch_char_acc': 0.0,
        'sample_acc': 0.0
    }
    with torch.no_grad():
        with tqdm(total=ctx.val_samples, leave=False) as pbar:
            num_samples = 0
            while num_samples < ctx.val_samples:
                batch = next(dataset_iter)
                batch_size = len(batch['task'])
                on_device = to_device(ctx.device, batch)
                
                pred = ctx.model(on_device)
                counters['batch'] += 1
                counters['batch_loss'] += mean_batch_loss(
                    on_device['answer'], 
                    pred
                ).cpu().numpy()
                char_pred = torch.argmax(pred, dim=-1)
                correct = char_pred == on_device['answer']
                
                
                counters['batch_char_acc'] += correct.float().mean().cpu().numpy()
                counters['sample_acc'] += correct.min(dim=-1).values.float().mean().cpu().numpy()
                
                num_samples += batch_size
                pbar.update(batch_size)
                
    return {
        'loss': counters['batch_loss'] / counters['batch'],
        'char_acc': counters['batch_char_acc'] / counters['batch'], 
        'sample_acc': counters['sample_acc'] / counters['batch']
    }
                    
calc_metrics(_ctx, _dataset_iter)
def feed_epoch(ctx, dataset_iter):
    ctx.model.train()
    num_samples = 0
    with tqdm(total=ctx.train_epoch, leave=False) as pbar:
        while num_samples < ctx.train_epoch:
            batch = next(dataset_iter)
            feed_batch(ctx, batch)
        
            batch_size = len(batch['task'])
            num_samples += batch_size
            pbar.update(batch_size)                           
            pbar.set_postfix(batch_loss = ctx.batch_loss)
    ctx.history.append(calc_metrics(ctx, dataset_iter))
    
            

_ctx.train_epoch = 10000            
feed_epoch(_ctx, _dataset_iter)

Чистовое обучение

Теперь все готово, чтобы все взять и обучить. Размер батча и learning rate подобраны для Тесла V100. Для карточки по-скромнее размер пакета нужно брать по-меньше (чтобы по памяти не вылететь), и learning rate, соответственно, тоже (чтобы градиент нешибко шатало).

dataset = tdutils.DataLoader(
    dataset=SampleSet(1, 99999999),
    batch_size=1024 * 2,
    num_workers=8
)
dataset_iter = iter(dataset)
ctx = Context()
ctx.train_epoch = 10000
ctx.val_samples = 5000
ctx.model = Model(len(CHARS), 256, h=8, n_layers=3).to(ctx.device)
ctx.opt = optim.SGD(ctx.model.parameters(), lr=1.0)
ctx.seq_len = dataset.dataset.str_size
ctx.history = []
%%time
_num_epoches = 100
for _ in tqdm(range(_num_epoches), leave=False):
    feed_epoch(ctx, dataset_iter)
# output
CPU times: user 4min 34s, sys: 1min 47s, total: 6min 22s
Wall time: 6min 20s

Вывод истории

Самая простая модель, которую можно взять для baseline'а -- тупое копирование.

def plot_history(ctx):
    plt.subplot(1, 3, 1)
    loss = [ record['loss'] for record in ctx.history]    
    plt.plot(loss)
    plt.title('loss')
    
    plt.subplot(1, 3, 2)
    char_acc = [ record['char_acc'] for record in ctx.history]
    plt.plot(char_acc)
    baseline = (ctx.seq_len - 1) / ctx.seq_len + 1/ctx.seq_len * 1/len(CHARS)
    plt.plot([0, len(char_acc) - 1], [baseline, baseline], c='g')
    plt.plot([0, len(char_acc) - 1], [1, 1], c='g')

    plt.title('char_acc')
    
    plt.subplot(1, 3, 3)
    baseline = (1/len(CHARS))
    sample_acc = [record['sample_acc'] for record in ctx.history]
    plt.plot(sample_acc)
    plt.plot([0, len(sample_acc) - 1], [baseline, baseline], c='g')
    plt.title('sample_acc')
    
plt.figure(figsize=(15,5))
plot_history(ctx)
Вывод истории обучения
Вывод истории обучения

Модель учится :)