Большие языковые модели это конечно хорошо, но иногда требуется использовать что-то маленькое и быстрое.

Постановка задачи

Дистилляция будет проводиться для модели BERT, обученной на задачу бинарной классификации. В качестве данных был выбран открытый корпус русскоязычных твитов. Вдохновлялся двумя статьями: по дистилляции данных из BERT в BiLSTM, и собственно по дистилляции BERT. Нового ничего не добавлю, хочется все причесать и сделать пошаговый туториал для простого использования. Весь код на github.

План работ

  1. Baseline 1: TF-IDF + RandomForest

  2. Baseline 2: BiLSTM

  3. Дистилляция BERT > BiLSTM

  4. Дистилляция BERT > tinyBERT

TF-IDF + RandomForest

Все стандартно: нижний регистр, лемматизация, удаление стоп-слов. Полученные вектора классифицируем RandomForest. Получаем F1 чуть больше 0.75.

Как обучить TF-IDF + RF
import re
import pandas as pd
from pymystem3 import Mystem

# get data
data = pd.read_csv('data.csv')

texts = list(data['comment'])
labels = list(map(int, data['toxic'].values))

# clean texts
texts = [re.sub('[^а-яё ]', ' ', str(t).lower()) for t in texts]
texts = [re.sub(r" +", " ", t).strip() for t in texts]

# lemmatize
mstm = Mystem()

normalized = [''.join(mstm.lemmatize(t)[:-1]) for t in texts]

# remove stopwords
with open('./stopwords.txt') as f:
    stopwords = [line.rstrip('\n') for line in f]

def drop_stop(text):
    tokens = text.split(' ')
    tokens = [t for t in tokens if t not in stopwords]
    return ' '.join(tokens)

normalized = [drop_stop(text) for text in normalized]

# new dataset
df = pd.DataFrame()
df['text'] = texts
df['norm'] = normalized
df['label'] = labels

# train-valid-test-split
from sklearn.model_selection import train_test_split

train, test = train_test_split(df, test_size=0.3, random_state=42)
valid, test = train_test_split(test, test_size=0.5, random_state=42)

# tf-idf
from sklearn.feature_extraction.text import TfidfVectorizer

model_tfidf = TfidfVectorizer(max_features=5000)

train_tfidf = model_tfidf.fit_transform(train['norm'].values)
valid_tfidf = model_tfidf.transform(valid['norm'].values)
test_tfidf = model_tfidf.transform(test['norm'].values)

# RF
from sklearn.ensemble import RandomForestClassifier

cls = RandomForestClassifier(random_state=42)
cls.fit(train_tfidf, train['label'].values)

# prediction
predictions = cls.predict(test_tfidf)

# score
from sklearn.metrics import f1_score

f1_score(predictions, test['label'].values)

BiLSTM

Попробуем улучшить бэйзлайн с помощью нейросетевого подхода. Все стандартно: учим токенизатор, учим сетку. В качестве базовой архитектуры берем BiLSTM. Получаем F1 чуть больше 0.79. Небольшой, но прирост есть.

Как обучить BiLSTM
# get data
import pandas as pd

train = pd.read_csv('train.csv')
valid = pd.read_csv('valid.csv')
test = pd.read_csv('test.csv')

# create tokenizer
from tokenizers import Tokenizer
from tokenizers import ByteLevelBPETokenizer
from tokenizers.pre_tokenizers import Whitespace

tokenizer = ByteLevelBPETokenizer()
tokenizer.pre_tokenizer = Whitespace()
tokenizer.enable_padding(pad_id=0, pad_token='<pad>')

texts_path = 'texts.txt'

with open(texts_path, 'w') as f:
    for text in list(train['text'].values):
        f.write("%s\n" % text)

tokenizer.train(
    files=[texts_path],
    vocab_size=5_000,
    min_frequency=2,
    special_tokens=['<pad>', '<unk>']
    )

# create dataset
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader

