Прочитав несколько известных статей по сегментации спутниковых снимков земли, я решил попробовать создать и обучить свою модель нейросети для этой задачи. И конечно, в процессе возникало много вопросов, своими ответами на которые я решил поделиться в рамках этого туториала. Поделиться так подробно и просто, как это было бы понятно таким новичкам, как я.

  1. Идея

  2. Импорты

  3. Датасет

    1. Загрузка

    2. Организация

    3. Вспомогательные функции и класс датасета

  4. Модель

  5. Обучение и тестирование модели

    1. Метрики

      1. Pixel Accuracy

      2. IoU

    2. Обучение

      1. Исследование обучения

      2. Сохранение модели

    3. Тестирование

  6. Заключение

Идея

Одним из главных источников идеи и подхода для создания и обучения моей нейросети стала статья LULC Segmentation of RGB Satellite Image Using FCN-8 [перевод на русский]: берем датасет спутниковых изображений земли Gaofen Image Dataset → разделяем изображения на подизображения, вместо уменьшения и аугментации исходных изображений: так она будет лучше обучаться → определяем модель нейросети с архитектурой Fully Convolutional Network (FCN, FCN-8) → обучаем ее. Единственное, вместо модели FCN-8 я решил выбрать доступную в PyTorch предобученную модель ResNet101, написанную по архитектуре FCN (FCN.ResNet101). Также вместо большой версии датасета (Large) я выбрал упрощенную версию (Fine) — она включает в себя 10 оригинальных и сегментированных по 15 классам изображений в формате tif и размером 7200x6800. Модель FCN.ResNet101 предполагает размер входных изображений 224x224, а так-как размер исходных изображений чуть больше, чем может вместиться целое кол-во подизображений размера 224x224, мы обрезаем исходные изображения до 7168x6720 (именно обрезаем, чтобы избежать потери детализации); таким образом, у нас получается 960 подизображений с одного исходного изображения.

Приступим к реализации этой идеи!

Импорты

Импортируем все необходимые библиотеки:

import torch
import torchvision
import numpy as np
import matplotlib.pyplot as plt
from PIL import Image
from sklearn.model_selection import train_test_split
from os import listdir
from typing import Tuple, List, Dict, Generator, Any
from IPython.display import clear_output

Датасет

Сайт датасета находится по этой ссылке, а две версии датасета (Fine и Large) на OneDrive. Но здесь я буду загружать его со своего Google Drive.

Загрузка

Если вы пользуетесь не Google Colab, вам может понадобиться установить модуль gdown:

pip3 install gdown 

Загрузим архивы датасета:

%%capture # clear output
!gdown "https://drive.google.com/uc?id=1--fNMFRXmBRDQPwdFqh54Sx7_RNE6z4a&confirm=t" # download image_RGB.zip
!gdown "https://drive.google.com/uc?id=1-3y-bJK-QZapo3nwZEXdVLWDjkVO4AYN&confirm=t" # download label_15classes.zip

И разархивируем их:

%%capture
!unzip image_RGB.zip
!unzip label_15classes.zip

Организация

Чтобы не нагружать оперативную память всеми подизображениями, можно сохранить каждое как файл и организовать их подгрузку. Сделаем это, но перед этим определимся с классами.

Упрощенная версия датасета, как уже было сказано, содержит сегментированные по 15 классам изображения. Их описание можно посмотреть в readme.txt на OneDrive:

label information of 15 classes:
industrial land
RGB:    200,    0,    0
urban residential
RGB:    250,    0, 150
rural residential
RGB:    200, 150, 150
traffic land
RGB:    250, 150, 150
paddy field
RGB:    0,     200,    0
irrigated land
RGB:    150,  250,   0
...

И так-как в вышеупомянутой статье модели создавали и обучали для сегментации одного целевого класса, выберем один из них, например, Irrigated Land: он достаточно хорошо представлен на изображениях датасета, а потому с ним будет достаточно просто. Но также важно добавить еще один класс — Background, так-как он важен для более корректного обучения модели.

Инициализируем словари с классами — classes и classes_by_id (последний пригодится для одной из функций далее):

classes = {
    (0, 0, 0): (0, '__background__'),
    (150, 250, 0): (1, 'irrigated_land')
}

