Привет, друзья! Добро пожаловать в новый туториал из серии практических материалов по explanable AI (интерпретируемости моделей). Он посвящен методу интерпретации на основе вмешательства — RISE. В этом материале разобрана теоретическая постановка метода, подчеркнуты красивые математические идеи и переходы, и, конечно, реализован код для практики. Приглашаю к чтению! Ноутбук к туториалу доступен на гитхаб.

Методы интерпретации на основе вмешательства основаны на идее ответа на вопрос: на вопрос:

«Что произойдет с предсказанием модели если изменить или исключить отдельный признак?»

Cуть всех таких методов состоит в том, что мы изменяем входные данные x по заранее определённому правилу, пропускаем измененное изображение x' через модель и фиксируем дельту (\delta = f(x) - f(x')) выхода модели. Отсюда, важность признака определяется через чувствительность модели к его изменениям.

Зачем это всё?

Такие методы полезны в условиях, когда структура и параметры модели недоступны (например, работа через API). В отличие от градиентных подходов, они не требуют доступа к внутренним весам или производным модели.

Среди стандартных методов — Occlusion, Ablataion и Pertubation. А вот не совсем стандартный, но всё ещё часто встречающийся метод — RISE (*Randomized Input Sampling for Explanation of Black-box Models*), который в этом туториале и разберем.

Его ключевые шаги:

  1. Генерируем случайные бинарные маски на изображение.

  2. Применяем эти маски и прогоняем через модель.

  3. Смотрим, как меняется вероятность предсказания интересующего класса.

  4. Собираем важностную карту признаков как средневзвешенную комбинацию масок.

Метод решает проблему корректной оценки множества всех бинарных масок (а для изображения $32 \times 32$ таких масок уже $2^{32\times 32}$), и именно это ключевая компонента, которую интересно пронаблюдать.

Приступим!

Модель и данные

Для демонстрации возьмем предобученную ResNet50 и картинку кошки. Котики — это всегда хорошо. В прочем, вы можете подгрузить любое изображение и поработать с ним.

import requests, io

url = "https://github.com/SadSabrina/XAI-open_materials/blob/38216bfee2a38c96b4f1d16460040cb89527c3fe/yolo_nas_cam_tutorial/coco_test/cat.jpg?raw=True"
img = Image.open(requests.get(url, stream=True).raw).convert("RGB")

x = transform(img).unsqueeze(0)  # (1, 3, 224, 224)

plt.imshow(img)
plt.title("Исходное изображение")
plt.axis("off")
plt.show()

RISE: постановка и реализация

Метод RISE — стандартный метод основанный на возмущениях. Как и другие — он закрывает часть пикселей исходного изображения, но делает это на основе масок. На этом моменте справедливо заметить, что количество бинарных масок для картинки 224 \times 224(стандартный вход в ResNet) — это 2^{224\times 224}. Чтобы корректно оценить важность признака в RISE реализовано два красивых перехода, и они — центральное, на что стоит обратить внимание в методе.

Постановка

Пусть:

  • f: I \to \mathbb{R} — модель-чёрный ящик, которая по изображению I возвращает уверенность в целевом классе.

  • \Lambda = \{1,\dots,H\} \times \{1,\dots,W\} — множество координат пикселей \lambda.

  • M: \Lambda \to {0,1} — случайная бинарная маска с распределением $D$.

Важность пикселя \lambda определим как условное матожидание:

Imp_{I,f}(\lambda) = \mathbb{E}_M \big[ f(I \odot M) \mid M(\lambda)=1 \big],

где оператор ⊙ обозначает покоординатное умножение (применение маски к изображению).

Интуиция: если сохранение пикселя \lambda после маскирования существенно увеличивает вероятность целевого класса, то этот пиксель важен для прогноза.

Как я уже отметила, даже для изображения 32 \times 32 количество всех возможных масок равно 2^{32 \cdot 32}, что вычислительно невозможно, особенно, когда картинка не одна. Для борьбы с этим в RISE применяются два приёма:

1. Переход к безусловному распределению масок. Условное матожидание разворачивается через формулу Байеса.

2. Специальный способ маскирования с использованием билинейной интерполяции. Маски сначала генерируются в низком разрешении с заданной вероятностью сохранения пикселя, а затем увеличиваются до размера изображения методом билинейной интерполяции. Это сглаживает границы масок и делает объяснения более устойчивыми.

Hint 1. Развёртка через формулу Байеса

Исходная формула:

Imp_{I,f}(\lambda) = \mathbb{E}_M \big[ f(I \odot M) \mid M(\lambda)=1 \big]  = \sum_{m} f(I \odot m)\, P\!\big(M=m \mid M(\lambda)=1\big).