class CustomDataset(Dataset):

    def __init__(self, tokens, labels, max_len):
        self.tokens = tokens
        self.labels = labels
        self.max_len = max_len


    def __len__(self):
        return len(self.tokens)


    def __getitem__(self, idx):
        label = self.labels[idx]
        label = torch.tensor(label)
        tokens = self.tokens[idx]
        out = torch.zeros(self.max_len, dtype=torch.long)
        out[:len(tokens)] = torch.tensor(tokens, dtype=torch.long)[:self.max_len]
        return out, label

max_len = 64
BATCH_SIZE = 16

train_labels = list(train['label'])
train_tokens = [tokenizer.encode(text).ids for text in list(train['text'])]
train_dataset = CustomDataset(train_tokens, train_labels, max_len)
train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=False)

test_labels = list(test['label'])
test_tokens = [tokenizer.encode(text).ids for text in list(test['text'])]
test_dataset = CustomDataset(test_tokens, test_labels, max_len)
test_loader = DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=False)

# create BiLSTM
class LSTM_classifier(nn.Module):


    def __init__(self, hidden_dim=128, vocab_size=5000, embedding_dim=300, linear_dim=128, dropout=0.3, n_classes=2):
        super().__init__()
        self.embedding_layer = nn.Embedding(vocab_size, embedding_dim)
        self.lstm_layer = nn.LSTM(embedding_dim, hidden_dim, batch_first=True, bidirectional=True)
        self.dropout_layer = nn.Dropout(dropout)        
        self.fc_layer = nn.Linear(hidden_dim * 2, linear_dim)
        self.batchnorm = nn.BatchNorm1d(linear_dim)
        self.relu = nn.ReLU()
        self.out_layer = nn.Linear(linear_dim, n_classes)


    def forward(self, inputs):
        batch_size = inputs.size(0)
        embeddings = self.embedding_layer(inputs)
        lstm_out, (ht, ct) = self.lstm_layer(embeddings)
        out = ht.transpose(0, 1)
        out = out.reshape(batch_size, -1)
        out = self.fc_layer(out)
        out = self.batchnorm(out)
        out = self.relu(out)
        out = self.dropout_layer(out)
        out = self.out_layer(out)
        out = torch.squeeze(out, 1)
        out = torch.sigmoid(out)
        return out

def init_weights(m):
    if isinstance(m, nn.Linear):
        torch.nn.init.xavier_uniform_(m.weight)
        m.bias.data.fill_(0.01)

def eval_nn(model, data_loader):
    predicted = []
    labels = []
    model.eval()
    with torch.no_grad():
        for data in data_loader:
            x, y = data
            x = x.to(device)
            outputs = model(x)
            _, predict = torch.max(outputs.data, 1)
            predict = predict.cpu().detach().numpy().tolist()
            predicted += predict
            labels += y
        score = f1_score(labels, predicted, average='binary')
    return score

def train_nn(model, optimizer, loss_function, train_loader, test_loader, device, epochs=20):
    best_score = 0
    for epoch in range(epochs):
        model.train()
        for inputs, labels in tqdm(train_loader):
            inputs, labels = inputs.to(device), labels.to(device)
            optimizer.zero_grad()
            predict = model(inputs)
            loss = loss_function(predict, labels)
            loss.backward()
            optimizer.step()
       score = eval_nn(model, test_loader)
        print(epoch, 'valid:', score)
        if score > best_score:
            torch.save(model.state_dict(),'lstm.pt')
            best_score = score
    return best_score

# fit NN
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

model = LSTM_classifier(hidden_dim=256, vocab_size=5000, embedding_dim=300, linear_dim=128, dropout=0.1)   

model.apply(init_weights)

model.to(device)

optimizer = optim.AdamW(model.parameters())

loss_function = nn.CrossEntropyLoss().to(device)

train_nn(model, optimizer, loss_function, train_loader, valid_loader, device, epochs=20)

eval_nn(model, test_loader)

