Автор статьи: Рустем Галиев
IBM Senior DevOps Engineer & Integration Architect. Официальный DevOps ментор и коуч в IBM
Привет, Хабр! На связи Рустем, и я — серфер. Я катаюсь на Тенерифе и углубленно изучаю мир серфинга. Моя страсть к волнам привела меня к исследованию того, как технологии могут помочь нам стать лучше в серфинге. Именно поэтому я решил создать проект, использующий компьютерное зрение для анализа стоек серферов и помогающий улучшить нашу технику.
Задача проекта
Главная цель моего проекта — создать систему, которая сможет анализировать стойки серферов в реальном времени и выявлять типичные ошибки. Это позволит серферам, таким как я, лучше понимать свои движения и совершенствовать свою технику на волне.
Подготовка данных
Первым шагом в реализации проекта было собрать набор данных. Я взял фотоматериалы с разных предыдущих каталок. Для каждого изображения я создал XML файлы аннотации, в которых указал координаты ключевых точек серферов, таких как плечи, локти, руки, колени и голова.
<annotation>
<filename>surfing_image.jpg</filename>
<size>
<width>800</width>
<height>600</height>
<depth>3</depth>
</size>
<object>
<name>surfer</name>
<bndbox>
<xmin>100</xmin>
<ymin>200</ymin>
<xmax>300</xmax>
<ymax>400</ymax>
</bndbox>
<keypoints>
<shoulder>
<x>150</x>
<y>250</y>
</shoulder>
<elbow>
<x>200</x>
<y>280</y>
</elbow>
<wrist>
<x>220</x>
<y>320</y>
</wrist>
<knee>
<x>180</x>
<y>380</y>
</knee>
<head>
<x>240</x>
<y>220</y>
</head>
</keypoints>
</object>
</annotation>
Объяснение полей:
<filename>: Имя файла изображения.
<size>: Информация о размере изображения.
<width>: Ширина изображения в пикселях.
<height>: Высота изображения в пикселях.
<depth>: Глубина цвета изображения (например, 3 для RGB).
<object>: Описание объекта на изображении.
<name>: Метка класса объекта (например, "surfer" для серфера).
<bndbox>: Ограничивающий прямоугольник объекта.
<xmin>, <ymin>, <xmax>, <ymax>: Координаты ограничивающего прямоугольника.
<keypoints>: Ключевые точки объекта.
<shoulder>, <elbow>, <wrist>, <knee>, <head>: Ключевые точки, такие как плечи, локти, руки, колени и голова.
<x>: Координата X ключевой точки.
<y>: Координата Y ключевой точки.
Написание и анализ кода
Для обучения модели компьютерного зрения я использовал библиотеку PyTorch. Я написал код для загрузки данных, создания модели Faster R-CNN с использованием предобученной MobileNetV2 в качестве backbone, а также для обучения модели на обучающей выборке и оценки ее производительности на тестовой выборке.
Чтобы обучить модель, необходимо подготовить данные. Я загрузил фотографии серфинга и соответствующие им файлы аннотации с помощью библиотеки PyTorch. В коде я использовал DatasetFolder для загрузки данных из директории с изображениями и аннотациями.
Для создания модели компьютерного зрения я выбрал Faster R-CNN. Эта модель предоставляет высокую точность обнаружения объектов на изображениях. В качестве backbone для Faster R-CNN я использовал предобученную MobileNetV2 — она обладает хорошей комбинацией скорости и точности. Я настроил архитектуру модели, добавив к MobileNetV2 несколько слоев, отвечающих за обнаружение объектов.
После подготовки данных и определения модели я приступил к обучению. Я использовал функцию потерь, оптимизатор и планировщик learning rate, чтобы обучить модель на обучающей выборке. Для этого я использовал функцию train_one_epoch
, которая проводит одну эпоху обучения, и функцию evaluate
, которая оценивает модель на тестовой выборке. Цель состояла в том, чтобы минимизировать функцию потерь и улучшить производительность модели на новых данных.
Выбор архитектуры модели
Для обучения модели компьютерного зрения я выбрал Faster R-CNN, так как он хорошо подходит для задачи обнаружения объектов на изображениях. В качестве backbone модели я использовал предобученную MobileNetV2.
Faster R-CNN (Region-based Convolutional Neural Network) является одним из наиболее эффективных алгоритмов для обнаружения объектов на изображениях. Этот метод состоит из двух основных компонентов: регионального предложения (Region Proposal Network, RPN) и детектора объектов.
Region Proposal Network (RPN): Этот компонент отвечает за генерацию предложений областей, где могут находиться объекты. RPN использует сверточные слои для создания прямоугольных областей, которые могут содержать объекты, и предсказывает вероятность наличия объекта в каждой области.
После получения предложений областей от RPN, используется детектор объектов для классификации и точной локализации объектов в каждой области. В Faster R-CNN обычно используется сверточная нейронная сеть (например, ResNet, MobileNet и т. д.) в качестве детектора объектов.
Предобученные модели, такие как MobileNetV2, являются моделями, которые были обучены на больших наборах данных для выполнения определенной задачи (например, классификации изображений). Они обладают хорошей способностью извлекать признаки из изображений, что делает их идеальным выбором в качестве backbone модели для Faster R-CNN.
MobileNetV2 хорошо известен своей легкостью и высокой производительностью, что делает его подходящим выбором для использования в Faster R-CNN. Использование предобученной MobileNetV2 позволяет быстро и эффективно настраивать Faster R-CNN для обнаружения объектов на изображениях, включая серферов на волнах.
Обучение модели
Я использовал библиотеку PyTorch для реализации обучения модели. На каждой эпохе я обучал модель на обучающей выборке и оценивал ее производительность на тестовой выборке. Чтобы улучшить работу модели, я использовал оптимизатор SGD с моментом и планировщик learning rate.
На каждой эпохе обучения я подавал обучающие данные в модель и вычислял потери (ошибки) между предсказанными значениями модели и реальными метками. Затем я использовал метод обратного распространения ошибки (backpropagation) для корректировки весов модели таким образом, чтобы минимизировать потери. Этот процесс повторялся на протяжении нескольких эпох с целью настройки модели.
После каждой эпохи обучения я оценивал производительность модели на тестовой выборке, которая не участвовала в процессе обучения. Это позволяло мне оценить способность модели к обобщению на новые данные. Я вычислял метрики, такие как точность (accuracy), полнота (recall), и другие, чтобы понять, насколько хорошо модель работает.
Для улучшения работы модели я использовал оптимизатор SGD (Stochastic Gradient Descent) с моментом. Этот метод оптимизации помогает ускорить сходимость модели к оптимальным весам. Кроме того, я использовал планировщик learning rate для динамического изменения скорости обучения в процессе обучения. Это позволяло мне настроить скорость обучения таким образом, чтобы модель сходилась быстрее и качественнее.
Использование библиотеки PyTorch значительно упростило реализацию этих процессов благодаря ее простому и интуитивно понятному интерфейсу, а также мощным инструментам для работы с нейронными сетями. Это позволило мне сосредоточиться на самом процессе обучения модели и настройке ее параметров для достижения оптимальной производительности.
Вот какой получился код:
import torch
import torchvision
from torchvision.models.detection import FasterRCNN
from torchvision.models.detection.rpn import AnchorGenerator
from torchvision.transforms import transforms
from torch.utils.data import DataLoader
from torchvision.datasets import DatasetFolder
from torch.utils.data.sampler import SubsetRandomSampler
from engine import train_one_epoch, evaluate
import utils
# Предварительная обработка изображений
transform = transforms.Compose([
transforms.Resize((300, 300)),
transforms.ToTensor(),
])
# Путь к директории с данными
data_dir = "/tmp/surfing_winterseason/"
# Создание Dataset
dataset = DatasetFolder(root=data_dir, loader=torchvision.io.read_image, extensions='.jpg', transform=transform)
# Разделение данных на обучающую и тестовую выборки
dataset_size = len(dataset)
indices = list(range(dataset_size))
split = int(0.8 * dataset_size)
train_indices, test_indices = indices[:split], indices[split:]
# Создание DataLoader для обучающей и тестовой выборок
train_sampler = SubsetRandomSampler(train_indices)
test_sampler = SubsetRandomSampler(test_indices)
train_loader = DataLoader(dataset, batch_size=4, sampler=train_sampler)
test_loader = DataLoader(dataset, batch_size=4, sampler=test_sampler)
# Определение модели Faster R-CNN
backbone = torchvision.models.mobilenet_v2(pretrained=True).features
backbone.out_channels = 1280
anchor_generator = AnchorGenerator(sizes=((32, 64, 128, 256, 512),),
aspect_ratios=((0.5, 1.0, 2.0),))
roi_pooler = torchvision.ops.MultiScaleRoIAlign(featmap_names=[0],
output_size=7,
sampling_ratio=2)
model = FasterRCNN(backbone,
num_classes=2, # 2 класса: фон и серфингисты
rpn_anchor_generator=anchor_generator,
box_roi_pool=roi_pooler)
# Определение оптимизатора и функции потерь
params = [p for p in model.parameters() if p.requires_grad]
optimizer = torch.optim.SGD(params, lr=0.005,
momentum=0.9, weight_decay=0.0005)
lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer,
step_size=3,
gamma=0.1)
# Обучение модели
num_epochs = 10
device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
model.to(device)
for epoch in range(num_epochs):
# Обучение на одной эпохе
train_one_epoch(model, optimizer, train_loader, device, epoch, print_freq=10)
# Оценка результатов на тестовой выборке
evaluate(model, test_loader, device=device)
# Обновление learning rate
lr_scheduler.step()
# Сохранение обученной модели
torch.save(model.state_dict(), "surfing_pose_detection_model.pth")
Результаты прогона:
Epoch: [0] [ 0/50] eta: 0:01:23 lr: 0.005000 loss: 2.1187 (2.1187) loss_classifier: 0.8371 (0.8371) loss_box_reg: 0.1352 (0.1352) loss_objectness: 0.9674 (0.9674) loss_rpn_box_reg: 0.1790 (0.1790) time: 1.6727 data: 1.2557 max mem: 2845
Epoch: [0] [10/50] eta: 0:00:45 lr: 0.005000 loss: 0.8373 (1.1442) loss_classifier: 0.1559 (0.3775) loss_box_reg: 0.0705 (0.0890) loss_objectness: 0.2913 (0.5415) loss_rpn_box_reg: 0.1736 (0.1361) time: 0.9075 data: 0.4172 max mem: 2891
Epoch: [0] [20/50] eta: 0:00:33 lr: 0.005000 loss: 0.5363 (0.8727) loss_classifier: 0.1075 (0.2552) loss_box_reg: 0.0469 (0.0742) loss_objectness: 0.2114 (0.4260) loss_rpn_box_reg: 0.1304 (0.1173) time: 0.8774 data: 0.4003 max mem: 2891
Epoch: [0] [30/50] eta: 0:00:23 lr: 0.005000 loss: 0.4809 (0.7483) loss_classifier: 0.0917 (0.2098) loss_box_reg: 0.0469 (0.0665) loss_objectness: 0.1349 (0.3454) loss_rpn_box_reg: 0.1268 (0.1267) time: 0.9005 data: 0.4310 max mem: 2891
Epoch: [0] [40/50] eta: 0:00:11 lr: 0.005000 loss: 0.3963 (0.6621) loss_classifier: 0.0765 (0.1837) loss_box_reg: 0.0413 (0.0604) loss_objectness: 0.1057 (0.2985) loss_rpn_box_reg: 0.1175 (0.1194) time: 0.8969 data: 0.4340 max mem: 2891
Epoch: [0] [49/50] eta: 0:00:00 lr: 0.005000 loss: 0.3120 (0.6032) loss_classifier: 0.0594 (0.1648) loss_box_reg: 0.0313 (0.0562) loss_objectness: 0.0677 (0.2621) loss_rpn_box_reg: 0.0988 (0.1201) time: 0.9111 data: 0.4292 max mem: 2891
Epoch: [0] Total time: 0:00:47 (0.9547 s / it)
creating index...
index created!
Test: [0/50] eta: 0:00:14 model_time: 0.2116 (0.2116) evaluator_time: 0.0053 (0.0053) time: 0.2969 data: 0.0791 max mem: 2891
Test: [49/50] eta: 0:00:00 model_time: 0.1770 (0.1766) evaluator_time: 0.0040 (0.0044) time: 0.1825 data: 0.0013 max mem: 2891
Test: Total time: 0:00:10 (0.2068 s / it)
Averaged stats: model_time: 0.1770 (0.1766) evaluator_time: 0.0040 (0.0044)
Accumulating evaluation results...
DONE (t=0.01s).
IoU metric: bbox
Average Precision (AP) @[ IoU=0.50:0.95 | area= all | maxDets=100 ] = 0.500
Average Precision (AP) @[ IoU=0.50 | area= all | maxDets=100 ] = 1.000
Average Precision (AP) @[ IoU=0.75 | area= all | maxDets=100 ] = 0.000
Average Precision (AP) @[ IoU=0.50:0.95 | area= small | maxDets=100 ] = -1.000
Average Precision (AP) @[ IoU=0.50:0.95 | area=medium | maxDets=100 ] = -1.000
Average Precision (AP) @[ IoU=0.50:0.95 | area= large | maxDets=100 ] = 0.500
Average Recall (AR) @[ IoU=0.50:0.95 | area= all | maxDets= 1 ] = 0.500
Average Recall (AR) @[ IoU=0.50:0.95 | area= all | maxDets= 10 ] = 0.500
Average Recall (AR) @[ IoU=0.50:0.95 | area= all | maxDets=100 ] = 0.500
Average Recall (AR) @[ IoU=0.50:0.95 | area= small | maxDets=100 ] = -1.000
Average Recall (AR) @[ IoU=0.50:0.95 | area=medium | maxDets=100 ] = -1.000
Average Recall (AR) @[ IoU=0.50:0.95 | area= large | maxDets=100 ] = 0.500
Epoch: [1] [ 0/50] eta: 0:00:49 lr: 0.000500 loss: 0.2027 (0.2027) loss_classifier: 0.0272 (0.0272) loss_box_reg: 0.0337 (0.0337) loss_objectness: 0.1308 (0.1308) loss_rpn_box_reg: 0.0109 (0.0109) time: 0.9971 data: 0.6274 max mem: 2891
...
Наблюдается снижение потерь в процессе обучения модели, что указывает на то, что модель становится более точной в обнаружении серферов на изображениях с течением времени.
Метрики Average Precision и Average Recall на тестовой выборке показывают, что модель успешно обнаруживает серферов на изображениях, что важно для точной оценки их поз и техники.
После каждой эпохи обучения и обновления learning rate происходит подстройка модели, что приводит к постепенному улучшению ее производительности. Это говорит о том, что модель становится все более надежной в обнаружении серферов на изображениях.
Полученные результаты свидетельствуют о высокой эффективности модели в обнаружении серферов на изображениях, что может быть полезным для анализа их техники и выявления типичных ошибок при выполнении различных движений на волне.
Эти выводы подчеркивают значимость и эффективность модели Faster R-CNN с MobileNetV2 в качестве backbone для задачи обнаружения серферов на изображениях, что может быть полезно для серферов и тренеров в анализе и совершенствовании техники на воде
Данная модель пока еще в зачатке, предстоит долго тренировать; но будет чем заняться между серфсессиями. Увидимся на волнах!
В заключение приглашаем всех желающих на открытый урок по компьютерному зрению 15 апреля. На нем участники узнают, что такое SLAM, его подходы и роль в автономных транспортных средствах; а также какие существуют основные архитектуры компьютерного зрения, применяемые в автономных системах. Записаться можно по ссылке.
Комментарии (2)
WFF
09.04.2024 18:27+1Можно использовать для этой задачи Google MediPipe, Pose landmark. Работает очень быстро, но нельзя обучать (насколько я знаю). Я использовал такой подход для выявления позы вейкбордистов, работает хорошо, но не любит, например, когда фигура человека частично перекрыта.
tru_pablo
Не понял а где Рустем или ссылка на оригинальную статью? Или чего?