Эта статья про картинки и классификацию. Небольшое исследование свойств, такой вот штрих к портрету MNIST (ну и подсказка в решении других подобных задач).

В сети есть множество публикаций об интерпретации той или иной нейронной сети и значимости и вкладе тех или иных точек в обучение. Есть масса работ про поиск усов, хвостов и других частей и их важности и значимости. Не буду сейчас подменять библиотекарей и составлять список. Просто расскажу о своем эксперименте.

Началось всё с отличного видео Доклад «Как думают роботы. Интерпретация ML-моделей» , просмотренного по совету одного умного человека и как любое толковое дело, поставило множество вопросов. Например: — насколько ключевые точки датасета уникальны?

Или другой вопрос: — в сети множество статей о том, как изменив одну точку картинки можно существенно исказить предсказание сети. Напомню, что в статье рассматриваем только задачи классификации. А насколько такая коварная точка уникальна? А есть ли такие точки в естественной последовательности MNIST и если их найти и выкинуть станет ли точность обучения нейронной сети выше?

Автор, следуя своему традиционному методу избавления от всего лишнего, решил не мешать кучу и выбрал простой, надежный и эффективный способ исследования поставленных вопросов:

в качестве экспериментальной задачи, примера для препарирования, выбрать знакомый всем MNIST ( yann.lecun.com/exdb/mnist ) и его классификацию.

В качестве подопытной сети выбрал классическую, рекомендуемую начинающим, примерную сеть команды KERAS
github.com/keras-team/keras/blob/master/examples/mnist_cnn.py

А само исследование решил провести очень просто.

Обучим сеть из KERAS с таким критерием останова, как отсутствие повышения точности на тестовой последовательности, т.е. учим сеть до тех пор, пока test_accuracy не станет существенно больше validation_accuracy и validation_accuracy в течении 15 эпох не улучшается. Другими словами сеть перестала обучаться и началось переобучение.

Из датасета MNIST сделаем 324 новых датасета путем отбрасывания групп точек и будем учить той же сетью на точно тех же условиях с теми же начальными весами.

Приступим, считаю правильным и верным выкладывать весь код, от первой до последней строчки. Даже если читатели видели его, очевидно, много раз.

Загружаем библиотеки и загружаем датасет mnist, если он еще не загружен.

Далее переводим его в формат 'float32' и нормируем в диапазон 0. — 1.

Подготовка закончена.

'''Trains a simple convnet on the MNIST dataset.
Gets to 99.25% test accuracy after 12 epochs
(there is still a lot of margin for parameter tuning).
16 seconds per epoch on a GRID K520 GPU.
'''

from __future__ import print_function
import keras
from keras.datasets import mnist
from keras.models import Sequential, load_model
from keras.layers import Dense, Dropout, Flatten
from keras.layers import Conv2D, MaxPooling2D
from keras import backend as K
from keras.optimizers import *
from keras.callbacks import EarlyStopping

import numpy as np
import os

num_classes = 10

# input image dimensions
img_rows, img_cols = 28, 28

# the data, shuffled and split between train and test sets
(x_train, y_train), (x_test, y_test) = mnist.load_data()

if K.image_data_format() == 'channels_first':
    x_train = x_train.reshape(x_train.shape[0], 1, img_rows, img_cols)
    x_test = x_test.reshape(x_test.shape[0], 1, img_rows, img_cols)
    input_shape = (1, img_rows, img_cols)
else:
    x_train = x_train.reshape(x_train.shape[0], img_rows, img_cols, 1)
    x_test = x_test.reshape(x_test.shape[0], img_rows, img_cols, 1)
    input_shape = (img_rows, img_cols, 1)

x_train = x_train.astype('float32')
x_test = x_test.astype('float32')
x_train /= np.max(x_train)
x_test /= np.max(x_test)

XX_test = np.copy(x_test)
XX_train = np.copy(x_train)
YY_test = np.copy(y_test)
YY_train = np.copy(y_train)


print('x_train shape:', XX_train.shape)
print('x_test shape:', XX_test.shape)


Запомним в переменных название файлов модели и весов, также accuracy и loss нашей сети. Этого нет в исходном коде, но для эксперимента необходимо.

f_model = "./data/mnist_cnn_model.h5"
f_weights = "./data/mnist_cnn_weights.h5"
accu_f = 'accuracy'
loss_f = 'binary_crossentropy'

Сама сеть — точь в точь как на сайте
github.com/keras-team/keras/blob/master/examples/mnist_cnn.py.

Сохраняем сеть и весы на диск. Все наши попытки обучения будем запускать с одинаковыми начальными весами:

y_train = keras.utils.to_categorical(y_train, num_classes)
y_test = keras.utils.to_categorical(y_test, num_classes)

model = Sequential()

model.add(Conv2D(32, kernel_size=(3, 3),
                 activation='relu',
                 input_shape=input_shape))
model.add(Conv2D(64, (3, 3), activation='relu'))
model.add(MaxPooling2D(pool_size=(2, 2)))
model.add(Dropout(0.25))
model.add(Flatten())
model.add(Dense(128, activation='relu'))
model.add(Dropout(0.5))
model.add(Dense(num_classes, activation='softmax'))

model.compile(loss=[loss_f], optimizer=Adam(lr=1e-4), metrics=[accu_f])
model.summary()

model.save_weights(f_weights)
model.save(f_model)