Учим BERT

Обучим модель-учитель. В качестве учителя я выбрал героя вышеупомянутой статьи по дистилляции - rubert-tiny от @cointegrated. Получаем F1 чуть больше 0.91. Я особо не игрался с обучением, можно думаю было получить метрику и получше, особенно если использовать большой BERT, но и так достаточно показательно. Как обучить BERT на бинарную классификацию можно глянуть в моей прошлой статье, или прямо тут:

как обучить BERT
import torch
from torch.utils.data import Dataset

class BertDataset(Dataset):

  def __init__(self, texts, targets, tokenizer, max_len=512):
    self.texts = texts
    self.targets = targets
    self.tokenizer = tokenizer
    self.max_len = max_len

  def __len__(self):
    return len(self.texts)

  def __getitem__(self, idx):
    text = str(self.texts[idx])
    target = self.targets[idx]

    encoding = self.tokenizer.encode_plus(
        text,
        add_special_tokens=True,
        max_length=self.max_len,
        return_token_type_ids=False,
        padding='max_length',
        return_attention_mask=True,
        return_tensors='pt',
        truncation=True
    )

    return {
      'text': text,
      'input_ids': encoding['input_ids'].flatten(),
      'attention_mask': encoding['attention_mask'].flatten(),
      'targets': torch.tensor(target, dtype=torch.long)
    }

from tqdm import tqdm
import numpy as np
import torch
from transformers import BertTokenizer, BertForSequenceClassification
from torch.utils.data import Dataset, DataLoader
from transformers import AdamW, get_linear_schedule_with_warmup
from sklearn.metrics import precision_recall_fscore_support


class BertClassifier:

    def __init__(self, path, n_classes=2):
        self.path = path
        self.model = BertForSequenceClassification.from_pretrained(path)
        self.tokenizer = BertTokenizer.from_pretrained(path)
        self.device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
        self.max_len = 512
        self.out_features = self.model.bert.encoder.layer[1].output.dense.out_features
        self.model.classifier = torch.nn.Linear(self.out_features, n_classes)
        self.model.to(self.device)

    
    def preparation(self, X_train, y_train, epochs):
        # create datasets
        self.train_set = BertDataset(X_train, y_train, self.tokenizer)
        # create data loaders
        self.train_loader = DataLoader(self.train_set, batch_size=2, shuffle=True)
        # helpers initialization
        self.optimizer = AdamW(
            self.model.parameters(),
            lr=2e-5,
            weight_decay=0.005,
            correct_bias=True
            )
        self.scheduler = get_linear_schedule_with_warmup(
                self.optimizer,
                num_warmup_steps=500,
                num_training_steps=len(self.train_loader) * epochs
            )
        self.loss_fn = torch.nn.CrossEntropyLoss().to(self.device)


    def fit(self):
        self.model = self.model.train()
        losses = []
        correct_predictions = 0

        for data in tqdm(self.train_loader):
            input_ids = data["input_ids"].to(self.device)
            attention_mask = data["attention_mask"].to(self.device)
            targets = data["targets"].to(self.device)

            outputs = self.model(
                input_ids=input_ids,
                attention_mask=attention_mask
                )

            preds = torch.argmax(outputs.logits, dim=1)
            loss = self.loss_fn(outputs.logits, targets)

            correct_predictions += torch.sum(preds == targets)

            losses.append(loss.item())
            
            loss.backward()
            torch.nn.utils.clip_grad_norm_(self.model.parameters(), max_norm=1.0)
            self.optimizer.step()
            self.scheduler.step()
            self.optimizer.zero_grad()

        train_acc = correct_predictions.double() / len(self.train_set)
        train_loss = np.mean(losses)
        return train_acc, train_loss
    

    def train(self, X_train, y_train, X_valid, y_valid, X_test, y_test, epochs=1):
        print('*' * 10)
        print(f'Model: {self.path}')
        self.preparation(X_train, y_train, epochs)
        for epoch in range(epochs):
            print(f'Epoch {epoch + 1}/{epochs}')
            train_acc, train_loss = self.fit()
            print(f'Train loss {train_loss} accuracy {train_acc}')
            predictions_valid = [self.predict(x) for x in X_valid]
            precision, recall, f1score = precision_recall_fscore_support(y_valid, predictions_valid, average='macro')[:3]
            print('Valid:')
            print(f'precision: {precision}, recall: {recall}, f1score: {f1score}')
            predictions_test = [self.predict(x) for x in X_test]
            precision, recall, f1score = precision_recall_fscore_support(y_test, predictions_test, average='macro')[:3]
            print('Test:')
            print(f'precision: {precision}, recall: {recall}, f1score: {f1score}')
        print('*' * 10)
    
    def predict(self, text):
        self.model = self.model.eval()
        encoding = self.tokenizer.encode_plus(
            text,
            add_special_tokens=True,
            max_length=self.max_len,
            return_token_type_ids=False,
            truncation=True,
            padding='max_length',
            return_attention_mask=True,
            return_tensors='pt',
        )
        
        out = {
              'text': text,
              'input_ids': encoding['input_ids'].flatten(),
              'attention_mask': encoding['attention_mask'].flatten()
          }
        
        input_ids = out["input_ids"].to(self.device)
        attention_mask = out["attention_mask"].to(self.device)
        
        outputs = self.model(
            input_ids=input_ids.unsqueeze(0),
            attention_mask=attention_mask.unsqueeze(0)
        )
        
        prediction = torch.argmax(outputs.logits, dim=1).cpu().numpy()[0]

        return prediction