classes_by_id = dict()
for rgb, (id, name) in classes.items():
  classes_by_id[id] = (rgb, name)

Теперь подизображения. Напишем функцию, которая будет возвращать генератор подизображений из исходного изображения:

def get_subimages_generator(
    image: Image.Image,
    subimage_size: Tuple[int, int, int]
) -> Generator[Image.Image, None, None]:
  for r in range(image.size[1] // subimage_size[1]):
    for c in range(image.size[0] // subimage_size[0]):
      yield image.crop(box=(
              c * subimage_size[0],
              r * subimage_size[1],
              (c + 1) * subimage_size[0],
              (r + 1) * subimage_size[1]
          )
      )

Создадим директорию dataset с директориями для каждого типа подизображений изображений датасета (оригинальные и сегментированные):

!mkdir dataset dataset/originals dataset/labeleds

Наконец, напишем функцию, которая будет сохранять все подизображения в соответствующие директории под строковыми id (они понадобятся нам далее). Функция будет проходится по каждым сегментированным подизображениям, фильтровать их по минимальным процентам представленности определенного цвета и, собственно, сохранять:

def save_dataset_subimages(classes_filter: Dict[Tuple[int, int, int], float]):

  for i, filename in enumerate(listdir('image_RGB/')):
      basename = filename[:filename.find('.tif')]
      
      image = Image.open(fp=f'image_RGB/{basename}.tif').crop(box=(16, 40, 7200 - 16, 6800 - 40))
      image_labeled = Image.open(fp=f'label_15classes/{basename}_label.tif').crop(box=(16, 40, 7200 - 16, 6800 - 40))
      subimages = get_subimages_generator(image=image, subimage_size=(224,224))
      subimages_labeleds = get_subimages_generator(image=image_labeled, subimage_size=(224,224))

      for si, subimage in enumerate(subimages):
        subimage_labeled = next(subimages_labeleds)
        
        # classes filter
        do_continue = False
        subimage_labeled_colors_dict = {rgb: count for count, rgb in subimage_labeled.getcolors()}
        for rgb, min_percent in classes_filter.items():
          if rgb not in subimage_labeled_colors_dict \
              or subimage_labeled_colors_dict[rgb] * 100 / 50176 < min_percent:
              # 50176 = subimage width * subimage height
            do_continue = True
            break
        if do_continue:
          continue

        subimage.save(fp=f'dataset/originals/i{i}si{si}.tif')
        subimage_labeled.save(fp=f'dataset/labeleds/i{i}si{si}_labeled.tif')

Задействуем ее:

save_dataset_subimages(
    classes_filter={
        (150, 250, 0): 5
    }
)

Команда ниже поможет нам узнать, сколько всего оригинальных, а значит и сегментированных, подизображений было сохранено:

!ls -lR dataset/originals/*.tif | wc -l

Всего для целевого класса Irrigated Land сохранено 4166 оригинальных, а значит и сегментированных, подизображений.

Вспомогательные функции и класс датасета

Для подгрузки и использования подизображений необходимо определить несколько функций и класс датасета, которые будут их обрабатывать и представлять так, как это будет удобно для работы с моделью. А для этого необходимо представлять оригинальные и сегментированные подизображения в виде тензоров (типе torch.Tensor), причем для большинства моделей, в том числе и для FCN.ResNet101, размерность тензоров должна предполагать наличие партий (batches): (batches, width, height, channels), где width, height и channels — ширина, высота и кол-во каналов изображения соответственно, в нашем случае — 224, 224 и 3 (R, G, B) соответственно. В случае с сегментированными изображениями, оно должно быть представлено еще и в размерности (batches, classes, width, height), где classes — количество сегментируемых классов, в нашем случае — 2; то есть должно представлять из себя маску, где для каждого класса указывается отдельная маска 224x224, которая указывает 0 или 1, на каком пикселе присутствует определенный класс, а на каком нет. Собственно, эти значения (0 или 1) для каждого пикселя каждой маски модель для семантической сегментации и должна предсказывать, оперируя вероятностью от 0 до 1.

Наконец, приступим к коду. Определим функции чтобы…

  • собирать два соответствующих подизображения (оригинальное и сегментированное) по их строковым id, которые были определены при их сохранении:

def get_dataset_subimage(dataset_subimage_id: str) -> Tuple[Image.Image, Image.Image]:

  subimage = Image.open(fp=f'dataset/originals/{dataset_subimage_id}.tif')
  subimage_labeled = Image.open(fp=f'dataset/labeleds/{dataset_subimage_id}_labeled.tif')

  return subimage, subimage_labeled
  • получать маску из сегментированного изображения и наоборот:

В первой функции проходимся по каждому пикселю сегментированного изображения и присваиваем маске для определенного класса определенного пикселя маски значение 1.0, при этом, если для пикселя не находится класса, относим его к классу Background. Получаем массив размерностью (classes, width, height):

def get_image_mask_from_labeled(
    image_labeled: Image.Image,
    classes: Dict[Tuple[int, int, int], Tuple[int, str]]
) -> np.ndarray:

  image_mask = np.zeros(shape=(len(classes),image_labeled.size[0],image_labeled.size[1]))

  image_labeled_ndarray = np.array(object=image_labeled)
  for r in np.arange(stop=image_labeled_ndarray.shape[0]):
    for c in np.arange(stop=image_labeled_ndarray.shape[1]):
      class_rgb = tuple(image_labeled_ndarray[r][c])
      class_value = classes.get(class_rgb)
      if class_value != None:
        image_mask[class_value[0]][r][c] = 1.0
      else:
        image_mask[0][r][c] = 1.0

  return image_mask

Во второй функции получаем массив из индексов классов с максимальными вероятностями среди других классов для каждого пикселя, а затем присваиваем каждому пикселю выходного изображения цвет в соответствии с его классом. То есть, обратно получаем массив размерностью (width, height, channels):

def get_image_labeled_from_mask(
    image_mask: np.ndarray,
    classes_by_id: Dict[Tuple[int, int, int], Tuple[int, str]]
) -> Image.Image:

  image_labeled_ndarray = np.zeros(
      shape=(image_mask.shape[1],image_mask.shape[2],3),
      dtype=np.uint8
  )

  image_mask_hot = image_mask.argmax(axis=0)
  for r in np.arange(stop=image_mask_hot.shape[0]):
    for c in np.arange(stop=image_mask_hot.shape[1]):
      class_id = image_mask_hot[r][c]
      class_by_id_value = classes_by_id.get(class_id)
      image_labeled_ndarray[r][c] = np.array(object=class_by_id_value[0])
  
  image_labeled = Image.fromarray(obj=image_labeled_ndarray)
  
  return image_labeled
  • чтобы предобрабатывать изображения (переводить в тензор):

def image_preprocess(image: Image.Image) -> torch.Tensor:
  return torchvision.transforms.ToTensor()(pic=image)
  • переводить два соответствующих подизображения в тензоры:

def get_dataset_subimage_tensor(
    subimage: Image.Image,
    subimage_labeled: Image.Image,
    classes: Dict[Tuple[int, int, int], Tuple[int, str]],
    dtype: torch.FloatType = None,
) -> Tuple[torch.Tensor, torch.Tensor]:

  subimage_tensor = image_preprocess(image=subimage)
  subimage_mask_tensor = torch.tensor(
      data=get_image_mask_from_labeled(
          image_labeled=subimage_labeled,
          classes=classes
      ),
      dtype=dtype
  )

  return subimage_tensor, subimage_mask_tensor

Проверим, как они работают. Загружаем два соответствующих подизображения по их строковому id:

subimage, subimage_labeled = get_dataset_subimage(dataset_subimage_id='i0si0')

Выводим их с помощью модуля matplotlib.pyplot:

fig, ax = plt.subplots(ncols=2)
ax[0].imshow(subimage)
ax[1].imshow(subimage_labeled)
plt.show()

Проверим корректность обработки масок. Переводим сегментированное подизображение в маску:

subimage_mask = get_image_mask_from_labeled(
    image_labeled=subimage_labeled,
    classes=classes
)

Переводим маску обратно в сегментированное подизображение:

subimage_labeled_from_mask = get_image_labeled_from_mask(
    image_mask=subimage_mask,
    classes_by_id=classes_by_id
)

Сравниваем оригинальное подизображение с сегментированным из маски:

fig, ax = plt.subplots(ncols=2)
ax[0].imshow(subimage)
ax[1].imshow(subimage_labeled_from_mask)
plt.show()

Работает отлично!

Осталось только определить специальный класс датасета (с помощью абстрактного класса torch.utils.data.Dataset), который вместе с даталоадером (типом torch.utils.data.DataLoader) помогут удобно подгружать подизображения. И чтобы все они не занимали много памяти, в самом классе датасета мы определим их строковые id, а по запросу на получение элемента (вызову перегруженного метода __getitem__) будем уже загружать их, переводить в тензоры и возвращать. При этом важно создать несколько датасетов — обучающий и тестовый.

Но перед этим, нам нужно собрать все id подизображений:

def get_dataset_subimages_id() -> List[str]:
  return [
      filename[:filename.find('.tif')]
      for filename in listdir(path='dataset/originals/')
  ]

Теперь разобъем список id на обучающий и тестовый. Для этого используем функцию train_test_split из модуля sklearn, которая поможет разбить их на обучающую и тестовую выборки, где длина обучающей выборки будет составлять 90%, а тестовой — 10%:

train_dataset_subimages_id, test_dataset_subimages_id = train_test_split(get_dataset_subimages_id(), train_size=0.9)

Наконец, определим класс датасета. Дополнительно для даталоадера нам понадобится перегрузить метод __len__ (для получение длины датасета) — в качестве длины датасета в нашем случае будет выступать длина списка id:

class Dataset(torch.utils.data.Dataset):
  def __init__(self,
      dataset_subimages_id: List[str],
      classes: Dict[Tuple[int, int, int], Tuple[int, str]],
  ):
    self.dataset_subimages_id = dataset_subimages_id
    self.classes = classes
  
  def __getitem__(self, idx: int) -> Tuple[torch.Tensor, torch.Tensor]:
    subimage = Image.open(fp=f'dataset/originals/{self.dataset_subimages_id[idx]}.tif')
    subimage_labeled = Image.open(fp=f'dataset/labeleds/{self.dataset_subimages_id[idx]}_labeled.tif')
    subimage_tensor, subimage_mask_tensor = get_dataset_subimage_tensor(
        subimage=subimage,
        subimage_labeled=subimage_labeled,
        classes=self.classes
    )
    return subimage_tensor, subimage_mask_tensor
  
  def __len__(self) -> int:
    return len(self.dataset_subimages_id)

Используем вышеописанный класс для создания датасетов:

train_dataset = Dataset(
    dataset_subimages_id=train_dataset_subimages_id,
    classes=classes,
)
test_dataset = Dataset(
    dataset_subimages_id=test_dataset_subimages_id,
    classes=classes
)

Даталоадеры удобны тем, что с помощью них можно удобно разбивать датасет на партии и перемешивать. Оба этих параметра оказываются достаточно важными при обучении. Для нашей модели наиболее подходящим для обучающей выборки, по результатам моих экспериментов, оказывается batch_size=32 и shuffle=True.

train_dataloader = torch.utils.data.DataLoader(
    dataset=train_dataset,
    batch_size=32,
    shuffle=True
)
test_dataloader = torch.utils.data.DataLoader(
    dataset=test_dataset,
    batch_size=1
)

Модель

Можно было бы написать свой класс модели с помощью абстрактного класса torch.nn.Module, однако, как уже говорилось в начале, я решил использовать готовую модель с предобученными на датасете ImageNet1K весами из модуля torchvisionFCN.ResNet101. Причем, c предобученными весами не всей модели, а только слоев backbone’а (так называемых извлекателей признаков), которые изображены на рисунке архитектуры FCN ниже между изображением котика с собачкой и слоя pixelwise prediction.

Рисунок из статьи Fully Convolutional Network for Semantic Segmentation, иллюстрирующий возможность FCN обучаться семантической сегментации.
Рисунок из статьи Fully Convolutional Network for Semantic Segmentation, иллюстрирующий возможность FCN обучаться семантической сегментации.

Но для начала, нужно инициализировать переменную device, которая поможет нам автоматически определять, на какой оперативной памяти (CPU или GPU) она будет находится и на какой оперативной памяти будут проводится вычисления нейросети. При обучении нейросети, особенно с большой архитектурой, важно пользоваться преимуществами распараллеливания вычислений на GPU, чтобы модель быстрее обрабатывала входные данные:

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

Наконец, инициализируем модель с помощью модуля torchvision и определим кол-во выходных классов и веса backbone’а:

%%capture
model = torchvision.models.segmentation.fcn_resnet101(
    num_classes=len(classes),
    weights_backbone=torchvision.models.ResNet101_Weights.IMAGENET1K_V1
).to(device=device)

Напишем функцию для предсказывания по входному изображению:

def predict(
    image: Image.Image,
    model: torch.nn.Module,
    device: torch.DeviceObjType,
) -> Image.Image:

  image_tensor = image_preprocess(image=image)

  with torch.no_grad():
    output_image_mask = model(image_tensor.unsqueeze(0).to(device))['out'][0].cpu().numpy()

  predicted_image_labeled = get_image_labeled_from_mask(
      image_mask=output_image_mask,
      classes_by_id=classes_by_id
  )

  return predicted_image_labeled

И проверим ее:

predicted_image_labeled = predict(
    image=subimage,
    model=model,
    device=device
)
fig, ax = plt.subplots(ncols=3)
ax[0].imshow(subimage)
ax[1].imshow(predicted_image_labeled)
ax[2].imshow(subimage_labeled)
plt.show()

Хорошо, она на что-то способна.

Обучение и тестирование модели

Метрики

При обучении и тестировании модели важно пользоваться различными метриками, чтобы следить за ее развитием. И если метрику потерь мы получаем из функции потерь (ее мы определим после), то в модулях torch и torchvision таких важных для семантической сегментации метрик как Pixel Accuracy и Intersection over Union (IoU) не содержится. Хотя они есть в torchmetrics, но я решил написать свои версии Pixel Accuracy и IoU, чтобы понять, что они из себя представляют. Хороший и простой обзор по этим и другим метрикам для задач сегментации провел Джереми Джордан на своем сайте.

Pixel Accuracy

Альтернативной метрикой для оценки семантической сегментации является простой отчет о проценте пикселей на изображении, которые были правильно классифицированы. Точность пикселей обычно указывается для каждого класса отдельно, а также глобально для всех классов. (Джереми Джордан)

Мы можем вычислить ее с помощью формулы:

pixel \space accuracy = \frac{correct \space pixels}{correct \space pixels + uncorrect \space pixels}

Напишем функцию для вычисления метрики Pixel Accuracy. Причем, определим классификацированные пиксели так, как мы бы их отрисовывали, то есть с помощью получения масива из индексов классов с максимальными вероятностями среди других классов для каждого пикселя:

def metric_pixel_accuracy(
    y_pred: torch.Tensor,
    y_true: torch.Tensor
) -> float:

  y_pred_argmax = y_pred.argmax(dim=1)
  y_true_argmax = y_true.argmax(dim=1)

  correct_pixels = (y_pred_argmax == y_true_argmax).count_nonzero()
  uncorrect_pixels = (y_pred_argmax != y_true_argmax).count_nonzero()
  result = (correct_pixels / (correct_pixels + uncorrect_pixels)).item()

  return result

IoU

Метрика Intersection over Union (IoU), также называемая индексом Жаккара, по сути, является методом количественной оценки процентного перекрытия между целевой маской и нашим прогнозируемым результатом. Эта метрика тесно связана с коэффициентом Дайса, который часто используется в качестве функции потерь во время обучения. (Джереми Джордан)

Ее мы можем вычислить с помощью формулы:

IoU = \frac{target \cap prediction}{target \cup prediction}

Наишем функцию для вычисления метрики IoU. Определим классифицированными только те пиксели, вероятность которых выше 0.51 (51%):

def metric_iou(
    y_pred: torch.Tensor,
    y_true: torch.Tensor
) -> float:

  y_pred_hot = y_pred >= 0.51

  intersection = torch.logical_and(y_pred_hot, y_true).count_nonzero()
  union = torch.logical_or(y_pred_hot, y_true).count_nonzero()
  result = (intersection / union).item()

  return result

Обучение

Определим функцию для обучения, которая будет подгружать подизображения, подавать модели входные тензоры и получать выходные, вычислять на их основе потерю, оптимизировать обучение, вычислять и выводить историю метрик по каждой партии, и все это определенное кол‑во эпох. Причем, чтобы оперативная память не перегружалась мы будем удалять уже не нужные громоздкие переменные и очищать оперативную память графического процессора, если мы его используем:

def train(
    model: torch.nn.Module,
    device: torch.DeviceObjType,
    train_dataloader: torch.utils.data.DataLoader,
    loss_fn: Any,
    optim_fn: Any,
    epochs: int
) -> Dict[str, List[float]]:

  history_metrics = {
      'loss': list(),
      'pixel_accuracy': list(),
      'iou': list()
  }

  for e in range(1, epochs + 1):
    for b, data in enumerate(train_dataloader, start=1):
      subimage_tensor, subimage_mask_tensor = data
      
      if device.type == 'cuda':
        subimage_tensor = subimage_tensor.to(device)
        subimage_mask_tensor = subimage_mask_tensor.to(device)
      
      optim_fn.zero_grad()
      output = model(subimage_tensor)
      loss = loss_fn(output['out'], subimage_mask_tensor)
      loss.backward()
      optim_fn.step()

      loss_item = loss.item()
      pixel_accuracy = metric_pixel_accuracy(output['out'], subimage_mask_tensor)
      iou = metric_iou(output['out'], subimage_mask_tensor)

      history_metrics['loss'].append(loss_item)
      history_metrics['pixel_accuracy'].append(pixel_accuracy)
      history_metrics['iou'].append(iou)

			# dynamic output
      clear_output()
      print(
          'Epoch: {}. Batch: {}. Loss: {:.3f} | Pixel Accuracy: {:.3f} | IoU: {:.3f}'.format(
              e, b,
              loss, pixel_accuracy, iou
          )
      )

      # memory clear
      del subimage_tensor, subimage_mask_tensor, output, loss
      if device.type == 'cuda':
        torch.cuda.empty_cache()
  
  return history_metrics

Определим также функции потерь и оптимизации, а также кол‑во эпох. Для нашей модели и для большинства моделей нейросетей они являются необходимыми, ведь они помогают проводить и оптимизировать обучение. В качестве функции потерь я выбрал популярную для задач семантической сегментации CrossEntropyLoss (torch.nn.CrossEntropyLoss). В качестве функции оптимизации — AdamW (torch.optim.AdamW), и важно определиться с ее параметрами: по результатам моих экспериментов, наиболее подходящим начальным коэффициентом обучения (learning rate, lr) оказывается 0.0001 (1e-4). Наиболее подходящее кол‑во эпох, опять‑же, по результатам моих экспериментов, оказывается 8, ведь по моим наблюдениям дальше он перестает обучаться, хотя в упомянутом в начале исследовании использовали целых 100 эпох, но для своей версии датасета.

Все готово к обучению. Запустим его:

history_metrics = train(
    model=model,
    device=device,
    train_dataloader=train_dataloader,
    loss_fn=torch.nn.CrossEntropyLoss(),
    optim_fn=torch.optim.AdamW(params=model.parameters(), lr=1e-4),
    epochs=8
)

Этот процесс оказывается достаточно долгим даже при использовании GPU в Google Colab. Возможно, часть кода оказывается недостаточно эффективным, поэтому будет очень хорошо, если вы сможете оптимизировать его для своих задач.

Исследование обучения

Выведем историю метрик с помощью matplotlib.pyplot:

plt.plot(
    history_metrics['loss'], 'red',
    history_metrics['pixel_accuracy'], 'green',
    history_metrics['iou'], 'blue',
)
plt.title('History Metrics in Training\nep=8, bs=32, loss_fn=CrossEntopyLoss(), optim_fn=AdamW(lr=1e-4)')
plt.xlabel('Batch')
plt.ylabel('Value')
plt.legend(('Loss', 'Pixel Accuracy', 'IoU'))
plt.show()

Хотя развитие модели видно достаточно явно: потери уменьшаются, метрики Pixel Accuracy и IoU увеличиваются, а значит точность и предсказания увеличивается. Однако показатели представляются для меня слишком колеблющимися. Это может значит, что модель после каждой партии модель проводит слишком резкие изменения. Все это можно регулировать с помощью дополнительной обработки датасета, изменения размера партий (batch_size), выбора определенных функций потерь, параметров функции оптимизации (learning rate и других) и т. п. Все это входит в спектр моего дальнейшего изучения, поэтому, если вы в этом хорошо разбираетесь, я буду рад, если вы поделитесь со мной своими знаниями.

Сохранение модели

Сохраним модель с помощью функции save из модуля torch в формате h5:

torch.save(
		model.state_dict(),
		'remezova_fcn_resnet101_ep8bs32lr1e-4.h5'
)

Тестирование

Напишем похожую на train функцию test, только без элементов обучения и c выводом медианы метрик из истории метрик с каждой партией:

def test(
    model: torch.nn.Module,
    device: torch.DeviceObjType,
    test_dataloader: torch.utils.data.DataLoader,
    loss_fn: Any,
) -> Dict[str, List[float]]:

  history_metrics = {
      'pixel_accuracy': list(),
      'iou': list()
  }

  for b, data in enumerate(test_dataloader, start=1):
    subimage_tensor, subimage_mask_tensor = data
    
    if device.type == 'cuda':
      subimage_tensor = subimage_tensor.to(device)
      subimage_mask_tensor = subimage_mask_tensor.to(device)
    
    with torch.no_grad():
      output = model(subimage_tensor)

    pixel_accuracy = metric_pixel_accuracy(output['out'], subimage_mask_tensor)
    iou = metric_iou(output['out'], subimage_mask_tensor)

    history_metrics['pixel_accuracy'].append(pixel_accuracy)
    history_metrics['iou'].append(iou)

    clear_output()
    print(
        'Batch: {}. median Pixel Accuracy: {:.3f} | median IoU: {:.3f}'.format(
            b,
            np.median(a=history_metrics['pixel_accuracy']),
            np.median(a=history_metrics['iou'])
        )
    )

    # memory clear
    del subimage_tensor, subimage_mask_tensor, output
    if device.type == 'cuda':
      torch.cuda.empty_cache()
  
  return history_metrics

Проведем тестирование:

test_history_metrics = test(
    model=model,
    device=device,
    test_dataloader=test_dataloader,
    loss_fn=torch.nn.CrossEntropyLoss()
)

И тем не менее получаем достаточно хороший результат:

Batch: 417. median Pixel Accuracy: 0.703 | median IoU: 0.523

Визуализируем результаты предсказаний. Выведем предсказания по пяти тестовым подизображениям:

fig, ax = plt.subplots(nrows=3, ncols=5, figsize=(16,8))
for si, (subimage_tensor, subimage_mask_tensor) in enumerate(test_dataloader):
  subimage = torchvision.transforms.ToPILImage()(pic=subimage_tensor[0])
  subimage_labeled = get_image_labeled_from_mask(
      image_mask=subimage_mask_tensor[0].cpu().numpy(),
      classes_by_id=classes_by_id
  )
  predicted_subimage_labeled = predict(
      image=subimage,
      model=model,
      device=device
  )
  ax[0][si].imshow(subimage)
  ax[1][si].imshow(predicted_subimage_labeled)
  ax[2][si].imshow(subimage_labeled)
  
  if si == 5 - 1:
    break

plt.show()
Примеры предсказаний. Сверху вниз: подизображение, маска предсказанного подизображения и маска истинного подизображения
Примеры предсказаний. Сверху вниз: подизображение, маска предсказанного подизображения и маска истинного подизображения

Заключение

Мы подготовили датасет, создали, обучили и протестировали модель для сегментации спутниковых снимков. Не смотря на некоторые недостатки, модель показывает достаточно хорошие результаты. Однако этот туториал и весь код будут продолжать изменяться или дополняться с развитием моих знаний или ваших рекомендаций.

Этот туториал вышел достаточно большим, но я надеюсь, я смог достаточно хорошо ответить на те вопросы по нейросетям для семантической сегментации, которые, возможно, возникли и у вас; особенно, если вы такой же новичок, как и я.

Также вы можете воспользоваться более структурированным кодом этой модели в репозитории Remezova на GitHub. ❤️

Комментарии (0)