Тема заинтересовала и руки зачесались проверить, а как это устроено у
Была выбрана простая сеть из примеров Keras в которую добавил одну строку. Нас интересует насколько упорядоченность входной обучающей последовательности mnist влияет на результат обучения MLP.
Результат получился неожиданным и странным, пришлось перепроверять многократно, но перейдем к делу и конкретике.
Идея эксперимента проста и обычна — обучаем MLP из keras на общедоступном mnist и получаем ориентир, после обучаем на последовательностях 01234567890123..7890123. Как студентов учат — немного бейсик, немного ассемблер, немного fortran, и т.д. и сравним с исходным обучением. Результат вполне ожидаем, исходная последовательность учит лучше, но порядок такой же. Вот график 64 испытаний
А теперь будем учить сеть так, подаем все картинки с «0», потом все «1», потом все «2» и так до «9» и результат получается никакой!, сеть просто не учится. Интуитивно ожидаешь результат сопоставимый, хуже или лучше — это уже детали, но вот таблица результатов обучения 64 раза
('Test accuracy:', 0.9708)
('Test accuracy:', 0.97689999999999999)
('Test accuracy:', 0.1009)
step 1
('Test accuracy:', 0.97689999999999999)
('Test accuracy:', 0.97219999999999995)
('Test accuracy:', 0.1009)
step 2
('Test accuracy:', 0.97330000000000005)
('Test accuracy:', 0.97609999999999997)
('Test accuracy:', 0.1028)
step 3
('Test accuracy:', 0.97040000000000004)
('Test accuracy:', 0.97160000000000002)
('Test accuracy:', 0.1135)
step 4
('Test accuracy:', 0.97370000000000001)
('Test accuracy:', 0.97050000000000003)
('Test accuracy:', 0.098199999999999996)
step 5
('Test accuracy:', 0.96999999999999997)
('Test accuracy:', 0.96909999999999996)
('Test accuracy:', 0.1009)
step 6
('Test accuracy:', 0.97589999999999999)
('Test accuracy:', 0.97540000000000004)
('Test accuracy:', 0.1028)
step 7
('Test accuracy:', 0.97360000000000002)
('Test accuracy:', 0.97350000000000003)
('Test accuracy:', 0.1135)
step 8
('Test accuracy:', 0.97740000000000005)
('Test accuracy:', 0.97109999999999996)
('Test accuracy:', 0.1135)
step 9
('Test accuracy:', 0.97260000000000002)
('Test accuracy:', 0.97089999999999999)
('Test accuracy:', 0.1135)
step 10
('Test accuracy:', 0.96930000000000005)
('Test accuracy:', 0.9708)
('Test accuracy:', 0.1028)
step 11
('Test accuracy:', 0.97419999999999995)
('Test accuracy:', 0.97099999999999997)
('Test accuracy:', 0.1135)
step 12
('Test accuracy:', 0.97419999999999995)
('Test accuracy:', 0.97519999999999996)
('Test accuracy:', 0.1009)
step 13
('Test accuracy:', 0.97719999999999996)
('Test accuracy:', 0.97370000000000001)
('Test accuracy:', 0.1135)
step 14
('Test accuracy:', 0.97489999999999999)
('Test accuracy:', 0.97189999999999999)
('Test accuracy:', 0.1135)
step 15
('Test accuracy:', 0.9758)
('Test accuracy:', 0.97219999999999995)
('Test accuracy:', 0.10489999999999999)
step 16
('Test accuracy:', 0.97419999999999995)
('Test accuracy:', 0.97529999999999994)
('Test accuracy:', 0.1135)
step 17
('Test accuracy:', 0.97819999999999996)
('Test accuracy:', 0.97170000000000001)
('Test accuracy:', 0.1009)
step 18
('Test accuracy:', 0.97850000000000004)
('Test accuracy:', 0.97260000000000002)
('Test accuracy:', 0.1009)
step 19
('Test accuracy:', 0.97419999999999995)
('Test accuracy:', 0.97589999999999999)
('Test accuracy:', 0.0974)
step 20
('Test accuracy:', 0.97699999999999998)
('Test accuracy:', 0.97319999999999995)
('Test accuracy:', 0.1135)
step 21
('Test accuracy:', 0.97309999999999997)
('Test accuracy:', 0.97260000000000002)
('Test accuracy:', 0.1009)
step 22
('Test accuracy:', 0.97560000000000002)
('Test accuracy:', 0.97519999999999996)
('Test accuracy:', 0.1135)
step 23
('Test accuracy:', 0.97619999999999996)
('Test accuracy:', 0.97450000000000003)
('Test accuracy:', 0.1009)
step 24
('Test accuracy:', 0.97689999999999999)
('Test accuracy:', 0.97430000000000005)
('Test accuracy:', 0.1028)
step 25
('Test accuracy:', 0.97609999999999997)
('Test accuracy:', 0.97599999999999998)
('Test accuracy:', 0.1135)
step 26
('Test accuracy:', 0.97840000000000005)
('Test accuracy:', 0.97419999999999995)
('Test accuracy:', 0.1028)
step 27
('Test accuracy:', 0.96909999999999996)
('Test accuracy:', 0.97019999999999995)
('Test accuracy:', 0.1135)
step 28
('Test accuracy:', 0.9738)
('Test accuracy:', 0.97419999999999995)
('Test accuracy:', 0.1009)
step 29
('Test accuracy:', 0.97460000000000002)
('Test accuracy:', 0.97419999999999995)
('Test accuracy:', 0.1135)
step 30
('Test accuracy:', 0.97640000000000005)
('Test accuracy:', 0.97170000000000001)
('Test accuracy:', 0.1042)
step 31
('Test accuracy:', 0.97409999999999997)
('Test accuracy:', 0.95650000000000002)
('Test accuracy:', 0.089200000000000002)
step 32
('Test accuracy:', 0.97689999999999999)
('Test accuracy:', 0.97109999999999996)
('Test accuracy:', 0.1135)
step 33
('Test accuracy:', 0.97370000000000001)
('Test accuracy:', 0.97340000000000004)
('Test accuracy:', 0.1009)
step 34
('Test accuracy:', 0.97699999999999998)
('Test accuracy:', 0.97150000000000003)
('Test accuracy:', 0.1135)
step 35
('Test accuracy:', 0.97250000000000003)
('Test accuracy:', 0.97140000000000004)
('Test accuracy:', 0.1009)
step 36
('Test accuracy:', 0.97589999999999999)
('Test accuracy:', 0.96950000000000003)
('Test accuracy:', 0.1055)
step 37
('Test accuracy:', 0.97519999999999996)
('Test accuracy:', 0.96509999999999996)
('Test accuracy:', 0.1135)
step 38
('Test accuracy:', 0.97299999999999998)
('Test accuracy:', 0.9728)
('Test accuracy:', 0.1028)
step 39
('Test accuracy:', 0.96909999999999996)
('Test accuracy:', 0.97240000000000004)
('Test accuracy:', 0.1009)
step 40
('Test accuracy:', 0.97399999999999998)
('Test accuracy:', 0.96479999999999999)
('Test accuracy:', 0.1135)
step 41
('Test accuracy:', 0.97799999999999998)
('Test accuracy:', 0.97319999999999995)
('Test accuracy:', 0.1135)
step 42
('Test accuracy:', 0.97419999999999995)
('Test accuracy:', 0.96340000000000003)
('Test accuracy:', 0.1009)
step 43
('Test accuracy:', 0.97740000000000005)
('Test accuracy:', 0.97170000000000001)
('Test accuracy:', 0.1009)
step 44
('Test accuracy:', 0.97160000000000002)
('Test accuracy:', 0.97389999999999999)
('Test accuracy:', 0.1135)
step 45
('Test accuracy:', 0.97599999999999998)
('Test accuracy:', 0.97360000000000002)
('Test accuracy:', 0.1033)
step 46
('Test accuracy:', 0.97389999999999999)
('Test accuracy:', 0.97019999999999995)
('Test accuracy:', 0.1135)
step 47
('Test accuracy:', 0.97650000000000003)
('Test accuracy:', 0.97619999999999996)
('Test accuracy:', 0.10290000000000001)
step 48
('Test accuracy:', 0.97409999999999997)
('Test accuracy:', 0.9647)
('Test accuracy:', 0.1009)
step 49
('Test accuracy:', 0.97240000000000004)
('Test accuracy:', 0.97450000000000003)
('Test accuracy:', 0.1135)
step 50
('Test accuracy:', 0.97570000000000001)
('Test accuracy:', 0.97040000000000004)
('Test accuracy:', 0.1135)
step 51
('Test accuracy:', 0.97250000000000003)
('Test accuracy:', 0.97219999999999995)
('Test accuracy:', 0.1135)
step 52
('Test accuracy:', 0.97230000000000005)
('Test accuracy:', 0.97309999999999997)
('Test accuracy:', 0.1135)
step 53
('Test accuracy:', 0.9758)
('Test accuracy:', 0.97230000000000005)
('Test accuracy:', 0.1135)
step 54
('Test accuracy:', 0.97770000000000001)
('Test accuracy:', 0.97260000000000002)
('Test accuracy:', 0.089200000000000002)
step 55
('Test accuracy:', 0.97340000000000004)
('Test accuracy:', 0.96919999999999995)
('Test accuracy:', 0.1135)
step 56
('Test accuracy:', 0.97170000000000001)
('Test accuracy:', 0.97070000000000001)
('Test accuracy:', 0.1028)
step 57
('Test accuracy:', 0.97670000000000001)
('Test accuracy:', 0.97330000000000005)
('Test accuracy:', 0.1135)
step 58
('Test accuracy:', 0.97589999999999999)
('Test accuracy:', 0.97370000000000001)
('Test accuracy:', 0.1033)
step 59
('Test accuracy:', 0.9748)
('Test accuracy:', 0.97419999999999995)
('Test accuracy:', 0.10290000000000001)
step 60
('Test accuracy:', 0.97409999999999997)
('Test accuracy:', 0.97099999999999997)
('Test accuracy:', 0.1009)
step 61
('Test accuracy:', 0.9758)
('Test accuracy:', 0.97450000000000003)
('Test accuracy:', 0.1135)
step 62
('Test accuracy:', 0.97529999999999994)
('Test accuracy:', 0.97260000000000002)
('Test accuracy:', 0.1028)
step 63
('Test accuracy:', 0.97240000000000004)
('Test accuracy:', 0.96809999999999996)
('Test accuracy:', 0.1135)
Не знаю как у людей, но тут чисто только MLP и получается, что обучать её можно не абы как, не на всех последовательностях.
Обучать ИИ так, как обучают в некоторых местах: сначала только бейсик, потом только фортран, потом только ассемблер и т.д. не приведет к успеху. / :-) / Если выявленная особенность присуща всем процессам обучения, как людям так и роботам, то все программы вузов нужно тщательно исследовать.
from keras import backend as K_B
from keras.datasets import mnist
from keras.layers import Input, Dense, Dropout
from keras.models import Sequential
from keras.optimizers import RMSprop
from keras.utils import np_utils
import numpy as np
import matplotlib.pyplot as plt
%matplotlib inline
def MLP(ind):
model = Sequential()
model.add(Dense(512, activation='relu', input_shape=(width * height,)))
model.add(Dropout(0.2))
model.add(Dense(512, activation='relu'))
model.add(Dropout(0.2))
model.add(Dense(num_classes, activation='softmax'))
model.compile(loss='categorical_crossentropy',
optimizer=RMSprop(),
metrics=['accuracy'])
if (ind == 0): # начальные веса в каждой тройке испытаний должны быть одинаковыми
model.save_weights('weights.h5') # натив, сохраняем веса
else:
model.load_weights('weights.h5', by_name = False) # эксперимент, восстанавливаем веса
history = model.fit(X_train, Y_train,
shuffle = False, # добавлена эта строка что бы запретить keras перемешивать батч
batch_size=batch_size,
epochs=epochs,
verbose=0,
validation_data=(X_test, Y_test))
score = model.evaluate(X_test, Y_test, verbose=0)
print('Test accuracy:', score[1])
K_B.clear_session()
# сессию уничтожаем в возрождаем преднамеренно, что бы начальные веса отличались
return(score[1])
batch_size = 12
epochs = 12
hidden_size = 512
(X_train, y_train), (X_test, y_test) = mnist.load_data()
num_train, width, height = X_train.shape
num_test = X_test.shape[0]
num_classes = np.unique(y_train).shape[0]
X_train = X_train.astype('float32')
X_test = X_test.astype('float32')
X_train /= 255.
X_test /= 255.
X_train = X_train.reshape(num_train, height * width)
X_test = X_test.reshape(num_test, height * width)
XX_train = np.copy(X_train)
yy_train = np.copy(y_train)
Y_train = np_utils.to_categorical(y_train, num_classes)
Y_test = np_utils.to_categorical(y_test, num_classes)
steps = 64
st = np.arange(steps, dtype='int')
res_N = np.arange((steps), dtype='float')
res_1 = np.arange((steps), dtype='float')
res_2 = np.arange((steps), dtype='float')
for n in xrange(steps):
# __ натив
X_train = np.copy(XX_train)
y_train = np.copy(yy_train)
Y_train = np_utils.to_categorical(y_train, num_classes)
print ' step ', n
res_N[n] = MLP(0)
# __ 00..0011..1122..2233.. .. 8899..99
perm = np.arange(num_train, dtype='int')
cl = np.zeros(num_classes, dtype='int')
for k in xrange(num_train):
if (cl[yy_train[k]] * num_classes + yy_train[k] < num_train):
perm[ cl[yy_train[k]] * num_classes + yy_train[k] ] = k
cl[yy_train[k]] += 1
for k in xrange(num_train):
X_train[k,...] = XX_train[perm[k],...]
for k in xrange(num_train):
y_train[k] = yy_train[perm[k]]
Y_train = np_utils.to_categorical(y_train, num_classes)
res_2[n] = MLP(2)
# __ 0123..78901..7890123..789
perm = np.arange(num_train, dtype='int')
j = 0
for k in xrange(num_classes):
for i in xrange(num_train):
if (yy_train[i] == k):
perm[j] = i
j += 1
for k in xrange(num_train):
X_train[k,...] = XX_train[perm[k],...]
y_train[k] = yy_train[perm[k]]
Y_train = np_utils.to_categorical(y_train, num_classes)
res_1[n] = MLP(1)
Другие типы сетей не проверял и эта проверка занимает на моей не очень теслистой тесле много часов.
Комментарии (8)
SADKO
25.01.2018 13:53Ничего удивительного, посмотрите внимательно как происходит обучение MLP, а обучать их можно по разному НО, когда примеры не перемешаны, а «потенциал коррекции весов» (назовём это так) убывает, получается что под первый пример она ещё как-то затачивается, а под последующие уже никак, приехали, либо кишка тонка, либо взаимный выпил в зависимости от того как будем учить.
Вообще тема вроде-бы баянистая в популярной литературе не раз перетёртая.fedorro
26.01.2018 12:58Скорее наоборот, после всего скопа примеров последнего класса веса затачиваются на этот класс, затирая результаты тренировок на остальных предыдущих классах. Проверить просто — оставить в тестовом множестве только первый класс (нули) или только последний (девятки), и посмотреть какие будут ошибки.
ChePeter Автор
26.01.2018 15:18Если оставить только последние «9» (5949 шт) MLP учится и результат
step 0
('Test accuracy:', 0.97360000000000002)
('Test accuracy:', 0.9758)
('Test accuracy:', 0.77300000000000002)
step 1
('Test accuracy:', 0.97460000000000002)
('Test accuracy:', 0.97509999999999997)
('Test accuracy:', 0.88370000000000004)
step 2
('Test accuracy:', 0.97550000000000003)
('Test accuracy:', 0.97299999999999998)
('Test accuracy:', 0.83389999999999997)
Эта последовательность «9» в конце приводит к переобучению, но предыдущие результаты не затираются, когда они есть.
Статья, в том числе, о том, что на одинаковой последовательности MLP не учится совсем.fedorro
26.01.2018 16:16но предыдущие результаты не затираются, когда они есть.
Ваш эксперимент как раз показывает, что это не так.
Test accuracy как раз и равен около 0.1, т.к. на девятках, которые были последними, тестовые примеры проходят успешно, и девятки как раз составляют 1/10 от общего числа тестовых примеров, с учетом равномерного распределения. И если в тестовом наборе (x_test, y_test) оставить только 9ки, то Test accuracy будет стремится к единице, даже при упорядочивании примеров, и к нули, если из тестового множества убрать все девятки.
ChePeter Автор
27.01.2018 09:29Если оставить только последние «9» (5949 шт) в интересующем варианте (0..01..12..23...78..89..9) т.е. последнюю последовательность «9», а начало вернуть, как в исходной последовательности, тогда получится результат выше.
PS. Извиняюсь за телеграфный стиль
ChePeter Автор
26.01.2018 17:50Проверка гипотезыfrom keras import backend as K_B from keras.datasets import mnist from keras.layers import Input, Dense, Dropout from keras.models import Sequential from keras.optimizers import RMSprop from keras.utils import np_utils import numpy as np import matplotlib.pyplot as plt %matplotlib inline batch_size = 12 epochs = 12 hidden_size = 512 (X_train, y_train), (X_test, y_test) = mnist.load_data() num_train, width, height = X_train.shape num_test = X_test.shape[0] num_classes = np.unique(y_train).shape[0] X_train = X_train.astype('float32') X_test = X_test.astype('float32') X_train /= 255. X_test /= 255. X_train = X_train.reshape(num_train, height * width) X_test = X_test.reshape(num_test, height * width) XX_train = np.copy(X_train) yy_train = np.copy(y_train) XX_test = np.copy(X_test) yy_test = np.copy(y_test) perm = np.arange(num_train, dtype='int') j = 0 for k in xrange(num_classes): for i in xrange(num_train): if (yy_train[i] == k): perm[j] = i j += 1 for k in xrange(num_train): X_train[k,...] = XX_train[perm[k],...] y_train[k] = yy_train[perm[k]] Y_train = np_utils.to_categorical(y_train, num_classes) Y_test = np_utils.to_categorical(y_test, num_classes) model = Sequential() model.add(Dense(512, activation='relu', input_shape=(width * height,))) model.add(Dropout(0.2)) model.add(Dense(512, activation='relu')) model.add(Dropout(0.2)) model.add(Dense(num_classes, activation='softmax')) model.compile(loss='categorical_crossentropy', optimizer=RMSprop(), # RMSprop(),'adam' metrics=['accuracy']) history = model.fit(X_train, Y_train, shuffle = False, batch_size=batch_size, epochs=epochs, verbose=2, validation_data=(X_test, Y_test)) for n in xrange(10): j = 0 i = 0 for k in xrange(num_test): if (yy_test[k] >= 8 ): X_test[j,...] = XX_test[k,...] y_test[j] = yy_test[k] else: X_test[j,...] = X_test[i,...] y_test[j] = y_test[i] i += 1 j += 1 Y_test = np_utils.to_categorical(y_test, num_classes) score = model.evaluate(X_test, Y_test, verbose=0) print ' remove ', 9 - n print('Test accuracy:', score[1])
Hedgehogues
Думаю, что этот эффект связан с тем, что сеть затачивается под одну (или несколько цифр) последних цифр, если подавать ей примеры все скопом: сначала все "1", затем все "2" и т.д. Связано это с тем, что у нейронов нет памяти и проблемой затухающего градиента. Если хочется, чтобы сеть корректно работала на всех последовательностях, то, вероятно, нужно смотреть в сторону LSTM. Вопрос только зачем. Разве что как исследование.
San_tit
Не думаю, что LSTM тут поможет: проблема не сколько в архитектуре сети, сколько в принципе обучения. Условно, сначала нейроны скатываются по максимальному градиенту для одной цифры, а потом для другой. При этом, если цифры идут блоками (сначала 1, потом 2 итд), то пересечения множеств нейронов, используемых для распознавания каждой цифры будут значительно больше, чем при случайном порядке.