import pandas as pd

train = pd.read_csv('train.csv')
valid = pd.read_csv('valid.csv')
test = pd.read_csv('test.csv')

classifier = BertClassifier(
    path='cointegrated/rubert-tiny',
    n_classes=2
)

classifier.train(
        X_train=list(train['text']),
        y_train=list(train['label']),
        X_valid=list(valid['text']),
        y_valid=list(valid['label']),
        X_test=list(test['text']),
        y_test=list(test['label']),
        epochs=1
)

path = './trainer'
classifier.model.save_pretrained(path)
classifier.tokenizer.save_pretrained(path)

Дистилляция BERT > BiLSTM

Основная идея - приближение BiLSTM-учеником выхода BERT-учителя. Для этого при обучении используем функцию ошибки MSE. Можно использовать совместно с обучением на метках и CrossEntropyLoss. Подробнее можно почитать в статье по ссылке. На моих тестовых данных дистилляция докинула всего пару процентов: F1 чуть больше 0.82.

Код дистилляции
import torch
import torch.nn as nn
import torch.optim as optim

from torch.utils.data import Dataset, DataLoader

from tokenizers import Tokenizer
from tokenizers import ByteLevelBPETokenizer
from tokenizers.pre_tokenizers import Whitespace

from transformers import BertTokenizer, BertForSequenceClassification
from transformers import AdamW, get_linear_schedule_with_warmup

from sklearn.metrics import precision_recall_fscore_support
from sklearn.metrics import f1_score

import numpy as np
import pandas as pd

### data

train = pd.read_csv('train.csv')
test = pd.read_csv('test.csv')

### tokenizer: train

tokenizer = ByteLevelBPETokenizer()
tokenizer.pre_tokenizer = Whitespace()
tokenizer.enable_padding(pad_id=0, pad_token='<pad>')

texts_path = 'texts.txt'

with open(texts_path, 'w') as f:
    for text in list(train['text'].values):
        f.write("%s\n" % text)

tokenizer.train(
    files=[texts_path],
    vocab_size=5_000,
    min_frequency=2,
    special_tokens=['<pad>', '<unk>']
    )

### load BERT tokenizer

tokenizer_bert = BertTokenizer.from_pretrained('./rubert-tiny')

### dataset

