Для создания изображений с помощью GAN я буду использовать Tensorflow.
Генеративно-состязательная сеть (GAN) — это модель машинного обучения, в которой две нейронные сети соревнуются друг с другом, чтобы быть более точными в своих прогнозах.
Как работают GAN?
Первым шагом в создании GAN является определение желаемого конечного результата и сбор начального набора обучающих данных на основе этих параметров. Затем эти данные рандомизируются и передаются в генератор до тех пор, пока они не достигнут базовой точности в получении результатов.
После этого сгенерированные изображения передаются в дискриминатор вместе с фактическими точками данных из исходной концепции. Дискриминатор фильтрует информацию и возвращает вероятность от 0 до 1, чтобы представить подлинность каждого изображения (1 соответствует реальному, а 0 соответствует ложному). Эти значения затем проверяются на точность и повторяются до тех пор, пока не будет достигнут желаемый результат.
Зачем генерировать изображение ЭКГ?
Я создал проект coronarography.ai . В нем на вход подается изображение ЭКГ, а на выходе мы получаем наличие патологии магистральных артерий сердца. Мне стало интересно проверить принципиальную возможность генерации изображений ЭКГ и сравнить полученные изображения с реальными.
В дальнейшем мне удалось одновременно генерировать синтетические табличные данные и изображения ЭКГ с сохраненной потоковой зависимостью. (расскажу в следующей статье).
Используемые библиотеки
import pandas as pd
import glob
import imageio
import matplotlib.pyplot as plt
import numpy as np
import os
import PIL
from PIL import Image
from tensorflow.keras import layers
import time
import tensorflow as tf
from IPython import display
import seaborn as sns
import matplotlib.pyplot as plt
import matplotlib.mlab as mlab
import matplotlib
from matplotlib.pyplot import figure
from sklearn.preprocessing import MinMaxScaler
import joblib
import tensorflow as tf
from tensorflow.keras.layers import *
from tensorflow.keras.models import Model
print(tf.__version__)
%matplotlib inline
matplotlib.rcParams['figure.figsize'] = (18,10)
import moviepy.editor as mpe
Загрузка и подготовка набора данных
Здесь мы загружаем изображения ЭКГ из папки в массив. Преобразовываем в одноканальное (черно-белое) изображение, нормализуем его.
data_image = []
for k in os.listdir('../AI_coronarography/DATA_WORK/DATA_WORK/ЭКГ'):
if k.endswith('.jpg'):
img = Image.open('../AI_coronarography/DATA_WORK/DATA_WORK/ЭКГ/'+k)
img = img.convert('L')
img = img.resize((800, 800))
data_image += [(np.array(img) - 127.5) / 127.5]
Пример загруженного ЭКГ изображения
Определим размер батча и перемешаем изображения.
train_images = np.array(data_image).reshape(np.array(data_image).shape[0], 800, 800, 1).astype('float32')
BUFFER_SIZE = 100
BATCH_SIZE = 10
train_dataset = tf.data.Dataset.from_tensor_slices(train_images).shuffle(BUFFER_SIZE).batch(BATCH_SIZE)
Определим структуру генератора
def make_generator_model():
input_1 = Input(shape=(100, ), name = "Input_image")
x = Dense(25*25*256, use_bias=False)(input_1)
x = BatchNormalization()(x)
x = LeakyReLU()(x)
x = Reshape((25, 25, 256))(x)
x = Conv2DTranspose(256, (1, 1), strides=(1, 1), padding='same', use_bias=False)(x)
x = BatchNormalization()(x)
x = LeakyReLU()(x)
x = Conv2DTranspose(128, (3, 3), strides=(2, 2), padding='same', use_bias=False)(x)
x = BatchNormalization()(x)
x = LeakyReLU()(x)
x = Conv2DTranspose(64, (5, 5), strides=(2, 2), padding='same', use_bias=False)(x)
x = BatchNormalization()(x)
x = LeakyReLU()(x)
x = Conv2DTranspose(32, (7, 7), strides=(2, 2), padding='same', use_bias=False)(x)
x = BatchNormalization()(x)
x = LeakyReLU()(x)
x = Conv2DTranspose(16, (9, 9), strides=(2, 2), padding='same', use_bias=False)(x)
x = BatchNormalization()(x)
x = LeakyReLU()(x)
x = Conv2DTranspose(1, (11, 11), strides=(2, 2), padding='same', use_bias=False, activation='tanh')(x)
model = Model(inputs=input_1, outputs=x)
return model
Проверим необученный генератор.
generator = make_generator_model()
noise = tf.random.normal([1, 100])
generated_image = generator(noise, training=False)
plt.imshow(generated_image[0, :, :, 0], cmap='gray');
Создадим дискриминатор
def make_discriminator_model():
input_1 = Input(shape=(800, 800, 1), name = "Input_image")
x = Conv2D(16, (11, 11), strides=(2, 2), padding='same')(input_1)
x = LeakyReLU()(x)
x = Dropout(0.3)(x)
x = Conv2D(32, (9, 9), strides=(2, 2), padding='same')(x)
x = LeakyReLU()(x)
x = Dropout(0.3)(x)
x = Conv2D(64, (7, 7), strides=(2, 2), padding='same')(x)
x = LeakyReLU()(x)
x = Dropout(0.3)(x)
x = Conv2D(128, (5, 5), strides=(2, 2), padding='same')(x)
x = LeakyReLU()(x)
x = Dropout(0.3)(x)
x = Conv2D(256, (3, 3), strides=(2, 2), padding='same')(x)
x = LeakyReLU()(x)
x = Dropout(0.3)(x)
x = Conv2D(256, (1, 1), strides=(1, 1), padding='same')(x)
x = LeakyReLU()(x)
x = Dropout(0.3)(x)
x = Flatten()(x)
x = Dense(1)(x)
model = Model(inputs=input_1, outputs=x)
return model
Мы используем еще не обученный дискриминатор, чтобы классифицировать
сгенерированные изображения как настоящие или фейковые. Модель будет
обучена выводить положительные значения для реальных изображений и
отрицательные значения для поддельных изображений.
discriminator = make_discriminator_model()
decision = discriminator(generated_image)
print (decision)
#result
tf.Tensor([[1.6453338e-05]], shape=(1, 1), dtype=float32)
Определим функции потерь и оптимизаторы для обеих моделей
Потери дискриминатора.
Этот метод количественно определяет, насколько хорошо дискриминатор способен отличать настоящие изображения от фейковых. Он сравнивает предсказания дискриминатора для реальных изображений с массивом 1 и предсказания дискриминатора для поддельных (сгенерированных) изображений с массивом 0.
Потери генератора.
Потеря генератора определяет, насколько хорошо он смог обмануть дискриминатор. Интуитивно, если генератор работает хорошо, дискриминатор классифицирует поддельные изображения как настоящие (или 1).
Оптимизаторы дискриминатора и генератора отличаются, потому что я обучаю две сети по отдельности.
cross_entropy = tf.keras.losses.BinaryCrossentropy(from_logits=True)
def discriminator_loss(real_output, fake_output):
real_loss = cross_entropy(tf.ones_like(real_output), real_output)
fake_loss = cross_entropy(tf.zeros_like(fake_output), fake_output)
total_loss = real_loss + fake_loss
return total_loss
def generator_loss(fake_output):
return cross_entropy(tf.ones_like(fake_output), fake_output)
generator_optimizer = tf.keras.optimizers.Adam(1e-4)
discriminator_optimizer = tf.keras.optimizers.Adam(1e-4)
Сохраним чекпойнты.
checkpoint_dir = './training_checkpoints'
checkpoint_prefix = os.path.join(checkpoint_dir, "ckpt")
checkpoint = tf.train.Checkpoint(generator_optimizer=generator_optimizer,
discriminator_optimizer=discriminator_optimizer,
generator=generator,
discriminator=discriminator)
Определение цикла обучения
Цикл обучения начинается с того, что генератор получает на вход случайное начальное число. Это число используется для создания ЭКГ. Затем дискриминатор используется для классификации реальных изображений (извлеченных из обучающего набора) и поддельных изображений (сгенерированных генератором). Потери рассчитываются для каждой из этих моделей, а градиенты используются для обновления генератора и дискриминатора.
EPOCHS = 6000
noise_dim = 100
num_examples_to_generate = 1
seed = tf.random.normal([num_examples_to_generate, noise_dim])
@tf.function
def train_step(images):
noise = tf.random.normal([BATCH_SIZE, noise_dim])
with tf.GradientTape() as gen_tape, tf.GradientTape() as disc_tape:
generated_images = generator(noise, training=True)
real_output = discriminator(images, training=True)
fake_output = discriminator(generated_images, training=True)
gen_loss = generator_loss(fake_output)
disc_loss = discriminator_loss(real_output, fake_output)
gradients_of_generator = gen_tape.gradient(gen_loss, generator.trainable_variables)
gradients_of_discriminator = disc_tape.gradient(disc_loss, discriminator.trainable_variables)
generator_optimizer.apply_gradients(zip(gradients_of_generator, generator.trainable_variables))
discriminator_optimizer.apply_gradients(zip(gradients_of_discriminator, discriminator.trainable_variables))
def train(dataset, epochs):
for epoch in range(epochs):
start = time.time()
for image_batch in dataset:
train_step(image_batch)
display.clear_output(wait=True)
generate_and_save_images(generator,
epoch + 1,
seed)
if (epoch + 1) % 500 == 0:
checkpoint.save(file_prefix = checkpoint_prefix)
print ('Time for epoch {} is {} sec'.format(epoch + 1, time.time()-start))
display.clear_output(wait=True)
generate_and_save_images(generator,
epochs,
seed)
Создание и сохранение изображений
def generate_and_save_images(model, epoch, test_input):
predictions = model(test_input, training=False)
fig = plt.figure(figsize=(12, 12))
plt.imshow(predictions[0, :, :, 0] * 127.5 + 127.5, cmap='gray')
plt.axis('off')
plt.savefig('image_per_epoch/image_at_epoch_{:04d}.png'.format(epoch), bbox_inches='tight', pad_inches=0)
plt.show()
Обучение модели
Вызовите метод train(), определенный выше, для одновременного обучения генератора и дискриминатора. Важно, чтобы генератор и дискриминатор не подавляли друг друга (например, чтобы они обучались с одинаковой скоростью).
В начале обучения сгенерированные изображения выглядят как случайный шум. По мере того, как нейронные сети будут учиться, сгенерированные изображения ЭКГ будут выглядеть все более и более реальными.
train(train_dataset, EPOCHS)
Видео обучения нейронной сети, генерирующей изображения ЭКГ.
Сравнение реальных и сгенерированных изображений ЭКГ.
Комментарии (4)
vassabi
03.01.2023 22:38+6мне одному кажется что это из пушки по воробьям и было бы гораздо лучше, чтобы восстанавливать из ЭКГ график точек. Из них уже моделировать одномерную ампдитуду, а потом из него формировать обратно в изображение ЭКГ ?
или это такая новая нормальность?
sairus777
04.01.2023 04:31Нет, не вам одному. Программистом можешь ты не быть, но тренькать сетки ты обязан...
Thyroxine
04.01.2023 16:54Нет, вам не кажется. И если уж у автора есть желание потренировать сетки, то для временных рядов (а ЭКГ это именно он) лучше подходят рекуррентные сети, например LSTM, как в статье в Nature Electrocardiogram generation with a bidirectional LSTM-CNN generative adversarial network.
GbrtR