Привет, Хабр! Меня зовут Арсений, я работаю ML-инженером в компании Вита и параллельно учусь на втором курсе магистратуры AI Talent Hub. В этой статье я хочу поделиться опытом разработки модели для распознавания русского рукописного текста.

Несмотря на всеобщую цифровизацию, огромное количество документов до сих пор заполняется от руки — медицинские карты, заявления, анкеты, почтовые адреса и многое другое. Автоматизация процесса распознавания таких документов может значительно ускорить их обработку и снизить количество ошибок при ручном вводе.

В статье я подробно расскажу о всех этапах создания модели:

  1. Какие данные использовал и где их взял;

  2. Какую архитектуру выбрал и почему;

  3. Как проходил процесс подготовки данных и обучения модели;

  4. Как организовал инференс.

Надеюсь, мой опыт будет полезен тем, кто столкнется с похожей задачей или просто интересуется данной тематикой.

Поехали!

Данные

В качестве данных для обучения использовался Cyrillic Handwriting Dataset с Kaggle. Это набор рукописных текстов на кириллице, специально созданный для задач OCR (Optical Character Recognition). Датасет содержит 73830 примеров и уже разделен на train и test выборки в соотношении 95% и 5% соответственно.

Особенности датасета:

  • Каждый файл представляет собой PNG-изображение с текстом в одну строку;

  • Длина текста не превышает 40 символов;

  • Тексты написаны разными людьми, что обеспечивает разнообразие почерков;

  • Есть как цветные изображения, так и черно-белые;

  • Каждое изображение сопровождается правильной расшифровкой написанного.

Рис. 1: Примеры из Cyrillic Handwriting Dataset.
Рис. 1: Примеры из Cyrillic Handwriting Dataset.

Модель

В качестве базовой архитектуры возьмем trocr-base-handwritten от Microsoft. Это трансформерная модель, которую зафайнтюнили на датасете IAM (примерно такие же картинки с текстом, как и в Cyrillic Handwriting, только на английском). На упомянутом наборе данных эта модель является SOTA

Рис. 2: Архитектура TrOCR из оригинальной статьи
Рис. 2: Архитектура TrOCR из оригинальной статьи

TrOCR — это энкодер-декодерная модель, состоящая из image transformer в качестве энкодера и text transformer в качестве декодера. Энкодер был инициализирован с весами BEiT, в то время как декодер — с весами RoBERTa. А вот веса cross-attention между ними были инициализированы уже случайно.

Изображения представляются модели в виде последовательности патчей фиксированного размера (разрешение 16x16 пикселей). Перед передачей в слои энкодера к последовательности добавляется позиционный эмбеддинг. Затем декодер текста авторегрессионно генерирует токены.

TrOCR статья: https://arxiv.org/abs/2109.10282

TrOCR документация: https://huggingface.co/transformers/master/model_doc/trocr.html

Подготовка данных и дообучение модели

Далее мы разберем как подготовить данные и зафайнтюнить выбранную модель под нашу задачу. Мы будем использовать класс VisionEncoderDecoderModel из библиотеки transformers, который можно использовать для объединения любого image Transformer encoder (например, ViT, BEiT) с любым text Transformer в качестве декодера (например, BERT, RoBERTa, GPT-2). Примером этого является TrOCR, поскольку он имеет архитектуру энкодер-декодера, как уже говорилось ранее.

Установка библиотек

Для начала нам потребуется установить следующие библиотеки:

  • transformers — для работы с моделью;

  • evaluate и jiwer — для расчетов метрики.

Это можно сделать командой:

!pip install -q transformers evaluate jiwer

Мы не будем использовать datasets от HuggingFace для предобработки данных, а воспользуемся старым добрым Dataset из torch.

Теперь проверим, что нам доступна CUDA:

import torch

torch.cuda.is_available()

Если все хорошо, то получим True.

В нашем случае использовалась видеокарточка H100, но подойдет и менее мощная.

Предобработка данных

Перед тем, как обучать модель, надо загрузить и подготовить данные. Давайте этим и займемся!

Сначала мы загрузим данные. Скачать их можно тут. Мы будем использовать только train часть из датасета, так как примеров там предостаточно. Для этого сделаем pandas dataframe из файла train.tsv.

import pandas as pd

train_val_df = pd.read_csv(
   "cyrillic-handwriting-dataset/train.tsv",
   sep="\t",
   header=None,
   names=["file_name", "text"],
)

train_val_df.head()

file_name

text

0

aa1.png

Молдова

1

aa1007.png

продолжила борьбу

2

aa101.png

разработанные

3

aa1012.png

Плачи

4

aa1013.png

Гимны богам

В таблице присутствует одна строка с пустым text, из-за которой может поломаться все обучение, поэтому дропнем ее.

train_val_df = train_val_df.dropna()

Следующим шагом разделим данные на трейн + валидацию в пропорции 80/20 через train_test_split из sklearn.

from sklearn.model_selection import train_test_split