class CustomDataset(Dataset):

    def __init__(self, tokens, labels, max_len):
        self.tokens = tokens
        self.labels = labels
        self.max_len = max_len


    def __len__(self):
        return len(self.tokens)


    def __getitem__(self, idx):
        label = self.labels[idx]
        label = torch.tensor(label)
        tokens = self.tokens[idx]
        out = torch.zeros(self.max_len, dtype=torch.long)
        out[:len(tokens)] = torch.tensor(tokens, dtype=torch.long)[:self.max_len]
        return out, label

max_len = 64
BATCH_SIZE = 16

train_labels = list(train['label'])
train_tokens = [tokenizer.encode(str(text)).ids for text in list(train['text'])]
train_dataset = CustomDataset(train_tokens, train_labels, max_len)
train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=False)

test_labels = list(test['label'])
test_tokens = [tokenizer.encode(str(text)).ids for text in list(test['text'])]
test_dataset = CustomDataset(test_tokens, test_labels, max_len)
test_loader = DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=False)

class LSTM_classifier(nn.Module):


    def __init__(self, hidden_dim=128, vocab_size=5000, embedding_dim=300, linear_dim=128, dropout=0.3, n_classes=2):
        super().__init__()
        self.embedding_layer = nn.Embedding(vocab_size, embedding_dim)
        self.lstm_layer = nn.LSTM(embedding_dim, hidden_dim, batch_first=True, bidirectional=True)
        self.dropout_layer = nn.Dropout(dropout)        
        self.fc_layer = nn.Linear(hidden_dim * 2, linear_dim)
        self.batchnorm = nn.BatchNorm1d(linear_dim)
        self.relu = nn.ReLU()
        self.out_layer = nn.Linear(linear_dim, n_classes)


    def forward(self, inputs):
        batch_size = inputs.size(0)
        embeddings = self.embedding_layer(inputs)
        lstm_out, (ht, ct) = self.lstm_layer(embeddings)
        out = ht.transpose(0, 1)
        out = out.reshape(batch_size, -1)
        out = self.fc_layer(out)
        out = self.batchnorm(out)
        out = self.relu(out)
        out = self.dropout_layer(out)
        out = self.out_layer(out)
        out = torch.squeeze(out, 1)
        out = torch.sigmoid(out)
        return out

########

def init_weights(m):
    if isinstance(m, nn.Linear):
        torch.nn.init.xavier_uniform_(m.weight)
        m.bias.data.fill_(0.01)

def eval_nn(model, data_loader):
    predicted = []
    labels = []
    model.eval()
    with torch.no_grad():
        for data in data_loader:
            x, y = data
            x = x.to(device)
            outputs = model(x)
            _, predict = torch.max(outputs.data, 1)
            predict = predict.cpu().detach().numpy().tolist()
            predicted += predict
            labels += y
        score = f1_score(labels, predicted, average='binary')
    return score

def train_nn(model, optimizer, loss_function, train_loader, test_loader, device, epochs=20):
    best_score = 0
    for epoch in range(epochs):
        model.train()
        for inputs, labels in train_loader:
            inputs, labels = inputs.to(device), labels.to(device)
            optimizer.zero_grad()
            predict = model(inputs)
            loss = loss_function(predict, labels)
            loss.backward()
            optimizer.step()
        score = eval_nn(model, test_loader)
        print(epoch, 'valid:', score)
        if score > best_score:
            torch.save(model.state_dict(), 'lstm.pt')
            best_score = score
    return best_score

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

model = LSTM_classifier(hidden_dim=256, vocab_size=5000, embedding_dim=300, linear_dim=128, dropout=0.1)   

model.apply(init_weights);

model.to(device);

optimizer = optim.AdamW(model.parameters())

loss_function = nn.CrossEntropyLoss().to(device)

train_nn(model, optimizer, loss_function, train_loader, test_loader, device, epochs=3)

#####