Прямая работа с условной вероятностью неудобна, потому что она требует явно учитывать распределение P[M=m | M(\lambda)=1]. Поэтому авторы переходят к безусловному матожиданию, используя формулу Байеса:

P(M=m \mid M(\lambda)=1) = \frac{P(M=m, M(\lambda)=1)}{P(M(\lambda)=1)}.

Так как:

P(M=m,\, M(\lambda)=1) =\begin{cases}0, & m(\lambda)=0, \\[6pt]P(M=m), & m(\lambda)=1,\end{cases}

то нулевые слагаемые можно отбросить, а m(\lambda)=1 учесть явно. Получаем:

Imp_{I,f}(\lambda) = \frac{1}{P[M(\lambda)=1]} \sum_{m} f(I \odot m)\, m(\lambda)\, P[M=m].

Знаменатель P[M(\lambda)=1] = \mathbb{E}[M], так как математическое ожидание индикаторной (либо 0, либо 1) случайной величины равно вероятности события, которое она описывает, то есть:

\mathbb{E}[M(\lambda)] = 1 \cdot P(M(\lambda)=1) + 0 \cdot P(M(\lambda)=0) = P(M(\lambda)=1)

Далее, чтобы приближенно оценить сумму \sum_{m} f(I \odot m), m(\lambda), P[M=m],
применяется метод Монте-Карло.

Метод Монте-Карло

Метод Монте-Карло позволяет приближённо вычислить математическое ожидание с помощью случайного сэмплирования. Формально, если X — случайная величина, то её матожидание оценивается как \mathbb{E}[X] \approx \frac{1}{N}\sum_{i=1}^N X_i,где X_1, \dots, X_N— независимые реализации X.

В нашем случае:

  • случайная величина — это f(I \odot M)\, M(\lambda),

  • мы берём N случайных масок M_1, \dots, M_N,

  • для каждой маски вычисляем предсказание модели f(I \odot M_i),

  • умножаем его на M_i(\lambda), чтобы учитывать вклад только тех масок, где пиксель $ сохранён,

  • и усредняем полученные значения.

Таким образом, оценка важности пикселя с применением метода Монте-Карло имеет вид:

S_{I,f}^{MC}(\lambda) \approx \frac{1}{N \cdot \mathbb{E}[M]}  \sum_{i=1}^N f(I \odot M_i)\, M_i(\lambda).

Здесь

  • нормировка \tfrac{1}{N \cdot \mathbb{E}[M]} учитывает два фактора: (i) деление на $N$ соответствует усреднению по выборке, (ii) деление на \mathbb{E}[M] = P(M(\lambda)=1) корректирует результат с учётом вероятности сохранения пикселя;

  • M_i(\lambda) — значение i-й маски в позиции \lambda:
    если пиксель \lambda сохранён в маске M_i, то M_i(\lambda)=1,
    если замаскирован, то M_i(\lambda)=0.

    И таким образом, M_i(\lambda) выступает в роли индикатора, который «включает» вклад f(I \odot M_i) только в тех случаях, когда пиксель \lambda действительно участвует в изображении.

Отлично, с формулой разобрались и метод Монте-Карло узаконили. Второй прием в статье связан с масками.

Hint 2. Сэмплирование масок с интерполяцией

После перехода к безусловной вероятности не решает, однако, проблемы применимости метода Монте-Карло в пространстве всех масок 2^{H\times W}. При условии стандартного input size 224\times 224, прямое равновероятное сэмплирование пикселей даёт очень разные маски, которые, в свою очередь, даеют шумные оценки. Достичь устойчивости карты в таком случае возможно только при всё ещё большом количестве масок.

В качестве решения этой проблемы вводится сэмплирование через интерполяцию. А именно маски генерируются в 4 шага:

1. Формирование малых бинарных масок.
Генерируются маски размером h \times w. Каждый элемент маски задаётся распределением Бернулли:

— значение 1 с вероятностью p (пиксель сохранён),
— значение 0 с вероятностью 1-p (пиксель замаскирован).

2. Апсемплинг с помощью билинейной интерполяции.
Полученные малые маски увеличиваются до размера

(h+1) \cdot C_H \;\times\; (w+1) \cdot C_W,

где C_H \times C_W = \lfloor H/h \rfloor \times \lfloor W/w \rfloor — размер ячейки после апсемплинга. Благодаря билинейной интерполяции маска перестаёт быть строго бинарной и принимает значения в диапазоне $[0,1]$. Это сглаживает границы и делает итоговую карту важности более «непрерывной».

