Мы уже писали о том, что предложили новую модель квантования нейронных сетей, позволяющую ускорить их на 40% на центральных процессорах, а также о том, как она устроена вот тут.
Сегодня мы расскажем о том, как мы в Smart Engines обучали 4.6-битные сети. Основная проблема обучения квантованных моделей в том, что градиентные методы модифицируют веса непрерывным образом, а у квантованных сетей они дискретные.
Кроме того, распространен сценарий, когда сначала обучается вещественная сеть, а затем ее хочется отквантовать, чтобы ускорить систему. Часто обучающие данные при этом уже недоступны или доступны не в полном объеме.
Поэтому методы квантования делятся на 2 большие группы:
Post Training Quantization (PTQ): когда мы квантуем уже обученную сеть. Как правило, они минимизируют ошибку между выходами каждого слоя вещественной и квантованной сетей. Такие методы хорошо работают на больших моделях, в которых есть избыточные веса, а в случае компактных сетей для центральных процессоров отквантоваться без снижения точности обычно не получается.
Quantization Aware Training (QAT): когда мы уже во время обучения применяем методы, повышающие качество квантованной модели.
В Smart Engines все модели уже компактные, а обучающие данные никуда не деваются, так что нам в первую очередь интересны QAT методы или их комбинации с PTQ методами. Расскажем о них поподробнее.
Метод обучения
Самый очевидный способ обучить квантованную сеть – игнорировать квантование во время обратного прохода по сети, а модифицированные веса округлить. Тогда прямой проход будет уже квантованным и соответствовать “рабочему” режиму сети. Так работает Straight Through Estimator (STE), один из старейших и популярных методов обучения квантованных сетей. На удивление, он неплохо работает, если взять уже обученную 8-битную модель и ее постепенно квантовать, то есть 256 значений для сети оказывается не так уж и мало. Однако округление на каждой итерации делает процесс обучения “шумным” и может привести к невозможности обучить сеть вообще, что и наблюдается в случае меньших разрядностей.
Поэтому даже для 8-битных моделей есть методы получше, например, инкрементное обучение.
Идея инкрементного обучения или понейронной конвертации заключается в том, что мы постепенно преобразуем часть весов каждого слоя в квантованный вид. Уже квантованная часть сети немного дообучается (например, с помощью STE) и “замораживается” (веса фиксируются и больше не меняются). Вещественная часть сети дообучается, чтобы адекватно обрабатывать изменившиеся входные значения от квантованной части. Этот подход прост, эффективен в отношении 4- и 8-битных сетей, но требует времени.
AdaQuant – PTQ метод, работающий послойно. Его идея которого заключается в том, что мы используем небольшой калибровочный датасет, чтобы определить динамические диапазоны активаций, а также коэффициент масштабирования и точку нуля так, чтобы минимизировать среднеквадратичную ошибку между выходом квантованного и вещественного слоев.
В данной работе мы использовали комбинацию этих подходов. Сначала параметры каждого слоя отквантовали с помощью Adaquant, а затем применили инкрементное обучение, а именно прошли от начала к концу и еще немного дообучили каждый из квантованных слоев (остальные “замораживались”) с помощью стохастического градиентного спуска с моментом 0.9 и скоростью обучения (learning rate) .
Далее мы рассмотрели три задачи компьютерного зрения, решаемые с помощью разных нейросетевых архитектур, и применили в них 4.6-битные сети. Кстати, наш код экспериментов есть на GitHub: https://github.com/SmartEngines/QNN_training_4.6bit
CIFAR-10
Это давно набившая оскомину небольшая выборка с цветными изображениями объектов 10 классов размера 32 на 32 (см. примеры на рис. 1). В качестве аугментации использовали отражения вдоль вертикальной оси, вырезание случайных регионов со сдвигами и случайные повороты на +-9 градусов с дополнением изображений на 4 пикселя с краев.
На ней мы обучили несколько простых сверточных моделей. Их архитектуры приведены в таблице 1.
Обозначения:
conv(c, f, k, [p]) – сверточный слой с c-канальным входом, f фильтрами размера k х k и паддингом p (в обоих направлениях, по умолчанию 0),
pool(n) – 2D max-pooling с окном n на n,
bn – batch normalization,
HardTanh – активация HardTanh(x) = min(1, max(-1, x)),
relu6 – активация ReLU6(x) = min(max(0, x),6),
tanh – гиперболический тангенс,
fc(n) – полносвязный слой с n выходами.
Таблица 1. Архитектуры сверточных сетей для CIFAR-10.
CNN6 |
CNN7 |
CNN8 |
CNN9 |
CNN10 |
conv(3, 4, 1) HardTanh |
conv(3, 8, 1) HardTanh |
conv(3, 8, 1) HardTanh |
conv(3, 8, 1) HardTanh |
conv(3, 8, 1) HardTanh |
conv(4, 8, 5) bn+relu6 pool(2) |
conv(8, 8, 3) bn+relu6 conv(8, 12, 3) bn+relu6 pool(2) |
conv(8, 8, 3) bn+relu6 conv(8, 12, 3) bn+relu6 pool(2) |
conv(8, 8, 3) bn+relu6 conv(8, 12, 3) bn+relu6 pool(2) |
conv(8, 16, 3, 1) bn+relu6 conv(16, 32, 3, 1) bn+relu6 pool(2) |
conv(8, 16, 3) bn+relu6 pool(2) |
conv(12, 16, 3) bn+relu6 pool(2) |
conv(12, 24, 3) bn+relu6 pool(2) |
conv(12, 12, 3, 1) bn+relu6 conv(12, 24, 3) bn+relu6 pool(2) |
conv(32, 32, 3, 1) bn+relu6 conv(32, 64, 3, 1) bn+relu6 pool(2) |
conv(16, 32, 3) bn+relu6 pool(2) |
conv(16, 32, 3) bn+relu6 pool(2) |
conv(24, 24, 3) bn+relu6 conv(24, 40, 3) bn+relu6 |
conv(24, 24, 3) bn+relu6 conv(24, 48, 3) bn+relu6 |
conv(64, 64, 3) bn+relu6 conv(64, 64, 3) bn+relu6 conv(64, 128, 3) bn+relu6 |
fc(64) tanh |
fc(64) tanh |
fc(64) tanh |
fc(96) tanh |
fc(256) tanh |
fc(10) |
fc(10) |
fc(10) |
fc(10) |
fc(10) |
Trainable parameters | ||||
15.6k |
16.9k |
29.1k |
40.7k |
315.6k |
Сначала мы обучили вещественные версии этих сетей 250 эпох с оптимизатором AdamW с параметрами по умолчанию, кроме коэффициента убывания весов (weight decay), который был установлен в , и скорости обучения (learning rate), которая была задана как и уменьшалась вдвое каждые 50 эпох. Размер батча был 100.
Дальше слои батч нормализации были интегрированы в свертки и мы приступили к квантованию. Первый и последний слой мы не квантовали, так как это улучшает результаты обучения и практически не снижает вычислительную эффективность. Результаты приведены в таблице 2. Напомним, что у нашей схемы квантования есть параметры и , которые обозначают количество дискретов для активаций и весов соответственно, мы перебрали все возможные их комбинации. Эксперименты повторили 5 раз с разной начальной инициализацией сети и посчитали среднюю точность и погрешность.
Таблица 2. Точность классификации на выборке CIFAR-10
Quantization |
Accuracy, % |
|||||
CNN6 |
CNN7 |
CNN8 |
CNN9 |
CNN10 |
||
5 |
127 |
63.5 ± 0.3 |
69.6 ± 0.1 |
70.6 ± 0.3 |
71.7 ± 0.3 |
81.3 ± 0.3 |
7 |
85 |
68.4 ± 0.3 |
73.4 ± 0.1 |
74.7 ± 0.2 |
76.2 ± 0.2 |
85.4 ± 0.1 |
9 |
63 |
70.9 ± 0.1 |
75.0 ± 0.2 |
76.4 ± 0.2 |
78.0 ± 0.2 |
86.9 ± 0.2 |
11 |
51 |
71.8 ± 0.3 |
75.7 ± 0.1 |
77.4 ± 0.1 |
79.0 ± 0.2 |
87.6 ± 0.1 |
13 |
43 |
72.7 ± 0.2 |
76.2 ± 0.1 |
77.8 ± 0.1 |
79.5 ± 0.1 |
88.0 ± 0.1 |
15 |
37 |
73.1 ± 0.2 |
76.4 ± 0.1 |
78.1 ± 0.1 |
79.8 ± 0.2 |
88.2 ± 0.2 |
17 |
31 |
73.0 ± 0.1 |
76.6 ± 0.2 |
78.1 ± 0.3 |
79.8 ± 0.2 |
88.2 ± 0.1 |
19 |
29 |
73.1 ± 0.2 |
76.4 ± 0.2 |
78.5 ± 0.1 |
80.0 ± 0.3 |
88.5 ± 0.3 |
21 |
25 |
73.4 ± 0.2 |
76.7 ± 0.3 |
78.3 ± 0.2 |
79.9 ± 0.1 |
88.4 ± 0.2 |
23 |
23 |
73.3 ± 0.3 |
76.5 ± 0.1 |
78.2 ± 0.2 |
79.9 ± 0.3 |
88.4 ± 0.2 |
25 |
21 |
73.1 ± 0.1 |
76.6 ± 0.2 |
78.2 ± 0.2 |
79.9 ± 0.2 |
88.5 ± 0.2 |
29 |
19 |
73.0 ± 0.1 |
76.3 ± 0.1 |
78.3 ± 0.2 |
79.9 ± 0.2 |
88.3 ± 0.2 |
31 |
17 |
73.1 ± 0.2 |
76.1 ± 0.1 |
78.0 ± 0.3 |
79.7 ± 0.2 |
88.2 ± 0.1 |
37 |
15 |
72.8 ± 0.2 |
75.5 ± 0.4 |
77.7 ± 0.3 |
79.4 ± 0.2 |
87.9 ± 0.3 |
43 |
13 |
72.0 ± 0.4 |
74.8 ± 0.2 |
77.5 ± 0.2 |
79.0 ± 0.3 |
87.9 ± 0.1 |
51 |
11 |
70.9 ± 0.3 |
74.0 ± 0.1 |
76.0 ± 0.3 |
78.1 ± 0.2 |
87.5 ± 0.1 |
63 |
9 |
69.0 ± 0.4 |
71.7 ± 0.3 |
74.3 ± 0.5 |
76.7 ± 0.4 |
86.3 ± 0.1 |
85 |
7 |
65.9 ± 0.5 |
67.7 ± 1.0 |
70.6 ± 0.4 |
73.4 ± 0.7 |
84.5 ± 0.3 |
127 |
5 |
47.5 ± 0.4 |
55.2 ± 0.6 |
58.2 ± 1.1 |
67.5 ± 0.4 |
74.9 ± 2.3 |
4 бита |
72.0 ± 0.2 |
75.4 ± 0.2 |
77.3 ± 0.2 |
79.3 ± 0.3 |
87.7 ± 0.2 |
|
8 бит |
74.7 ± 0.1 |
77.6 ± 0.1 |
79.4 ± 0.1 |
80.8 ± 0.1 |
89.2 ± 0.1 |
|
float32 |
74.95 |
77.83 |
79.66 |
81.4 |
89.07 |
Можно видеть, что 8-битные модели демонстрируют почти ту же точность, что и вещественные, а вот 4-битные им сильно уступают. Для 4.6-битных моделей оказалось, что лучше всего распределять дискреты между активациями и весами равномерно: наилучшие результаты показывают схемы квантования от (15, 37) до (31, 17). Отметим, что точность классификации в этом диапазоне заметно лучше, чем у 4-битных моделей. Тем не менее, 4.6-битные сети все же немного уступили 8-битным и вещественным моделям, что говорит о том, что метод обучения можно придумать и получше.
ImageNet
Это уже стандартная выборка для оценки точности классифицирующих моделей (см. примеры на рис. 2).
Здесь мы использовали ResNet’ы: ResNet-18 (11.7М весов) и ResNet-34 (21.8М весов), предобученные модели из PyTorch's Torchvision. Все активации ReLU заменили на ReLU6, батч нормализацию интегрировали в свертки и отквантовали, но первый и последний слои также не трогали.
Калибровочная выборка была меньше: 25 батчей по 64 изображения для ResNet-18 и 5 батчей по 64 изображения для ResNet-34.
Точности классификации top-1 и top-5 для выборки ImageNet показаны в таблице 3. Лучшая пара параметров для 4.6-битных моделей оказалась (, ) = (23, 23). Также как и на CIFAR-10 8-битные модели имеют сравнимое качество с вещественными, а 4-битные сильно хуже. 4.6-битные сети оказались почти ровно посередине. Это ухудшение не такое большое и с ним уже можно работать!
Таблица 3. Точность классификации на ImageNet.
Quantization |
ResNet-18 |
ResNet-34 |
|||
top-1, % |
top-5, % |
top-1, % |
top-5, % |
||
29 |
19 |
65.6 ± 0.3 |
86.7 ± 0.1 |
68.6 ± 0.1 |
88.5 ± 0.1 |
25 |
21 |
65.9 ± 0.2 |
86.9 ± 0.1 |
68.8 ± 0.3 |
88.7 ± 0.2 |
23 |
23 |
66.1 ± 0.1 |
87.0 ± 0.1 |
69.1 ± 0.2 |
88.9 ± 0.1 |
4 бита |
64.2 ± 0.2 |
85.7 ± 0.1 |
66.1 ± 0.3 |
87.0 ± 0.2 |
|
8 бит |
68.3 ± 0.1 |
88.3 ± 0.1 |
71.4 ± 0.1 |
90.1 ± 0.1 |
|
float32 |
68.7 |
88.5 |
72.3 |
90.8 |
TCIA
Это выборка с изображениями МРТ головного мозга, изображения RGB, размер каждого 256 на 256. На этих изображениях ставится задача сегментации: нужно найти и отметить аномалии. Эта задача прекрасно решается с помощью модели U-Net с 7.76М параметров. Мы взяли предобученную модель тут. Она принимает на вход изображение и выдает бинарную маску, которая отмечает аномалии.
Также как и в ResNet’ах мы заменили активации ReLU на ReLU6 и отквантовали сеть. Размер калибровочной выборки был 5 батчей по 10 изображений. Для оценки качества сегментации использовались средние значения Dice и Intersection over Union (IoU) для тех изображений, где аномалии были. Также посчитали метрики качества бинарной классификации (есть аномалия или нет): accuracy, precision, recall, ошибку 1-го рода (false positive rate) и ошибку 2-го рода (false negative rate). Метрики приведены в таблице 4, а примеры работы на рис. 3.
Таблица 4. Качество сегментации моделей U-Net на выборке TCIA.
float32 |
8 бит |
4.6 бита |
4 бита |
|
Dice ↑ |
0.7643 |
0.7843 ± 0.0008 |
0.769 ± 0.006 |
0.746 ± 0.009 |
IoU ↑ |
0.6875 |
0.7046 ± 0.0008 |
0.688 ± 0.005 |
0.662 ± 0.009 |
Accuracy ↑ |
0.8119 |
0.8124 ± 0.0013 |
0.781 ± 0.013 |
0.57 ± 0.06 |
Precision ↑ |
0.6654 |
0.6624 ± 0.0016 |
0.623 ± 0.015 |
0.45 ± 0.04 |
Recall ↑ |
0.9286 |
0.9447 ± 0.0009 |
0.948 ± 0.004 |
0.969 ± 0.013 |
Type I ↓ |
0.1631 |
0.1682 ± 0.0011 |
0.201 ± 0.013 |
0.42 ± 0.07 |
Type II ↓ |
0.0249 |
0.0193 ± 0.0003 |
0.0182 ± 0.0013 |
0.011 ± 0.005 |
Можно видеть, что 8-битная сеть даже немного лучше вещественной: у нее чуть больше ошибка 1-го рода, но при этом выше качество сегментации. 4-битной сети откровенно плохо как по метрикам, так и визуально: добавился заметный шум. А 4.6-битная модель работает разумным образом и демонстрирует качество близкое к качеству 8-битной и вещественной моделей.
Обсуждение
Наши эксперименты ясно показали, что даже при использовании простейших методов обучения, не слишком отличающихся от подходов для 4- и 8-битных моделей, 4.6-битные сети работают заметно лучше 4-битных при практически такой же скорости работы. Это было ожидаемо, так как 4.6-битная схема квантования предлагает больше дискретов, чем 4-битная, а мы только проверили это теоретическое обстоятельство на практике.
На сегодняшний день в нише быстрых моделей для центральных процессоров основную роль играют 8-битные модели за счет сочетания факторов: заметное ускорение и малые усилия по конвертации к 8-битному формату. Именно их наиболее актуально сравнивать с 4.6-битными. Здесь 4.6-битные сети однозначно выиграли по скорости, продемонстрировав ускорение на 30-40%. При использовании простых в имплементации методов обучения они несколько проиграли 8-битным по качеству, однако можно взять чуть более сложную 4.6-битную модель, которая даст сходное качество, но будет все еще работать быстрее исходной. Поэтому на практике мы получаем ускорение без снижения качества, а также активно работаем над усовершенствованием метода обучения 4.6-битных сетей.
На самом деле здесь есть много простора для фантазии:
параметры , позволяют модифицировать схему квантования под разные задачи,
никто не говорит, что они должны быть одинаковые в разных слоях модели,
AdaQuant все-таки устарел, есть и другие методы, которые могут подойти лучше.
Поэтому мы считаем 4.6-битные сети крайне перспективными для практического использования и дальнейшего усовершенствования. Мы уже используем их в наших системах распознавания паспортов, распознавания документов, а также подали заявку на патентование этой технологии в РФ и США.