class DistillDataset(Dataset):

    def __init__(self, texts, labels, tokenizer_bert, tokenizer_lstm, max_len):
        self.texts = texts
        self.labels = labels
        self.tokenizer_bert = tokenizer_bert
        self.tokenizer_lstm = tokenizer_lstm
        self.max_len = max_len


    def __len__(self):
        return len(self.texts)


    def __getitem__(self, idx):
        text = self.texts[idx]
        label = self.labels[idx]
        label = torch.tensor(label)
        # lstm
        tokens_lstm = self.tokenizer_lstm.encode(str(text)).ids
        out_lstm = torch.zeros(self.max_len, dtype=torch.long)
        out_lstm[:len(tokens_lstm)] = torch.tensor(tokens_lstm, dtype=torch.long)[:self.max_len]
        # bert
        encoding = self.tokenizer_bert.encode_plus(
            str(text),
            add_special_tokens=True,
            max_length=self.max_len,
            return_token_type_ids=False,
            truncation=True,
            padding='max_length',
            return_attention_mask=True,
            return_tensors='pt',
        )
        
        out_bert = {
              'input_ids': encoding['input_ids'].flatten(),
              'attention_mask': encoding['attention_mask'].flatten()
        }
        return out_lstm, out_bert, label

train_dataset_distill = DistillDataset(
    list(train['text']),
    list(train['label']),
    tokenizer_bert,
    tokenizer,
    max_len
)

train_loader_distill = DataLoader(train_dataset_distill, batch_size=BATCH_SIZE, shuffle=True)

### BERT-teacher model

class BertTrainer:

    def __init__(self, path_model, n_classes=2):
        self.model = BertForSequenceClassification.from_pretrained(path_model, num_labels=n_classes)
        self.device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
        self.max_len = 512
        self.model.to(self.device)
        self.model = self.model.eval()
    
    def predict(self, inputs):     
        input_ids = inputs["input_ids"].to(self.device)
        attention_mask = inputs["attention_mask"].to(self.device)
        with torch.no_grad():
            outputs = self.model(
                input_ids=input_ids,
                attention_mask=attention_mask
            )
        return outputs.logits

teacher = BertTrainer('./rubert-tiny')

### BiLSTM-student model

class CustomLSTM(nn.Module):


    def __init__(self, hidden_dim=128, vocab_size=5000, embedding_dim=300, linear_dim=128, dropout=0.3, n_classes=2):
        super().__init__()
        self.embedding_layer = nn.Embedding(vocab_size, embedding_dim)
        self.lstm_layer = nn.LSTM(embedding_dim, hidden_dim, batch_first=True, bidirectional=True)
        self.dropout_layer = nn.Dropout(dropout)        
        self.fc_layer = nn.Linear(hidden_dim * 2, linear_dim)
        self.batchnorm = nn.BatchNorm1d(linear_dim)
        self.relu = nn.ReLU()
        self.out_layer = nn.Linear(linear_dim, n_classes)


    def forward(self, inputs):
        batch_size = inputs.size(0)
        embeddings = self.embedding_layer(inputs)
        lstm_out, (ht, ct) = self.lstm_layer(embeddings)
        out = ht.transpose(0, 1)
        out = out.reshape(batch_size, -1)
        out = self.fc_layer(out)
        out = self.batchnorm(out)
        out = self.relu(out)
        out = self.dropout_layer(out)
        out = self.out_layer(out)
#         out = torch.squeeze(out, 1)
#         out = torch.sigmoid(out)
        return out

def loss_function(output, teacher_prob, real_label, a=0.5):
    criterion_mse = torch.nn.MSELoss()
    criterion_ce = torch.nn.CrossEntropyLoss()
    return a * criterion_ce(output, real_label) + (1 - a) * criterion_mse(output, teacher_prob)

def init_weights(m):
    if isinstance(m, nn.Linear):
        torch.nn.init.xavier_uniform_(m.weight)
        m.bias.data.fill_(0.01)

