В этой небольшой статье мы научим нейросеть решать задачу умножения перестановок длины 5 (группа ) и визуализируем результаты обучения с помощью методов проекции t-SNE (и понизим размерность PCA) и алгоритма UMAP. Мы убедимся в том, что даже элементарная модель может "неосознанно" провести бинарную классификацию перестановок. Однако с более тонкой задачей кластеризации по цикловой структуре модель будет испытывать затруднения.
Вступление
Этот эксперимент был вдохновлен моей недавной вышедшей статьей на Хабре про извлечение корня из перестановок (большое спасибо за комментарии и активность!). И мне стало любопытно узнать - может ли НС выделить такие признаки, чтобы линейно сепарировать перестановки одной цикловой группы от других?
Например, я помнил, что в работе Grokking: Generalization Beyond Overfitting on Small Algorithmic Datasets был дан пример такой классификации при далеком-далеком очень долгом обучении и огромном количестве параметров (модель проходит этап переобучения и входит в состояние обобщения, гроккинга) на задаче умножения перестановок в S5:
Получится ли у нас добиться такого же результата, только менее тяжеловесной нейросетью?
Модель
Используем простейшую нейронную сеть с двумя линейными слоями и слоем дропаута. Функция активации - ReLU. Под спойлером вы можете увидеть процесс обучения и создание визуализации.
Скрытый текст
from itertools import permutations
S5 = list(permutations(range(5)))
perm_to_idx = {perm: idx for idx, perm in enumerate(S5)}
idx_to_perm = {idx: perm for idx, perm in enumerate(S5)}
def perm_multiply(p1, p2):
return tuple(p1[p2[i]] for i in range(5))
train_data = []
for p1, p2 in itertools.product(S5, repeat=2):
product = perm_multiply(perm_multiply(p1, p2), p1)
train_data.append((perm_to_idx[p1], perm_to_idx[p2], perm_to_idx[product]))
train_data = np.array(train_data)
class PermutationNet(nn.Module):
def __init__(self, input_size, hidden_size, output_size):
super(PermutationNet, self).__init__()
self.fc1 = nn.Linear(input_size * 2, hidden_size)
self.fc2 = nn.Linear(hidden_size, output_size)
self.dropout = nn.Dropout(0.3)
def forward(self, x):
x = F.relu(self.fc1(x))
x = self.dropout(x)
x = self.fc2(x)
return x
input_size = len(S5)
hidden_size = 256
output_size = len(S5)
model = PermutationNet(input_size, hidden_size, output_size)
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)
def one_hot_encode(index, size):
vec = np.zeros(size)
vec[index] = 1
return vec
X_train = np.array([np.concatenate([one_hot_encode(p1, input_size), one_hot_encode(p2, input_size)])
for p1, p2, _ in train_data])
y_train = np.array([product for _, _, product in train_data])
X_train = torch.FloatTensor(X_train)
y_train = torch.LongTensor(y_train)
epochs = 1000
for epoch in range(epochs):
optimizer.zero_grad()
outputs = model(X_train)
loss = criterion(outputs, y_train)
loss.backward()
optimizer.step()
if (epoch + 1) % 100 == 0:
print(f'Эпоха [{epoch+1}/{epochs}], Лосс: {loss.item():.4f}')
from sklearn.decomposition import PCA
import numpy as np
plt.figure(figsize=(12, 12))
with torch.no_grad():
last_layer_activations = model.fc2.weight.cpu().numpy()
pca = PCA(n_components=15)
pca_result = pca.fit_transform(last_layer_activations)
tsne = TSNE(n_components=2, random_state=40)
tsne_results = tsne.fit_transform(last_layer_activations)
for structure, color in color_map.items():
indices = [i for i, perm in enumerate(S5) if cycle_structure(perm) == structure]
plt.scatter(tsne_results[indices, 0], tsne_results[indices, 1], c=color, label=f'{structure}-цикл', s=100, alpha=0.7)
# for i in indices:
# perm = S5[i]
# perm_str = str([j + 1 for j in perm])
# plt.text(tsne_results[i, 0] + 0.03, tsne_results[i, 1] + 0.03, perm_str, fontsize=8)
plt.legend(title='Цикловые структуры', bbox_to_anchor=(1.05, 1), loc='upper left')
plt.title('t-SNE проекция на задаче предсказания p = xyx')
plt.show()
Хорошая новость
Почти с первых эпох модель улавливает важную закономерность - какие перестановки чётные, а какие нечётные. Определить четность перестановки (или четность количества инверсий) по цикловой структуре можно следующим образом: перестановка нечётна, если количество циклов четной длины нечётно.
По правую сторону от оранжевой линии лежат владения нечетных перестановок - (2,3)-циклы, (4)-циклы, (2)-циклы. По левую - четные. Более того, такое разграничение сохранится и при обучении на другую операцию .
Странная новость
На данном графике показаны только четные перестановки. Также на выходной слой был применен алгоритм снижения размерности до 3 (PCA) и только потом передан в t-SNE размерности 2. Аналогично для нечетных:
Можно ли обосновать эту закономерность математически? Или это обусловлено странной настройкой PCA и t-SNE? Ответы на эти вопросы автор не знает, поэтому милости прошу в комментарии. Скорее всего я упускаю нечто очевидное.
Печальная новость
Увы, не удалось достичь красивой картинки, как в вступлении. Сепарировать цикловые структуры одинаковой четности оказалось непосильной задачей.
Даже после 10000 эпох наблюдаем почти ту же (плачевную) картину:
Разграничить как-то цикловые структуры не представляется возможным, уж слишком они хаотично разбросаны.
Отступление
Модель теперь предсказывает класс перестановки. Используем два сверточных слоя GNN (GCNConv из torch_geometric) и два линейных слоя. Перестановки теперь кодируются как орграфы. За 500 эпох с теми же параметрами достигнем того же результата. Однако это не является успешным выполнением нашей задачи - модель должна неявно отразить закономерность.
Можно ли сделать вывод о том, что нейросеть из двух линейных слоев не способна на "усвоение" закономерности, что позволит разделить перестановки на цикловые группы при умножении? Я не могу сказать об этом наверняка, поэтому в этой сумбурной микростатье делюсь этими результатами. Будет интересно узнать Ваши мнения в комментариях.
ioleynikov
я не вижу сути проблемы. Логично рассматривать нейронную сеть как еще одно представление математической функции в ряд, аналогичный разложению Тейлора или Фурье. Кстати, для временных рядов очень эффективны представления в виде авторегрессионных моделей типа (LPC) это чистой воды полный аналог back propagation.