У нас в предзаказе появилась долгожданная книга о библиотеке PyTorch.
Поскольку весь необходимый базовый материал о PyTorch вы узнаете из этой книги, мы напоминаем о пользе процесса под названием «grokking» или «углубленное постижение» той темы, которую вы хотите усвоить. В сегодняшней публикации мы расскажем, как Кай Арулкумаран (Kai Arulkumaran) грокнул PyTorch (без картинок). Добро пожаловать под кат.
PyTorch – это гибкий фреймворк для глубокого обучения, обеспечивающий автоматическое различение объектов при помощи динамических нейронных сетей (то есть, сетей, использующих динамическое управление потоком, например, инструкции
if
и циклы while
). PyTorch поддерживает GPU-ускорение, распределенное обучение, различные виды оптимизации и еще множество других приятных возможностей. Здесь я изложил некоторые мысли о том, как, на мой взгляд, следует использовать PyTorch; здесь не охвачены все аспекты библиотеки и рекомендуемые практики, но, надеюсь, этот текст окажется вам полезен.Нейронные сети – это подкласс вычислительных графов. Вычислительные графы получают на вход данные, далее эти данные маршрутизируются (и могут преобразовываться) на узлах, где и происходит их обработка. В глубоком обучении нейроны (узлы) обычно преобразуют данные, применяя к ним параметры и дифференцируемые функции, так, чтобы параметры можно было оптимизировать для минимизации потерь методом градиентного спуска. В более широком смысле отмечу, что функции могут быть стохастическими, а граф – динамическим. Таким образом, тогда как нейронные сети хорошо вписываются в парадигму программирования потоков данных (dataflow programming), API PyTorch ориентирован на парадигму императивного программирования, а такой способ трактовки создаваемых программ гораздо более привычен. Именно поэтому код PyTorch проще читается, по нему проще судить об устройстве сложных программ, что, однако, не требует серьезно поступаться производительностью: на самом деле, PyTorch достаточно быстр и предусматривает множество оптимизаций, о которых вы, как конечный пользователь, можете совершенно не волноваться (однако, если они вам действительно интересны, можете копнуть поглубже и познакомиться с ними).
Остальная часть этой статьи является разбором официального примера на датасете MNIST. Здесь мы грокаем PyTorch, поэтому разбираться в статье рекомендую только после знакомства с официальными руководствами для начинающих. Для удобства код представлен в виде небольших фрагментов, снабженных комментариями, то есть, не распределен на отдельные функции/файлы, которые вы привыкли видеть в чистом модульном коде.
Импорты
import argparse
import os
import torch
from torch import nn, optim
from torch.nn import functional as F
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
Все это вполне стандартные импорты, за исключением модулей
torchvision
, особенно активно используемых для решения задач, связанных с компьютерным зрением.Настройка
parser = argparse.ArgumentParser(description='PyTorch MNIST Example')
parser.add_argument('--batch-size', type=int, default=64, metavar='N',
help='input batch size for training (default: 64)')
parser.add_argument('--epochs', type=int, default=10, metavar='N',
help='number of epochs to train (default: 10)')
parser.add_argument('--lr', type=float, default=0.01, metavar='LR',
help='learning rate (default: 0.01)')
parser.add_argument('--momentum', type=float, default=0.5, metavar='M',
help='SGD momentum (default: 0.5)')
parser.add_argument('--no-cuda', action='store_true', default=False,
help='disables CUDA training')
parser.add_argument('--seed', type=int, default=1, metavar='S',
help='random seed (default: 1)')
parser.add_argument('--save-interval', type=int, default=10, metavar='N',
help='how many batches to wait before checkpointing')
parser.add_argument('--resume', action='store_true', default=False,
help='resume training from checkpoint')
args = parser.parse_args()
use_cuda = torch.cuda.is_available() and not args.no_cuda
device = torch.device('cuda' if use_cuda else 'cpu')
torch.manual_seed(args.seed)
if use_cuda:
torch.cuda.manual_seed(args.seed)
argparse
– это стандартный способ обращения с аргументами командной строки в Python.Если нужно писать код, рассчитанный на работу на разных устройствах (пользуясь GPU-ускорением, когда оно доступно, но при его отсутствии откатываясь обратно к вычислениям на CPU), то выберите и сохраните подходящий
torch.device
, при помощи которого можно определить, где должны храниться тензоры. Подробнее о создании такого кода см. в официальной документации. Подход PyTorch – отдавать подбор устройств под контроль пользователя, что может показаться нежелательным в простых примерах. Однако, такой подход значительно упрощает работу, когда приходится иметь дело с тензорами, что а) удобно при отладке b) позволяет эффективно использовать устройства вручную.Для воспроизводимости экспериментов необходимо установить случайные начальные значения для всех компонентов, использующих случайную генерацию чисел (в том числе,
random
или numpy
, если и они у вас используются). Обратите внимание: cuDNN использует недетерминированные алгоритмы и по желанию отключается при помощи torch.backends.cudnn.enabled = False
.Данные
data_path = os.path.join(os.path.expanduser('~'), '.torch', 'datasets', 'mnist')
train_data = datasets.MNIST(data_path, train=True, download=True,
transform=transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.1307,), (0.3081,))]))
test_data = datasets.MNIST(data_path, train=False, transform=transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.1307,), (0.3081,))]))
train_loader = DataLoader(train_data, batch_size=args.batch_size,
shuffle=True, num_workers=4, pin_memory=True)
test_loader = DataLoader(test_data, batch_size=args.batch_size,
num_workers=4, pin_memory=True)
Поскольку модели
torchvision
сохраняются под ~/.torch/models/
, я предпочитаю хранить датасеты torchvision
под ~/.torch/datasets
. Это мое авторское соглашение, но им очень удобно пользоваться в проектах, разрабатываемых на базе MNIST, CIFAR-10, т.д. В целом, датасеты следует хранить отдельно от кода, если вы собираетесь переиспользовать несколько датасетов.torchvision.transforms
содержит множество удобных вариантов преобразований для отдельных изображений, например, обрезку и нормализацию.В
DataLoader
есть множество опций, но, кроме batch_size
и shuffle
, также следует иметь в виду num_workers
и pin_memory
, они помогают повысить эффективность. num_workers > 0
использует субпроцессы для асинхронной загрузки данных, а не блокирует под это главный процесс. Типичный пример использования – загрузка данных (например, изображений) с диска и, возможно, их преобразования; все это может делаться параллельно, вместе с сетевой обработкой данных. Степень обработки, возможно, потребуется настроить, чтобы a) минимизировать количество работников и, следовательно, объем использования CPU и RAM (каждый работник загружает отдельную порцию, а не отдельные образцы, входящие в порцию) b) минимизировать длительность ожидания данных в сети. pin_memory
использует закрепленную память (pinned memory) (в противовес подкачиваемой) для ускорения любых операций переноса данных из RAM в GPU (и ничего не делает с кодом, относящимся только к CPU).Модель
class Net(nn.Module):
def __init__(self):
super(Net, self).__init__()
self.conv1 = nn.Conv2d(1, 10, kernel_size=5)
self.conv2 = nn.Conv2d(10, 20, kernel_size=5)
self.conv2_drop = nn.Dropout2d()
self.fc1 = nn.Linear(320, 50)
self.fc2 = nn.Linear(50, 10)
def forward(self, x):
x = F.relu(F.max_pool2d(self.conv1(x), 2))
x = F.relu(F.max_pool2d(self.conv2_drop(self.conv2(x)), 2))
x = x.view(-1, 320)
x = F.relu(self.fc1(x))
x = self.fc2(x)
return F.log_softmax(x, dim=1)
model = Net().to(device)
optimiser = optim.SGD(model.parameters(), lr=args.lr, momentum=args.momentum)
if args.resume:
model.load_state_dict(torch.load('model.pth'))
optimiser.load_state_dict(torch.load('optimiser.pth'))
Сетевая инициализация обычно распространяется на переменные членов, слои, в которых содержатся обучаемые параметры и, может быть, на отдельные обучаемые параметры и необучаемые буферы. Затем при прямом проходе они используются в сочетании с функциями из
F
, чисто функциональными, не содержащими параметров. Некоторым нравится работать с чисто функциональными сетями (напр., держать параметры и использовать F.conv2d
вместо nn.Conv2d
) или сети, целиком состоящие из слоев (напр., nn.ReLU
вместо F.relu
)..to(device)
– удобный способ отправлять параметры устройства (и буферы) на GPU, если в качестве device
задан GPU, так как в противном случае (если в качестве device задан CPU) ничего делаться не будет. Важно перенести параметры устройства на соответствующее устройство, прежде, чем передавать их оптимизатору; в противном случае оптимизатор не сможет правильно отслеживать параметры!И нейронные сети (
nn.Module
), и оптимизаторы (optim.Optimizer
) умеют сохранять и загружать свое внутреннее состояние, и делать это рекомендуется с помощью .load_state_dict(state_dict)
– перезагрузить состояние обоих бывает нужно, чтобы возобновить обучение на основе ранее сохраненных словарей состояний. Сохранение всего объекта целиком может быть чревато ошибками. Если вы сохранили тензоры на GPU и хотите загрузить их на CPU или другой GPU, то проще всего загружать их непосредственно на CPU при помощи опции map_location
, напр., torch.load('model.pth'
, map_location='cpu'
).Вот еще некоторые моменты, не показанные здесь, но заслуживающие упоминания, связаны с тем, что при прямом проходе можно использовать поток управления (напр., выполнение инструкции
if
может зависеть от переменной члена или от самих данных. Кроме того, совершенно допустимо посреди процесса выводить (print
) тензоры, что значительно упрощает отладку. Наконец, при прямом проходе может использоваться множество аргументов. Проиллюстрирую этот момент коротким листингом, не привязанным ни к какой конкретной идее:def forward(self, x, hx, drop=False):
hx2 = self.rnn(x, hx)
print(hx.mean().item(), hx.var().item())
if hx.max.item() > 10 or self.can_drop and drop:
return hx
else:
return hx2
Обучение
model.train()
train_losses = []
for i, (data, target) in enumerate(train_loader):
data = data.to(device=device, non_blocking=True)
target = target.to(device=device, non_blocking=True)
optimiser.zero_grad()
output = model(data)
loss = F.nll_loss(output, target)
loss.backward()
train_losses.append(loss.item())
optimiser.step()
if i % 10 == 0:
print(i, loss.item())
torch.save(model.state_dict(), 'model.pth')
torch.save(optimiser.state_dict(), 'optimiser.pth')
torch.save(train_losses, 'train_losses.pth')
Сетевые модули по умолчанию ставятся в режим обучения – что в определенной степени отражается на работе модулей, больше всего – на прореживании и пакетной нормализации. Так или иначе, лучше задавать такие вещи вручную при помощи
.train()
, который просачивает флаг «training» до всех дочерних модулей. Здесь метод
.to()
не только принимает устройство, но и устанавливает non_blocking=True
, обеспечивая таким образом асинхронное копирование данных на GPU из закрепленной памяти, позволяя CPU сохранять работоспособность при переносе данных; в противном случае non_blocking=True
попросту не вариант.Прежде чем собрать новый набор градиентов при помощи
loss.backward()
и выполнить обратное распространение при помощи optimiser.step()
, необходимо вручную обнулить градиенты оптимизируемых параметров при помощи optimiser.zero_grad()
. По умолчанию PyTorch накапливает градиенты, что очень удобно, если у вас не хватает ресурсов, чтобы вычислить все нужные вам градиенты за один проход.PyTorch использует «магнитофонную» систему автоматических градиентов – собирает информацию о том, какие операции и в каком порядке производились над тензорами, а затем воспроизводит их в обратном направлении, чтобы выполнить дифференциацию в обратном порядке (reverse-mode differentiation). Вот почему он такой супер-гибкий и допускает произвольные вычислительные графы. Если ни один из этих тензоров не требует градиентов (приходится установить
requires_grad=True
, создавая тензор для этой цели), то никакой граф не сохраняется! Однако, у сетей обычно есть параметры, требующие градиентов, поэтому любые вычисления, выполняемые на основе вывода сети, будут сохраняться в графе. Итак, если вы хотите сохранять данные, результирующие после этого шага, то понадобится вручную отключить градиенты или (более распространенный подход), сохранить эту информацию как число Python (при помощи .item()
в скаляре PyTorch) или массив numpy
. Подробнее об autograd
рассказано в официальной документации.Один из способов сократить вычислительный граф — пользоваться
.detach()
, когда проходится скрытое состояние при обучении RNN с усеченной версией backpropagation-through-time. Это также удобно при дифференциации потерь, когда один из компонентов является выводом другой сети, но эта другая сеть не должна оптимизироваться относительно потерь. В качестве примера приведу обучение дискриминативной части на материале вывода генерирующей при работе с GAN, либо обучение политики в алгоритме актор-критик с использованием целевой функции в качестве базовой (напр. A2C). Еще один прием, предотвращающий вычисление градиентов, эффективный при обучении GAN (обучение генерирующей части на материале дискриминативной) и типичный при тонкой настройке – циклический перебор параметров сети, при котором задано param.requires_grad = False
.Важно не только регистрировать результаты в консоли/файле логов, но и ставить контрольные точки в параметрах модели (и состоянии оптимизатора) просто на всякий случай. Также можно пользоваться
torch.save()
для сохранения обычных Python-объектов, либо воспользоваться другим стандартным решением – встроенным pickle
.Тестирование
model.eval()
test_loss, correct = 0, 0
with torch.no_grad():
for data, target in test_loader:
data = data.to(device=device, non_blocking=True)
target = target.to(device=device, non_blocking=True)
output = model(data)
test_loss += F.nll_loss(output, target, reduction='sum').item()
pred = output.argmax(1, keepdim=True)
correct += pred.eq(target.view_as(pred)).sum().item()
test_loss /= len(test_data)
acc = correct / len(test_data)
print(acc, test_loss)
В ответ на
.train()
сети нужно явно переводить в режим оценки (evaluation mode) при помощи .eval()
.Как упоминалось выше, при использовании сети обычно составляется вычислительный граф. Чтобы этого не происходило, пользуйтесь менеджером контекста
no_grad
при помощи with torch.no_grad()
.Еще немного
Это дополнительный раздел, в который я вынес еще несколько полезных отступлений.
Вот официальная документация, поясняющая работу с памятью.
Ошибки CUDA? Исправлять их тяжко, и обычно они связаны с логическими неувязками, по которым на CPU выводятся более вразумительные сообщения об ошибках, чем на GPU. Лучше всего, если, планируя работать с GPU, вы сможете быстро переключаться между CPU и GPU. Более общий совет по разработке – организовать код так, чтобы его можно было быстро проверить перед запуском полноценного задания. Например, подготовьте небольшой или синтетический датасет, прогоните одну эпоху train + test, т.д. Если дело в ошибке CUDA, либо вы совсем никак не можете переключиться на CPU, установите CUDA_LAUNCH_BLOCKING=1. Так запуски ядра CUDA станут синхронными, и вы станете получать более точные сообщения об ошибках.
Замечание о
torch.multiprocessing
или просто об одновременном запуске множества сценариев PyTorch. Поскольку PyTorch использует многопоточные библиотеки BLAS для ускорения вычислений линейной алгебры на CPU, обычно при этом задействовано несколько ядер. Если вы хотите делать несколько вещей одновременно, с использованием многопоточной обработки или нескольких сценариев, может быть целесообразно вручную сократить их количество, установив для переменной окружения OMP_NUM_THREADS
значение 1 или другое невысокое значение. Таким образом снижается вероятность пробуксовки процессора. В официальной документации есть и другие замечания по поводу многопоточной обработки. Комментарии (15)
dim2r
13.10.2019 21:18C первого взгляда как то нелогично работает loss. Непонятно, где он привязывается к нейросети. Такое ощущение, что как-то висит в воздухе и вдруг ты должен вызвать loss.backward(). Ожидал что-то типа model.backward(loss);
Plesser
14.10.2019 09:31Есть такое дело, я когда с этим столкнулся тоже сначала выпал в осадок. Но потом когда осознаешь что target и predict это намного больше чем просто набор данных, все встает на свои места :)
dim2r
14.10.2019 11:13кроме этой нестыковки, все остальное вызвает восторг,
особенно, когда loss можно умножать на -1, что крайне нужно в обучении с подкреплениемPlesser
14.10.2019 12:33Я бы не сказал что это не состыковка, с точки зрения программирования как раз то все ок. Передаем не сами данные а ссылку на объект и наш loss уже с этим объектом работает. Тут дело привычки
Akon32
14.10.2019 11:26Сетевая инициализация обычно распространяется на переменные членов, слои, в которых содержатся обучаемые параметры и, может быть, на отдельные обучаемые параметры и необучаемые буферы.
Какой-то странный перевод фразы
Network initialisation typically includes member variables, layers which contain trainable parameters, and maybe separate trainable parameters and non-trainable buffers.
Я бы перевёл как "Инициализация сети обычно содержит её (сети) переменные, слои, в которых есть обучаемые параметры, и, может быть, отдельные обучаемые параметры и необучаемые буферы". Или даже "при инициализации объекта сети создаются переменные-члены, слои,… ".
Также упоминаются какие-то "сетевые модули", которые, наверно, всё-таки "модули сети".
OneType
Что за дебильное выражение «грокаем»? Неужели в великом и могучем нет его аналогов?
WhiteBlackGoose
ShadowDweller
Почему-то вспомнился когдатошний WoW с его Пандарией и местные приматы-хозены с их «бортельным шнуропсом». Те тоже «грокать» любят.
Shtucer
Хайнлайна не любишь?
YuriM1983
Хайнлайн имеет какое-то отношение к русскому языку? Сепулируешь на Хайнлайна?
vikarti
В данном случае аналоги то может и есть но они не несут того же смысла. Потому что выражение — прямая отсылка на одну (видимо не читанную) книгу Хайнлайна и там оно — подано как заимствование и по тексту — переводить… нельзя без потери большей части смысла. В русском переводе — to grok именно что переведено как «грокать».
b1oki
Или это отсылка на отсылку в названии другой книги издательства «Грокаем алгоритмы. Иллюстрированное пособие для программистов и любопытствующих», автор Бхаргава Адитья.