def eval_nn(model, data_loader):
    predicted = []
    labels = []
    model.eval()
    with torch.no_grad():
        for data in data_loader:
            x, y = data
            x = x.to(device)
            outputs = model(x)
            _, predict = torch.max(outputs.data, 1)
            predict = predict.cpu().detach().numpy().tolist()
            predicted += predict
            labels += y
        score = f1_score(labels, predicted, average='binary')
    return labels, predicted, score

def train_distill(model, teacher, optimizer, loss_function, distill_loader, train_loader, test_loader, device, epochs=30, alpha=0.5):
    best_score = 0
    score_list = []
    for epoch in range(epochs):
        model.train()
        for inputs, inputs_teacher, labels in distill_loader:
            inputs, labels = inputs.to(device), labels.to(device)
            optimizer.zero_grad()
            predict = model(inputs)
            teacher_predict = teacher.predict(inputs_teacher)
            loss = loss_function(predict, teacher_predict, labels, alpha)
            loss.backward()
            optimizer.step()
        score_train = round(eval_nn(model, train_loader)[2], 3)
        score_test = round(eval_nn(model, test_loader)[2], 3)
        score_list.append((score_train, score_test))
        print(epoch, score_train, score_test)
        if score_test > best_score:     
            best_score = score_test
            best_model = model
    torch.save(best_model.state_dict(), f'./results/lstm_{best_score}.pt')
    return best_model, best_score, score_list

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

vocab_size = tokenizer.get_vocab_size()

vocab_size

score_alpha = []
for alpha in [0, 0.25, 0.5, 0.75, 1]:
    model = LSTM_classifier(hidden_dim=256, vocab_size=5000, embedding_dim=300, linear_dim=128, dropout=0.1)   

    model.apply(init_weights)
    model.to(device)
    optimizer = optim.AdamW(model.parameters())
    _, _, score_list = train_distill(model, teacher, optimizer, loss_function, train_loader_distill, train_loader, test_loader, device, 30, alpha)
    score_alpha.append(score_list)

import matplotlib.pyplot as plt
import numpy as np

a_list = [1, 0.75, 0.5, 0.25, 0]

for i, score in enumerate(score_alpha):
    _, score_test = list(zip(*score))
    plt.plot(score_test, label=f'{a_list[i]}')
plt.grid(True)
plt.legend()
plt.show()

Дистилляция BERT > tinyBERT

Основная идея, как и в прошлом пункте - приближать учеником поведение учителя. Есть много вариантов что и как приближать, я взял всего два:

  1. Приближать [CLS]-токен по MSE.

  2. Дистилляция распределения токенов по дивергенции Кульбака-Лейблера.

Дополнительно в процессе обучения решаем задачу MLM - предсказание замаскированных токенов. Уменьшение размера модели осуществляется за счет сокращения словаря и уменьшения количества голов внимания, а также количества и размерности скрытых слоев.

Обучение итогового классификатора в итоге делится на 2 этапа:

  1. Обучение языковой модели.

  2. Обучение головы для классификации.

Я применял дистилляцию только для первого этапа, голову для классификации учил уже непосредственно на дистиллированной модели. Думаю можно было накинуть и вариант с MSE как в примере с BiLSTM, но оставил эти эксперименты на потом.

Ключевые моменты реализации:

Сокращение словаря:
from transformers import BertTokenizerFast, BertForPreTraining, BertModel, BertConfig
from collections import Counter
from tqdm.auto import tqdm, trange
import pandas as pd

train = pd.read_csv('train.csv')
X_train=list(train['text'])

tokenizer = BertTokenizerFast.from_pretrained('./rubert-tiny')

cnt = Counter()
for text in tqdm(X_train):
    cnt.update(tokenizer(str(text))['input_ids'])

resulting_vocab = {
    tokenizer.vocab[k] for k in tokenizer.special_tokens_map.values()
}

for k, v in cnt.items():
    if v > 5:
        resulting_vocab.add(k)

