Недавняя статья об новой архитектуре нейронных сетей на основе теоремы Колмогорова-Арнольда (KAN — Kolmogorov-Arnold Networks) вызвала большой ажиотаж: уже было представлено множество вариаций того, как правильно создавать такие сети, ведутся горячие дебаты, а рабочая ли схема и имеет ли право на жизнь и многое другое. Цель этой статьи постараться ответить на простой вопрос: могут ли KAN справляться с компьютерным зрением?
Исходный код всех экспериментов данной статьи можете найти по ссылке.
Пока ещё не ушли далеко, сразу же скажем, что у этой статьи два автора: Иван Дрокин, который написал оригинал на английском языке, проводил эксперименты и написал собственную библиотеку свёрток для KAN — torch-conv-kan. За вольный перевод для читателей Хабра, многочисленные ревью и правки оригинала статьи, а также часть глупых мемов отвечал Антон Клочков — автор телеграм-канала MLE шатает Produnction (нет, в названии нет опечатки) про нейронные сети, мемы, языки программирования и рассуждения об общих топиках в IT.
Все рассуждения в статье ведутся от лица Ивана.
Основы
Давайте начнем с самого фундаментального — математики. Если говорить кратко, то многослойные перцептроны (MLP — Multi-Layer Perceptron) представляют собой нелинейную функцию от взвешенной суммы входных данных, тогда как сети Колмогорова-Арнольда представляют собой сумму нелинейных унарных функций от входных данных.
Работоспособность MLP основывается на теореме Цыбенко или более известной в широких кругах, как универсальной теореме апроксимации. Если не вдаваться в детали, то она утверждает, что любую функцию можно апроксимировать с любой точностью с помощью одного достаточно широкого скрытого слоя с нелинейной функцией активации.
С другой стороны, как уже было написано в начале статьи, работа KAN базируется на основе теоремы Колмогорова-Арнольда:
Авторы статьи "KAN: Kolmogorov-Arnold Networks" разработали новый тип архитектуры нейронных сетей: обучаемые активации на рёбрах и суммирование результатов в узлах. В противовес этому, в MLP в узлах применяется фиксированная нелинейная функция, а на ребрах производится линейная проекция входов.
Что же это за "обучаемые активации" и как их обучать? В оригинальной статье было предложено использовать B-сплайны (с классными визуализациями можно глянуть следующую статью). Сплайн-функция степени представляет собой кусочно-полиномиальную функцию степени . Места, где отрезки соединяются, называются узлами. Ключевое свойство сплайн-функций заключается в том, что они и их производные могут быть непрерывными, в зависимости от кратности узлов.
Плюсы и минусы KAN
Авторы выделяют следующие плюсы использование KAN относительно MLP:
Выше качество: KAN показали точность выше на большом множестве задач в сравнении с MLP. Первые могут более эффективно (в терминах качества) представлять сложные многомерные функции (спасибо теореме Колмогорова-Арнольда), что положительно сказывается на качестве.
Интерпретируемость: MLP — это чёрный ящик, сложно сказать, что происходит у них внутри, а KAN может предложить интересные возможности. Например, можно разложить сложную многомерную функцию на более простые компоненты, анализ которых может служить инсайтом о поведении модели, о работе на конкретных данных.
Гибкость и обобщаемость: за счёт обучаемости активаций можно лучше находить нелинейные зависимости в данных, что также ведёт к обобщаемости (но это не так просто).
Устойчивость к шумным данным и Adversarial Attacks: способность KAN улавливать более устойчивые представления данных с помощью адаптивных функций активации позволяют KAN быть более устойчивым к шумам и атакам.
Но есть и определённые сложности с KAN (No free lunch, помните?):
Чувствительность к гиперпараметрам: как и любая другая нейронная сеть, KAN чувствительна к множеству гиперпараметров, таких как learning rate, силу регуляризации, самой архитектуре. Проблема подбора правильных параметров остаётся актуальной и может в значительной степени влиять на сходимость. Тут стоит отметить, что на настоящий момент нет чётких рецептов того, как именно тюнить разного рода гиперпараметры (L1/L2 веса, параметры активаций, dropout коэффициенты и т.д.). Это чем-то напоминает времена, когда только начиналось развитие свёрточных нейронных сетей.
Высокие вычислительные затраты: адаптивные функции активации (например, B-сплайны) могут требовать больше ресурсов, чем классические MLP, что отражается на длительности и стоимости обучения и инференса;
Сложность моделей и масштабируемость: KAN масштабируемы в терминах гибкости модели, но при этом более глубокие сети ещё сильнее увеличивают вычислительные затраты. Масштабирование KAN на большие датасеты и сложные задачи, где вычислительная эффективность и интерпретируемость выходят на передний план, является пока сложной задачей.
Разновидности KAN
Некоторые из проблем, описанные выше, частично были решены в следующих модификациях.
В первую очередь, была представлена версия Fast KAN, в которой B-сплайны заменены радиальными базисными функциями (RBF). Это изменение помогает уменьшить вычислительные затраты на использование сплайнов.
Также появилось несколько Polynomial KAN: Лежандра, Чебышева, Якоби, Грама, Бернштейна, и Wavelet-based KAN. Про каждый из них можно узнать подробнее, перейдя по соответствующим ссылкам.
Свёртки для KAN
Для начала вспомним, что такое свёрточный слой? Наиболее распространенный тип свёрточного слоя — это слой двухмерной свёртки, обычно сокращаемый как Conv2D. В этом слое фильтр (или ядро) "скользит" по двухмерным входным данным (например, картинкам), выполняя поэлементное умножение. Результаты суммируются в одно число. Ядро выполняет ту же операцию для каждой позиции, по которой оно скользит, преобразуя двухмерную матрицу признаков в другую.
Хотя свёртки в одномерном и трёхмерном пространствах имеют тот же принцип, они используют разные ядра, размеры входных данных и выходных данных. Однако для упрощения мы сосредоточимся на двухмерной свёрточном слое. Если вы хотите углубиться в эту тему, прочитайте этот хороший пост.
Как правило, после слоя свёртки применяются слой нормализации (например, BatchNorm, InstanceNorm и т.д.) и нелинейные активации (ReLU, LeakyReLU, SiLU и т.д.).
Более формально: предположим, что у нас есть входное изображение размером . Для упрощения мы опустим ось канала (т.е. рассмотрим одноканальные изображения), она добавляет еще одно суммирование по оси канала. Итак, сначала нам нужно выполнить свертку с нашим ядром размером :
Затем применяем батч нормализацию и нелинейность, например, ReLU:
Свёрточный слой в KAN будет работать иначе: ядро будет содержать не обучаемые веса (конкретные числа), а унарные, обучаемые нелинейные функции. В этом случае ядро "скользит" по двухмерным входным данным, выполняя поэлементное применение функций активации из этого фильтра. Результатом каждого из применений будет число, которые потом суммируются.
Опять же, более формально: давайте применим к нашему входному изображению свёрточный слой в терминах теоремы Колмогорова-Арнольда:
Каждая является унарной нелинейной обучаемой функцией со своими обучаемыми параметрами. В оригинальной статье авторы предлагают использовать функцию следующего вида:
где и — обучаемые параметры, — B-сплайн. — это esidual activation functions, похожие на residual connctions из ResNet сетей. В оригинальной статье авторы предлагают выбрать в качестве функцию SiLU:
Как было уже упомянуто выше, недавно было предложено использовать вместо сплайнов полиномиальные функции или RBF.
Подытожим вышесказанное: в классических свёрточных слоях ядра содержат просто веса (числа), тогда как KAN-свёртки содержат унарные функции:
Эксперименты
Реализацию различных типов KAN-свёрток, моделей, датасетов, сетапов экспериментов и многое другое можете найти моём репозитории torch-conv-kan.
MNIST
Итак, давайте начнём эксперименты со всем знакомого MNISTа.
Бейзлайн модели — простая нейронная сеть, состоящая из четырёх свёрточных слоёв. Для уменьшения размерности во втором и третьем свёрточных слоях используется параметр dilation=2
.
Количество каналов в свёрточных слоях одинаково для всех моделей: 32, 64, 128, 256. После свёрток применяется Global Average Pooling, за которым следует линейный выходной слой. Кроме того, для регуляризиции используется Dropout слой: с параметром p = 0.25
в свёрточных слоях и p = 0.5
перед выходным слоем.
Пример реализации с использованием Pytorch и torch-conv-kan
import torch
import torch.nn as nn
from kan_convs import KANConv2DLayer
class SimpleConvKAN(nn.Module):
def __init__(
self,
layer_sizes,
num_classes: int = 10,
input_channels: int = 1,
spline_order: int = 3,
groups: int = 1):
super(SimpleConvKAN, self).__init__()
self.layers = nn.Sequential(
KANConv2DLayer(input_channels, layer_sizes[0], spline_order, kernel_size=3, groups=1, padding=1, stride=1,
dilation=1),
KANConv2DLayer(layer_sizes[0], layer_sizes[1], spline_order, kernel_size=3, groups=groups, padding=1,
stride=2, dilation=1),
KANConv2DLayer(layer_sizes[1], layer_sizes[2], spline_order, kernel_size=3, groups=groups, padding=1,
stride=2, dilation=1),
KANConv2DLayer(layer_sizes[2], layer_sizes[3], spline_order, kernel_size=3, groups=groups, padding=1,
stride=1, dilation=1),
nn.AdaptiveAvgPool2d((1, 1))
)
self.output = nn.Linear(layer_sizes[3], num_classes)
self.drop = nn.Dropout(p=0.25)
def forward(self, x):
x = self.layers(x)
x = torch.flatten(x, 1)
x = self.drop(x)
x = self.output(x)
return x
Заметьте, что в случае классических свёрточных слоев, структура слоёв была бы примерно следующей: Conv2D -> Batch Normalization -> ReLU
.
Для проведения экспериментов используются аугментации, которые вы можете посмотреть под катом.
Пример аугментаций с использованием torchvision
from torchvision.transforms import v2
transform_train = v2.Compose([
v2.RandomAffine(degrees=20, translate=(0.1, 0.1), scale=(0.9, 1.1)),
v2.ColorJitter(brightness=0.2, contrast=0.2),
v2.ToTensor(),
v2.Normalize((0.5,), (0.5,))
])
Кроме того, нам также необходимо исследовать влияние различных слоёв нормализации внутри свёрток KAN и влияние L1 регуляризации. В столбце Norm Layer
во всех таблицах указывается, какой слой нормализации использовался во время эксперимента, а в столбце Affine
указывается, был ли параметр affine
слоя нормализации BatchNorm2d
установлен как True
или False
.
Все эксперименты были выполнены с использованием NVIDIA RTX 3090 с идентичными параметрами.
Model |
Accuracy |
Parameters |
Eval Time, s |
Norm Layer |
Affine |
L1 |
---|---|---|---|---|---|---|
SimpleConv, 4 layers |
99.42 |
101066 |
0.7008 |
BatchNorm2D |
False |
0 |
SimpleKANConv, 4 layers |
96.80 |
3488814 |
2.8306 |
InstanceNorm2D |
False |
0 |
SimpleKANConv, 4 layers |
99.41 |
3489774 |
2.6362 |
InstanceNorm2D |
True |
0 |
SimpleKANConv, 4 layers |
99.00 |
3489774 |
2.6401 |
BatchNorm2D |
True |
0 |
SimpleKANConv, 4 layers |
98.89 |
3489774 |
2.4138 |
BatchNorm2D |
True |
1e-05 |
Модель на основе классических свёрток работает куда лучше нейронной сети с KAN-свёртами, которая к тому же имеет в 34 раза больше параметров и требует куда большего времени исполнения. Кажется, это не та "революция" в нейронных сетях, которую мы ожидали.
Давайте попробуем воспользоваться подходом Fast KAN и заменим сплайны на RBF:
Model |
Accuracy |
Parameters |
Eval Time, s |
Norm Layer |
Affine |
L1 |
SimpleConv, 4 layers |
99.42 |
101066 |
0.7008 |
BatchNorm2D |
True |
0 |
SimpleFastKANConv, 4 layers |
99.26 |
3488810 |
1.5636 |
InstanceNorm2D |
False |
0 |
SimpleFastKANConv, 4 layers |
99.01 |
3489260 |
1.7406 |
InstanceNorm2D |
True |
0 |
SimpleFastKANConv, 4 layers |
97.65 |
3489260 |
1.5999 |
BatchNorm2D |
True |
0 |
SimpleFastKANConv, 4 layers |
95.62 |
3489260 |
1.6158 |
BatchNorm2D |
True |
1e-05 |
По времени стало получше, но при этом сами результаты неоднозначные: где-то стало лучше, где-то хуже.
Тогда попробуем ещё одну функцию вместо сплайнов: дискретные полиномы Чебышева. Сам факт что полиномы дискретны в теории должно дать неплохие результаты для обработки дискретных данных — изображений и текстов. Давайте узнаем, так ли это:
Model |
Accuracy |
Parameters |
Eval Time, s |
Norm Layer |
Affine |
L1 |
SimpleConv, 4 layers |
99.42 |
101066 |
0.7008 |
BatchNorm2D |
True |
0 |
SimpleKAGNConv, 4 layers |
98.21 |
487866 |
1.8506 |
InstanceNorm2D |
False |
0 |
SimpleKAGNConv, 4 layers |
99.46 |
488826 |
1.8813 |
InstanceNorm2D |
True |
0 |
SimpleKAGNConv, 4 layers |
99.49 |
488826 |
1.7253 |
BatchNorm2D |
True |
0 |
SimpleKAGNConv, 4 layers |
99.44 |
488826 |
1.8979 |
BatchNorm2D |
True |
1e-05 |
Новые KAN-свёртки выдают качество чуть лучше традиционной модели, но работают в 2.5 раза медленнее и имеют почти в пять раз больше параметров. L1 регуляризация немного снижает производительность модели, но это область для дальнейших улучшений.
CIFAR 100
Бейзлайн модели здесь меняется. Теперь у нас будет восемь свёрточных слоёв . Для уменьшения размерности во втором, третьем и шестом свёрточных слоях используется параметр dilation=2
. Для каждого слоя число каналов равно соответвующе 16, 32, 64, 128, 256, 256, 512, 512. Всё остальное в модели остается без изменений.
Аугментации, которые используются в экспериментах с CIFAR 100 можете посмотреть под катом.
Пример аугментаций с использованием torchvision
from torchvision.transforms import v2
from torchvision.transforms.autoaugment import AutoAugmentPolicy
transform_train = v2.Compose([
v2.RandomHorizontalFlip(p=0.5),
v2.RandomChoice([v2.AutoAugment(AutoAugmentPolicy.CIFAR10),
v2.AutoAugment(AutoAugmentPolicy.IMAGENET),
v2.AutoAugment(AutoAugmentPolicy.SVHN),
v2.TrivialAugmentWide()]),
v2.ToTensor(),
v2.Normalize((0.5,), (0.5,))
])
Model |
Accuracy |
Parameters |
Eval Time, s |
Norm Layer |
Affine |
L1 |
SimpleConv, 8 layers |
57.52 |
1187172 |
1.8265 |
BatchNorm2D |
True |
0 |
SimpleKAGNConv, 8 layers |
29.39 |
22655732 |
2.5358 |
InstanceNorm2D |
False |
0 |
SimpleKAGNConv, 8 layers |
48.56 |
22659284 |
2.0454 |
InstanceNorm2D |
True |
0 |
SimpleKAGNConv, 8 layers |
59.27 |
22659284 |
2.6460 |
BatchNorm2D |
True |
0 |
SimpleKAGNConv, 8 layers |
58.07 |
22659284 |
2.2583 |
BatchNorm2D |
True |
1e-05 |
KAN-свёртки на основе дискретных полиномов Чебышева показывают лучшее качество, хотя и с бОльшими временными затратами и существенно бОльшим количеством параметров (более чем в 20 раз больше). BatchNorm2D
, по-видимому, является лучшим вариантом для нормализации внутренних признаков в KAN-свёртках на дискретных полиномов Чебышева.
Использование последних кажется многообещающим для дальнейших экспериментов на ImageNet1k и других "реальных" наборах данных.
Заключение
Итак, могут ли KAN справляться с компьютерным зрением? Кажется, что да! Смогут ли они заменить классические CNN? Это еще предстоит выяснить.
MLP используется уже много лет и нуждается в обновлении. Мы уже видели подобные изменения. Например, шесть лет назад сети Long Short-Term Memory (LSTM), которые долгое время были основой для моделирования последовательностей, были заменены трансформерами в качестве стандартного строительного блока для архитектуры языковых моделей. Похожий сдвиг для MLP был бы интригующим.
Свёрточные нейронные сети, которые доминировали в течение многих лет (и до сих пор являются основой для компьютерного зрения), в конечном итоге были частично заменены визуальными трансформерами (ViT). Возможно, пришло время для нового лидера в этой области?
Однако прежде чем это произойдет, сообществу необходимо найти эффективные методы для обучения сетей Колмогорова-Арнольда, свёрточных сетей Колмогорова-Арнольда (ConvKAN) и ViT-KAN, а также решить проблемы, которые связаны с этими моделями (например, скорость инференса, количество параметров).
Хотя меня очень вдохновляет эта новая архитектура и начальные эксперименты показывают обнадеживающие результаты, я остаюсь несколько скептичен. Необходимы ещё эксперименты. Оставайтесь с нами, мы собираемся углубиться в эту тему.
imageman
А можно ли упростить вашу KAN реализацию с тем, что бы число параметров было сопоставимым с SimpleConv, 8 layers? Что там с переобучением? Может из-за большого числа параметров банально быстро уходит в переобучение? Сделайте тесты хотя бы для CIFAR 100.
brivang
Хороший поинт, спасибо за предложение, попробую провести эксперимент и приложить результаты. По поводу переобучения - сейчас пачка экспериментов гоняется по поводу оверфита и регуляризаций в целом - но это уже будет отдельная статья