train_df, eval_df = train_test_split(train_val_df, test_size=0.2)
# we reset the indices to start from zero
train_df.reset_index(drop=True, inplace=True)
eval_df.reset_index(drop=True, inplace=True)

Каждый элемент итогового датасета должен возвращать:

  1. pixel_values (значения пикселей исходных изображений) — это будет input для модели;

  2. labels — input_ids соответствующего текста на изображении.

Мы будем использовать TrOCRProcessor для приведения данных в нужные форматы. TrOCRProcessor — это обертка над ViTFeatureExtractor и RobertaTokenizer. Первый будет использоваться для ресайзинга и нормализации изображений. Второй - для кодирования и декодирования текста в/из input_ids.

Штош, давайте напишем класс нашего датасета. Для этого нам нужно наследоваться от torch.utils.data.Dataset и определить 3 метода.

  1. Метод __init__ для инициализации экземпляра класса. В качестве входных параметров будем предавать следующее:

    1. root_dir — путь к директории, где хранятся изображения;

    2. df — pd.DataFrame, содержащий столбцы file_name и text;

    3. processor —  предобученный TrOCRProcessor;

    4. max_target_length — максимальная длина для текстовых меток (labels).

  2. Метод __len__, который будет возвращать количество элементов в датасете.

  3. Метод __getitem__:

    • Принимает на вход индекс idx;

    • Извлекает имя файла, изображения и соответствующий текст из df по индексу idx;

    • Загружает изображение из root_dir и конвертирует его в формат RGB;

    • Обрабатывает изображение с помощью processor, который возвращает тензоры, представляющие пиксели изображения;

    • Кодирует текст в input_ids с помощью токенизатора, который является частью processor, метки заполняются pad-токенами до max_target_length;

    • Заменяет токены pad-токены на -100, чтобы они игнорировались функцией потерь во время обучения;

    • Возвращает словарь encoding, содержащий pixel_values и labels.

import torch
from torch.utils.data import Dataset
from PIL import Image


class CHDataset(Dataset):
   def __init__(self, root_dir, df, processor, max_target_length=128):
       self.root_dir = root_dir
       self.df = df
       self.processor = processor
       self.max_target_length = max_target_length

   def __len__(self):
       return len(self.df)

   def __getitem__(self, idx):
       # get file name + text
       file_name = self.df["file_name"][idx]
       text = self.df["text"][idx]
       # prepare image (i.e. resize + normalize)
       image = Image.open(self.root_dir + file_name).convert("RGB")
       pixel_values = self.processor(image, return_tensors="pt").pixel_values
       # add labels (input_ids) by encoding the text
       labels = self.processor.tokenizer(
           text, padding="max_length", max_length=self.max_target_length
       ).input_ids
       # important: make sure that PAD tokens are ignored by the loss function
       labels = [
           label if label != self.processor.tokenizer.pad_token_id else -100
           for label in labels
       ]

       encoding = {
           "pixel_values": pixel_values.squeeze(),
           "labels": torch.tensor(labels),
       }
       return encoding

Давайте инициализируем наборы данных для обучения и оценки:

from transformers import TrOCRProcessor

processor = TrOCRProcessor.from_pretrained("microsoft/trocr-base-handwritten")

train_dataset = CHDataset(
   root_dir="cyrillic-handwriting-dataset/train/",
   df=train_df,
   processor=processor,
)
eval_dataset = CHDataset(
   root_dir="cyrillic-handwriting-dataset/train/",
   df=eval_df,
   processor=processor,
)

Посмотрим на количество примеров в обеих подвыборках:

print("Number of training examples:", len(train_dataset))
print("Number of validation examples:", len(eval_dataset))

Number of training examples: 57827
Number of validation examples: 14457

И поймем какие размерности у их элементов:

encoding = train_dataset[0]
for k, v in encoding.items():
   print(k, v.shape)

pixel_values torch.Size([3, 384, 384])
labels torch.Size([128])

Получили, что pixel_values — это тензор размером [3, 384, 384]. Здесь 3 обозначает количество цветовых каналов в формате RGB, а 384x384 — это размеры изображения после изменения его размера (ресайзинга). А labels — тензор размером [128] с учетом pad-токенов (128, потому что указали такой max_target_length).

Данные готовы и наконец-то можно приступать к дообучению.

Файнтюнинг модели

Здесь мы инициализируем модель TrOCR с её предобученными весами. Обратите внимание, что веса головы языкового моделирования инициализированы из претрейна, так как модель уже была обучена генерировать текст на этапе предобучения. Подробности можно найти в статье.

from transformers import VisionEncoderDecoderModel

model = VisionEncoderDecoderModel.from_pretrained("microsoft/trocr-base-handwritten")

Важно установить несколько атрибутов, а именно:

  • атрибуты, необходимые для создания decoder_input_ids из labels (модель автоматически создаст decoder_input_ids, сдвинув labels на одну позицию вправо и добавив decoder_start_token_id в начало, а также заменив идентификаторы, равные -100, на pad_token_id)

  • размер словаря модели (для языковой модели, расположенной поверх декодера)

  • параметры beam-search, которые используются при генерации текста.