resulting_vocab = sorted(resulting_vocab)

tokenizer.save_pretrained('./bert_distill');

inv_voc = {idx: word for word, idx in tokenizer.vocab.items()}

with open('./bert_distill/vocab.txt', 'w', encoding='utf-8') as f:
    for idx in resulting_vocab:
        f.write(inv_voc[idx] + '\n')

Инициализация весов
config = BertConfig(
    emb_size=256,
    hidden_size=256,
    intermediate_size=256,
    max_position_embeddings=512,
    num_attention_heads=8,
    num_hidden_layers=3,
    vocab_size=tokenizer_distill.vocab_size
)

model = BertForPreTraining(config)

model.save_pretrained('./bert_distill')

from transformers import BertModel
# load model without CLS-head
teacher = BertForPreTraining.from_pretrained('./rubert-tiny')

tokenizer_teacher = BertTokenizerFast.from_pretrained('./rubert-tiny')

# copy input embeddings accordingly with resulting_vocab
model.bert.embeddings.word_embeddings.weight.data = teacher.bert.embeddings.word_embeddings.weight.data[resulting_vocab, :256].clone()
model.bert.embeddings.position_embeddings.weight.data = teacher.bert.embeddings.position_embeddings.weight.data[:, :256].clone()

# copy output embeddings
model.cls.predictions.decoder.weight.data = teacher.cls.predictions.decoder.weight.data[resulting_vocab, :256].clone()

MLM-loss
inputs = tokenizer_distill(texts, return_tensors='pt', padding=True, truncation=True, max_length=16)
inputs = preprocess_inputs(inputs, tokenizer_distill, data_collator)
outputs = model(**inputs, output_hidden_states=True)
loss += nn.CrossEntropyLoss(
        outputs.prediction_logits.view(-1, model.config.vocab_size),
        inputs['labels'].view(-1)
    )

KL-loss
def loss_kl(inputs, outputs, model, teacher, vocab_mapping, temperature=1.0):
    new_inputs = torch.tensor(
        [[vocab_mapping[i] for i in row] for row in inputs['input_ids']]
    ).to(inputs['input_ids'].device)
    with torch.no_grad():
        teacher_out = teacher(
            input_ids=new_inputs, 
            token_type_ids=inputs['token_type_ids'],
            attention_mask=inputs['attention_mask']
        )
    # the whole batch, all tokens after the [cls], the whole dimension
    kd_loss = torch.nn.KLDivLoss(reduction='batchmean')(
        F.log_softmax(outputs.prediction_logits[:, 1:, :] / temperature, dim=1), 
        F.softmax(teacher_out.prediction_logits[:, 1:, vocab_mapping] / temperature, dim=1)
    ) / outputs.prediction_logits.shape[-1]
    return kd_loss

MSE-loss
input_teacher = {k: v for k, v in tokenizer_teacher(
        texts,
        return_tensors='pt',
        padding=True,
        max_length=16,
        truncation=True
    ).items()}

with torch.no_grad():
    out_teacher = teacher_mse(**input_teacher)

embeddings_teacher_norm = torch.nn.functional.normalize(out_teacher.pooler_output)

input_distill = {k: v for k, v in tokenizer_distill(
        texts,
        return_tensors='pt',
        padding=True,
        max_length=16,
        truncation=True
    ).items()}

out = model(**input_distill, output_hidden_states=True)
embeddings = model.bert.pooler(out.hidden_states[-1])
embeddings_norm = torch.nn.functional.normalize(adapter_emb(embeddings))
loss = torch.nn.MSELoss(embeddings_norm, embeddings_teacher_norm)

Размер итоговой модели составил 16 Мб, метрика F1 0.86. Учил модель я 12 часов на макбук эйр 19 года с i5 и 8 Гб оперативной памяти. Думаю, если погонять подольше, то и результат будет получше.

Код и данные для обучения представлены на github, замечания, дополнения и исправления приветствуются.

Комментарии (0)