3. Случайный сдвиг.
Маска сдвигается на случайное число пикселей вдоль обеих координатных осей. Такой шаг обеспечивает дополнительную вариативность и устраняет артефакты регулярной сетки.

4. Обрезка (cropping).
После сдвига берётся окно размера H \times W, совпадающее с размером исходного изображения и маска приводится к нужным размерам для покоординатного умножения с I.

Эти хитрости дают сглаженные, но при этом они всё ещё основанные на вероятностном распределении с параметром p, маски.

Реализация


По сути, для реализации необходимо три шага:

  1. Сделать маски;

  2. Применить маски к изображению;

  3. Получить прогноз и усредить.

Полный кусочек кода для реализации масок выглядит так:

  # sizes
  _, C, H, W = x.shape
  (C_H, C_W) = (H//s, W//s)
  up_h, up_w = (s+1)*C_H, (s+1)*C_W

  grid = (torch.rand(N,1,s,s, device=x.device) < p).float()
  ups = F.interpolate(grid, size=(up_h, up_w), mode="bilinear",
                      align_corners=False)

  # shifts and crop
  masks = torch.empty(N,1,H,W, device=x.device)
  
  for i in range(N):
    # сдвиг по x на случайное число от 0 до 32 (256-224)
    dx = torch.randint(0, C_H, (1,), device=x.device).item() 
    # аналогично по y
    dy = torch.randint(0, C_W, (1,), device=x.device).item() 
    masks[i,0] = ups[i,0, dx:dx+H, dy:dy+W]
Маски с параметром p=0.t
Маски с параметром p=0.t
Маски x изображение
Маски x изображение

И далее, остается только получить прогноз и усреднить.

scores = []
class_id = model(x).softmax(dim=1).argmax().item()


with torch.no_grad():
  for i in tqdm(range(0, N)):

    stack = masks[i] * x        # (B,1,H,W)*(1,C,H,W) -> (B,C,H,W)
    out = model(stack).softmax(dim=1)        # вероятности
    scores.append(out)          # (B,)

  scores = torch.cat(scores, dim=0)                # (N,)
  # Берём столбец интересующего класса: (N,)
  scores = scores[:, class_id]

  # Свернём MC-оценку: (H*W,) = (1 x N) @ (N x H*W)
  sal = torch.matmul(scores.view(1, N), masks.view(N, H * W)).view(H, W)
  sal = sal / (N * p)



plt.imshow(x[0].permute(1, 2, 0).numpy())
plt.imshow(sal.cpu(), cmap="jet", alpha=0.5)
plt.axis("off")
plt.title("RISE")
plt.show()
RISE на 10 масках
RISE на 10 масках

Стоит заметить, что метод тем точнее, чем больше масок мы генерируем изначально (в силу всё равно большого пространства возможных масок, даже с введенными улучшениями).

RISE 10, 100, 1000
RISE 10, 100, 1000

Ручная реализация остается чувствительной к количеству наблюдений и их всё ещё необходимо достаточно много. Но метод математически прекрасен, мне хотелось этим поделиться :)

На этом, туториал завершается. Вы молодцы, если дошли до конца. Подводя итог, в туториале мы разобрали применение метода RISE для интерпретации. Для этого:

  • познакомились с идеей perturbation-based методов и формулировкой RISE через условное матожидание;

  • аккуратно (надеюсь) вывели математическую постановку и показали переход через формулу Байеса;

  • реализовали алгоритм генерации случайных масок и обсудили трюк с билинейной интерполяцией;

  • собрали Monte-Carlo аппроксимацию и построили карты важности для конкретного класса;

  • визуализировали результаты и убедились, что метод работает даже для полностью «чёрных ящиков».

RISE позволяет получать интерпретации без доступа к весам и градиентам модели, что делает его универсальным инструментом. Однако стоит помнить, что качество карт важности зависит от числа сгенерированных масок и параметров распределения (например, вероятности сохранения пикселя). Попробуйте поэкспериментировать с параметрами и сравнить полученные карты с другими методами (например, CAM, по которому тоже есть туториал).

Основные материалы туториала:

[1](https://arxiv.org/abs/1806.07421) — оригинальная статья о RISE,

[2](https://github.com/eclique/RISE) — официальный репозиторий,

Спасибо за чтение! Другие туториалы публикую в блоге (https://t.me/jdata_blog) и на [гитхаб](https://github.com/SadSabrina/XAI-open_materials/tree/main)!

Со всем самым добрым и спасибо за чтение,
Ваш Дата-автор.

Комментарии (0)