# set special tokens used for creating the decoder_input_ids from the labels
model.config.decoder_start_token_id = processor.tokenizer.cls_token_id
model.config.pad_token_id = processor.tokenizer.pad_token_id
# make sure vocab size is set correctly
model.config.vocab_size = model.config.decoder.vocab_size

# set beam search parameters
model.config.eos_token_id = processor.tokenizer.sep_token_id
model.config.max_length = 64
model.config.early_stopping = True
model.config.no_repeat_ngram_size = 3
model.config.length_penalty = 2.0
model.config.num_beams = 4

Далее определим некоторые гиперпараметры обучения, создав экземпляр training_args. Важно отметить, что существует множество других параметров, полный список которых можно найти в документации. Например, вы можете задать размер батча для обучения/оценки, решить, использовать ли mixed precision training, установить частоту сохранения модели и т.д.

from transformers import Seq2SeqTrainer, Seq2SeqTrainingArguments

training_args = Seq2SeqTrainingArguments(
   num_train_epochs=5,
   predict_with_generate=True,
   eval_strategy="epoch",
   save_strategy="epoch",
   per_device_train_batch_size=64,
   per_device_eval_batch_size=64,
   fp16=True,
   output_dir="./trained_models",
)

Мы будем оценивать модель по Character Error Rate (CER) (подробнее см. здесь).

import evaluate

cer_metric = evaluate.load("cer")

Функция compute_metrics принимает на вход EvalPrediction (который является NamedTuple) и должна возвращать словарь. На этапе оценки модель вернёт EvalPrediction, который состоит из двух элементов:

  • predictions: предсказания, сделанные моделью.

  • label_ids: фактические истинные метки.

def compute_metrics(pred):
   labels_ids = pred.label_ids
   pred_ids = pred.predictions

   pred_str = processor.batch_decode(pred_ids, skip_special_tokens=True)
   labels_ids[labels_ids == -100] = processor.tokenizer.pad_token_id
   label_str = processor.batch_decode(labels_ids, skip_special_tokens=True)

   cer = cer_metric.compute(predictions=pred_str, references=label_str)

   return {"cer": cer}

Давайте начнем обучение! Мы также предоставляем default_data_collator для Trainer, который используется для объединения примеров в батчи.

Следует учесть, что оценка занимает довольно много времени, так как мы используем beam search для декодирования, что требует нескольких прямых проходов для каждого примера.

from transformers import default_data_collator

# instantiate trainer
trainer = Seq2SeqTrainer(
   model=model,
   processing_class=processor.tokenizer,
   args=training_args,
   compute_metrics=compute_metrics,
   train_dataset=train_dataset,
   eval_dataset=eval_dataset,
   data_collator=default_data_collator,
)
trainer.train()

После обучения мы получили вот такие вот результаты:

  • Training Loss: 0.026100;

  • Validation Loss: 0.120961;

  • CER: 0.048542.

Вполне неплохой результат.
Осталось только сохранить модель, например, локально для дальнейшего использования.

processor.save_pretrained("model/TrOCRInferenceModel/weights")
model.save_pretrained("model/TrOCRInferenceModel/weights")

Инференс

После обучения мы можем легко загрузить модель, используя метод .from_pretrained(output_dir), так как на предыдущем шаге мы сохранили ее.

Загрузим модель и изображение из тестовой выборки:

from transformers import TrOCRProcessor, VisionEncoderDecoderModel
from PIL import Image
import requests

image = Image.open(
   "cyrillic-handwriting-dataset/test/test4.png"
).convert("RGB")

processor = TrOCRProcessor.from_pretrained(
   "model/TrOCRInferenceModel/weights"
)
model = VisionEncoderDecoderModel.from_pretrained(
   "model/TrOCRInferenceModel/weights"
)

Посмотрим, что там за текст на картинке.

Тут написано слово «Определим». Конечно не с первого раза, но разобрать буквы можно. Правда в данном случае буква «п» больше похоже на «й». Посмотрим как с этим справится наша модель.

pixel_values = processor(images=image, return_tensors="pt").pixel_values

generated_ids = model.generate(pixel_values)
generated_text = processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
generated_text

«Определим»

Супер! Модель распознала слово без ошибок.

Вывод

В этой статье мы разобрали как можно дообучить модель TrOCR под задачу распознавания русского рукописного текста и какие данные для этого могут использоваться. Модель получилась довольно точной и может использоваться в реальных кейсах.

P.S.: Свою дообученную модель я разместил по ссылке. Буду только рад, если она кому-то будет полезна.

Материал подготовил магистрант 2 курса AI Talent Hub, Арсений Казанцев.

Комментарии (2)


  1. Squoworode
    12.12.2024 06:18

    А где классические шиншиллы или "лишили лилии"?


  1. ENick
    12.12.2024 06:18

    "Здесь 3 обозначает количество цветовых каналов в формате RGB" - а почему не grey, в 3-и раза меньше считать/обучаться

    "После обучения мы получили вот такие вот результаты: " - без демонстрации графиков обучения это мало о чём говорит