Model: "sequential"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
=================================================================
conv2d (Conv2D)              (None, 26, 26, 32)        320       
_________________________________________________________________
conv2d_1 (Conv2D)            (None, 24, 24, 64)        18496     
_________________________________________________________________
max_pooling2d (MaxPooling2D) (None, 12, 12, 64)        0         
_________________________________________________________________
dropout (Dropout)            (None, 12, 12, 64)        0         
_________________________________________________________________
flatten (Flatten)            (None, 9216)              0         
_________________________________________________________________
dense (Dense)                (None, 128)               1179776   
_________________________________________________________________
dropout_1 (Dropout)          (None, 128)               0         
_________________________________________________________________
dense_1 (Dense)              (None, 10)                1290      
=================================================================
Total params: 1,199,882
Trainable params: 1,199,882
Non-trainable params: 0
_________________________________________________________________

Запустим обучение на исходном mnist, что бы получить ориентир, базовую эффективность.

x_test = np.copy(XX_test)
x_train = np.copy(XX_train)
s0 = 0

if os.path.isfile(f_model):
    model = load_model(f_model)
    model.load_weights(f_weights, by_name=False)

    step = 0
    while True:
        
        fit = model.fit(x_train, y_train,
                  batch_size=batch_size,
                  epochs=1,
                  verbose=0,
                  validation_data=(x_test, y_test)
                )
        
        current_accu = fit.history[accu_f][0]
        current_loss = fit.history['loss'][0]
        val_accu = fit.history['val_'+accu_f][0]
        val_loss = fit.history['val_loss'][0]
        print("\x1b[2K","accuracy {0:12.10f} loss {1:12.10f} step {2:5d} val_accu {3:12.10f} val_loss {4:12.10f}  ".                          format(current_accu, current_loss, step, val_accu, val_loss), end="\r")
    
        step += 1
        if val_accu > max_accu:
            s0 = 0
            max_accu = val_accu
        else:
            s0 += 1
        if current_accu * 0.995 > val_accu and s0 > 15:
            break
else:
    print("model not found ")

accuracy 0.9967333078 loss 0.0019656278 step   405 val_accu 0.9916999936 val_loss 0.0054226643  

Теперь начнем основной эксперимент. Мы берем для обучения из исходной поледовательности все 60000 размеченных картинок, и в них обнуляем всё, кроме квадрата 9х9. Получим 324 экспериментальных поледовательности и сравним результат обучения сети на них с обучением на оригинальной поледовательности. Обучаем ту же сеть с теми же начальными весами.

batch_size = 5000
s0 = 0
max_accu = 0.

for i in range(28 - 9):
    for j in range(28 - 9):
        print("\ni= ", i, "  j= ",j)
        x_test = np.copy(XX_test)
        x_train = np.copy(XX_train)

        x_train[:,:i,:j,:] = 0.
        x_test [:,:i,:j,:] = 0.

        x_train[:,i+9:,j+9:,:] = 0.
        x_test [:,i+9:,j+9:,:] = 0.

        if os.path.isfile(f_model):
            model = load_model(f_model)
            model.load_weights(f_weights, by_name=False)
        else:
            print("model not found ")
            break

        step = 0
        while True:
            
            fit = model.fit(x_train, y_train,
                      batch_size=batch_size,
                      epochs=1,
                      verbose=0,
                      validation_data=(x_test, y_test)
                    )
            
            current_accu = fit.history[accu_f][0]
            current_loss = fit.history['loss'][0]
            val_accu = fit.history['val_'+accu_f][0]
            val_loss = fit.history['val_loss'][0]
            print("\x1b[2K","accuracy {0:12.10f} loss {1:12.10f} step {2:5d} val_accu {3:12.10f} val_loss {4:12.10f}  ".   format(current_accu, current_loss, step, val_accu, val_loss), end="\r")
        
            step += 1
            if val_accu > max_accu:
                s0 = 0
                max_accu = val_accu
            else:
                s0 += 1
            if current_accu * 0.995 > val_accu and s0 > 15:
                break

Нет смысла выкладывать тут все 324 результата, если кому интересно, то могу выслать персонально. Расчет занимает несколько суток, это если кто хочет повторить.

Как оказалось сеть на обрезке 9х9 может обучаться как хуже, что очевидно, но и лучше, что совсем не очевидно.

Например:

i= 0 j= 14
accuracy 0.9972333312 loss 0.0017946947 step 450 val_accu 0.9922000170 val_loss 0.0054322388

i= 18 j= 1
accuracy 0.9973166585 loss 0.0019487827 step 415 val_accu 0.9922000170 val_loss 0.0053000450

Мы выбрасываем из картинок с рукописными цифрами всё, кроме квадрата 9х9 и качество обучения и распознавания у нас улучшается!

Так же явно видно, что такая особая область, повышающая качество сети, не одна. И не две, это две приведены как пример.

Итог этого эксперимента и предварительные выводы.

  • Любой естественный датасет, не думаю, что ЛеКун специально искажал что-то, содержит не только точки существенные для обучения, но и точки мешающие обучению. Задача поиска «вредных» точек становится актуальной, они есть, даже если их не видно.
  • Можно проводить стекинг и блендинг не только вдоль датасета, выбирая картинки группами, но и поперек, выбирая области картинок для разбиения и далее как обычно. Такой подход в данном случае повышает качество обучения и есть надежда, что в аналогичной задаче применение такого стекинга поперек позволит прибавить в качестве. А на том же kaggle.com несколько десятитысячных иногда (почти всегда ) позволяют существенно поднять свой авторитет и рейтинг.

Спасибо за внимание.