В этой небольшой статье мы научим нейросеть решать задачу умножения перестановок длины 5 (группа S_5) и визуализируем результаты обучения с помощью методов проекции t-SNE (и понизим размерность PCA) и алгоритма UMAP. Мы убедимся в том, что даже элементарная модель может "неосознанно" провести бинарную классификацию перестановок. Однако с более тонкой задачей кластеризации по цикловой структуре модель будет испытывать затруднения.

Вступление

Этот эксперимент был вдохновлен моей недавной вышедшей статьей на Хабре про извлечение корня из перестановок (большое спасибо за комментарии и активность!). И мне стало любопытно узнать - может ли НС выделить такие признаки, чтобы линейно сепарировать перестановки одной цикловой группы от других?

Например, я помнил, что в работе Grokking: Generalization Beyond Overfitting on Small Algorithmic Datasets был дан пример такой классификации при далеком-далеком очень долгом обучении и огромном количестве параметров (модель проходит этап переобучения и входит в состояние обобщения, гроккинга) на задаче умножения перестановок в S5:

Каждому цвету соответствует своя цикловая структура. Красота.
Каждому цвету соответствует своя цикловая структура. Красота.

Получится ли у нас добиться такого же результата, только менее тяжеловесной нейросетью?

Модель

Используем простейшую нейронную сеть с двумя линейными слоями и слоем дропаута. Функция активации - ReLU. Под спойлером вы можете увидеть процесс обучения и создание визуализации.

Скрытый текст
from itertools import permutations

S5 = list(permutations(range(5)))

perm_to_idx = {perm: idx for idx, perm in enumerate(S5)}
idx_to_perm = {idx: perm for idx, perm in enumerate(S5)}

def perm_multiply(p1, p2):
    return tuple(p1[p2[i]] for i in range(5))

train_data = []
for p1, p2 in itertools.product(S5, repeat=2):
    product = perm_multiply(perm_multiply(p1, p2), p1)
    train_data.append((perm_to_idx[p1], perm_to_idx[p2], perm_to_idx[product]))

train_data = np.array(train_data)


class PermutationNet(nn.Module):
    def __init__(self, input_size, hidden_size, output_size):
        super(PermutationNet, self).__init__()
        self.fc1 = nn.Linear(input_size * 2, hidden_size)
        self.fc2 = nn.Linear(hidden_size, output_size)
        self.dropout = nn.Dropout(0.3)

    def forward(self, x):
        x = F.relu(self.fc1(x))
        x = self.dropout(x)
        x = self.fc2(x)

        return x

input_size = len(S5)
hidden_size = 256
output_size = len(S5)

model = PermutationNet(input_size, hidden_size, output_size)


criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)


def one_hot_encode(index, size):
    vec = np.zeros(size)
    vec[index] = 1
    return vec


X_train = np.array([np.concatenate([one_hot_encode(p1, input_size), one_hot_encode(p2, input_size)])
                    for p1, p2, _ in train_data])
y_train = np.array([product for _, _, product in train_data])

X_train = torch.FloatTensor(X_train)
y_train = torch.LongTensor(y_train)

epochs = 1000
for epoch in range(epochs):
    optimizer.zero_grad()
    outputs = model(X_train)
    loss = criterion(outputs, y_train)
    loss.backward()
    optimizer.step()

    if (epoch + 1) % 100 == 0:
        print(f'Эпоха [{epoch+1}/{epochs}], Лосс: {loss.item():.4f}')



from sklearn.decomposition import PCA
import numpy as np

plt.figure(figsize=(12, 12))

with torch.no_grad():
    last_layer_activations = model.fc2.weight.cpu().numpy()

pca = PCA(n_components=15)
pca_result = pca.fit_transform(last_layer_activations)

tsne = TSNE(n_components=2, random_state=40)
tsne_results = tsne.fit_transform(last_layer_activations)

for structure, color in color_map.items():
    indices = [i for i, perm in enumerate(S5) if cycle_structure(perm) == structure]
    plt.scatter(tsne_results[indices, 0], tsne_results[indices, 1], c=color, label=f'{structure}-цикл', s=100, alpha=0.7)

    # for i in indices:
    #     perm = S5[i]
    #     perm_str = str([j + 1 for j in perm])
    #     plt.text(tsne_results[i, 0] + 0.03, tsne_results[i, 1] + 0.03, perm_str, fontsize=8)


plt.legend(title='Цикловые структуры', bbox_to_anchor=(1.05, 1), loc='upper left')

plt.title('t-SNE проекция на задаче предсказания p = xyx')
plt.show()

Хорошая новость

Обучаем операции умножения перестановок в S5. p = xy
Обучаем операции умножения перестановок в S5. p = xy

Почти с первых эпох модель улавливает важную закономерность - какие перестановки чётные, а какие нечётные. Определить четность перестановки (или четность количества инверсий) по цикловой структуре можно следующим образом: перестановка нечётна, если количество циклов четной длины нечётно.

По правую сторону от оранжевой линии лежат владения нечетных перестановок - (2,3)-циклы, (4)-циклы, (2)-циклы. По левую - четные. Более того, такое разграничение сохранится и при обучении на другую операцию p = xyx.

Обучаем НС предсказанию результата выражения xyx
Обучаем НС предсказанию результата выражения xyx

Странная новость

И снова можно провести линию. Но чем она обоснована?
И снова можно провести линию. Но чем она обоснована?

На данном графике показаны только четные перестановки. Также на выходной слой был применен алгоритм снижения размерности до 3 (PCA) и только потом передан в t-SNE размерности 2. Аналогично для нечетных:

Аналогично для нечетных
Аналогично для нечетных

Можно ли обосновать эту закономерность математически? Или это обусловлено странной настройкой PCA и t-SNE? Ответы на эти вопросы автор не знает, поэтому милости прошу в комментарии. Скорее всего я упускаю нечто очевидное.

Печальная новость

Увы, не удалось достичь красивой картинки, как в вступлении. Сепарировать цикловые структуры одинаковой четности оказалось непосильной задачей.

UMAP. Четная сепарация остается, но желаемого результата нет.
UMAP. Четная сепарация остается, но желаемого результата нет.

Даже после 10000 эпох наблюдаем почти ту же (плачевную) картину:

Вытянулись вдоль линии
Вытянулись вдоль линии

Разграничить как-то цикловые структуры не представляется возможным, уж слишком они хаотично разбросаны.

Отступление

С помощью GNN удалось получить желаемый результат
С помощью GNN удалось получить желаемый результат

Модель теперь предсказывает класс перестановки. Используем два сверточных слоя GNN (GCNConv из torch_geometric) и два линейных слоя. Перестановки теперь кодируются как орграфы. За 500 эпох с теми же параметрами достигнем того же результата. Однако это не является успешным выполнением нашей задачи - модель должна неявно отразить закономерность.

Можно ли сделать вывод о том, что нейросеть из двух линейных слоев не способна на "усвоение" закономерности, что позволит разделить перестановки на цикловые группы при умножении? Я не могу сказать об этом наверняка, поэтому в этой сумбурной микростатье делюсь этими результатами. Будет интересно узнать Ваши мнения в комментариях.

Комментарии (2)


  1. ioleynikov
    21.10.2024 10:30

    я не вижу сути проблемы. Логично рассматривать нейронную сеть как еще одно представление математической функции в ряд, аналогичный разложению Тейлора или Фурье. Кстати, для временных рядов очень эффективны представления в виде авторегрессионных моделей типа (LPC) это чистой воды полный аналог back propagation.


  1. Limpo444
    21.10.2024 10:30

    Ебать как все понятно (нет)