
Всем привет. Я Андрей Бояренков, лидер кластера бизнес-моделей стрима "Разработка моделей КИБ и СМБ" банка ВТБ.
Наш кластер отвечает за:
выстраивание и внедрение процессов AutoML,
за разработку моделей для процессов: ПОДФТ \ Precollection \ ЖЦК (жизненного цикла клиента),
а также за разработку моделей цифровых помощников для различных подразделений банка, в т.ч. риск-андеррайтинга, операционных рисков и комплаенса.
Участие заказчика в процессах разработки и применения моделей является достаточно важным. Как правило, оно является ключевым на следующих этапах:
постановка задачи на разработку модели, включая определение сегмента и целевой переменной,
согласование лонг-листа фичей модели и методологии их расчета,
прием результатов разработки модели (подтверждение соответствия метрик качества модели изначально заявленным),
подтверждение бизнес-логики работы фичей в модели.
В данной статье хочу рассказать о том, какие на мой взгляд типы графиков необходимо построить, чтобы наиболее оптимальным образом показать заказчику логику работы фичей в моделях. В качестве ограничения установим, что целью является показать логику работы фичей не на уровне конкретного наблюдения, а на уровне выборки в целом.
Графики покажем на примере находящегося в открытом доступе датасете telecom_churn. В частности, его можно найти по указанным ссылкам:
1. https://www.kaggle.com/datasets/keyush06/telecom-churncsv
2. https://www.kaggle.com/datasets/nikkitha8/telecom-churn
3. https://www.kaggle.com/code/kashnitsky/topic-1-exploratory-data-analysis-with-pandas
4. https://habr.com/ru/companies/ods/articles/322626/
Ниже представлены основные шаги для обработки выборки telecom_churn в целях ее дальнейшего использования:
1. Сначала импортируем библиотеки которые пригодятся для нашего исследования.
2. Далее создадим датафрейм data с выборкой telecom_churn, скорректируем названия строк и присвоим бинарный тип данных целевой переменной churn. Детальный EDA (Explanatory Data Analysis) проводить не будем, т.к. цель статьи заключается только в том чтобы показать возможные графики для интерпретации работы фичей.
3. Для расширения признакового пространства также рассчитаем ряд дополнительных фичей.
4. Сделаем дополнительную предобработку типов данных.
5. Создадим лонг-лист фичей features и список категориальных фичей cat_feat.
6. Разделим выборку со стратификацией по churn на train (75%) и test (25%) для проверки качества модели на независимой выборке.
7. Выборку train разделим на train (60%) для обучения модели и val (15%) для использования критерия ранной остановки.
Код на Python для проведения данных действий с выборкой см. ниже:
import pandas as pd
import re
import shap
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from catboost import CatBoostClassifier
from sklearn.model_selection import train_test_split
from sklearn import preprocessing
from sklearn.metrics import make_scorer, roc_auc_score
from scipy import stats
import warnings
warnings.filterwarnings('ignore')
data = pd.read_csv('telecom_churn.csv', sep=',') # чтение датафрейма
data.columns = data.columns.str.replace(' ', '_') # корректировка названий строк
data['churn'] = data['churn'].map({True: 1, False: 0}).astype('int32')
#расчет суммы показателей за все время
col_minutes = [x for x in data.columns if len(re.findall('total_\w*_minutes', x)) != 0]
col_calls = [x for x in data.columns if len(re.findall('total_\w*_calls', x)) != 0]
col_charge = [x for x in data.columns if len(re.findall('total_\w*_charge', x)) != 0]
data['total_minutes'] = data[col_minutes].sum(axis=1)
data['total_calls'] = data[col_calls].sum(axis=1)
data['total_charge'] = data[col_charge].sum(axis=1)
data['charge_per_minute'] = data['total_charge'] / data['total_minutes'] #стоимость совокупной минуты
data['charge_per_call'] = data['total_charge'] / data['total_calls'] #стоимость одного звонка
data['minutes_per_call'] = data['total_minutes'] / data['total_calls'] #продолжительность одного звонка
#сделать категориальными фичи по которым уникальных значений от 3 до 9 включительно
for column in data.columns: if data[column].nunique() < 10 and data[column].nunique() >= 3: data[column] = data[column].astype('str') #сделать бинарными фичи по которым 2 уникальных значения и при этом они являются типами 'object' или 'bool'
bool_columns = []
for column in data.dtypes[(data.dtypes=='object')|(data.dtypes=='bool')].index: if data[column].nunique() == 2: bool_columns.append(column)
le = preprocessing.LabelEncoder()
for column in bool_columns: data[column] = le.fit_transform(data[column])
# для того чтобы показать как графики с интерпретацией фичей работают в том числе на категориальных фичах
# искусственно сделаем фичу 'customer_service_calls' категориальной (для каждого значения введем отдельное обозначение
# от 'A' до 'F', а для значений 6 и более создадим отдельную категорию 'G' в которую попадет чуть более 1% выборки
data['customer_service_calls'] = data['customer_service_calls'].replace( {0: 'A', 1: 'B', 2: 'C', 3: 'D', 4: 'E', 5: 'F'}).mask(data['customer_service_calls'] >= 6, 'G')
# Создадим лонг-лист фичей features и список категориальных фичей cat_feat (исключен идентификатор phone_number и таргет churn)
features = data.columns.to_list()
features = [x for x in features if x not in ['phone_number','churn']]
cat_feat = data.select_dtypes(exclude = [np.number]).columns.to_list()
cat_feat = [x for x in cat_feat if x not in ['phone_number','churn']]
# Разделим выборку со стратификацией по churn на трейн(75%) и тест(25%) для проверки качества модели на независимой выборке:
x_train,x_test,y_train,y_test=train_test_split(data, data['churn'], stratify = data['churn'], test_size = 0.25, random_state=42)
x_train,x_val,y_train,y_val=train_test_split(x_train, y_train, stratify = y_train, test_size = 0.20, random_state = 42)
Далее на всех фичах из лонг–листа обучим Catboost небольшой глубины и c небольшим количеством деревьев. Нам это нужно для получения списка наболее значимых фичей.
Полученная модель на тестовой выборке показала коэффициент Джини 83,08%.
Значимость оценим с помощью get_feature_importance() в основе которого алгоритм PredictionValueChange, который показывает, насколько в среднем изменится прогноз модели при изменении значения одной из фичей.
При изменении значения фичи замена осуществляется на среднее значение фичи по выборке, а для категориальных признаков — на самый частый категориальный класс.
Выберем пять наиболее значимых фич. Именно на них мы построим новую модель и проанализируем графики для интерпретации.
Как видно из таблицы пятью самыми значимыми получились: 'total_charge', 'customer_service_calls', 'international_plan', 'total_intl_calls' и 'number_vmail_messages'.
Для информации выведем в отдельной таблице аналитику по каждой из них. Сделать это можно с помощью стандартной функции describe с дополнительной обработкой для добавления полезной информации.
Код на Python для проведения данных действий см. ниже:
def gini(y_true, y_pred): gini = 2 roc_auc_score(y_true, y_pred) - 1 return gini
gini_scorer = make_scorer(gini, greater_is_better = True)
params = { "verbose": False, "eval_metric": 'Logloss', 'iterations': 1000, 'random_state': 42, 'early_stopping_rounds': 10, 'max_depth': 4
}
model = CatBoostClassifier(*params)
model.fit(x_train[features], y_train, eval_set = (x_val[features], y_val), cat_features = cat_feat)
y_pred_train = model.predict_proba(x_train[features]).T[1]
y_pred_val = model.predict_proba(x_val[features]).T[1]
y_pred_test = model.predict_proba(x_test[features]).T[1]
gini_train = np.round(gini(y_train, y_pred_train),3)
gini_val = np.round(gini(y_val, y_pred_val),3)
gini_test = np.round(gini(y_test, y_pred_test),3)
print('Джини бустинга на всех фичах:', 'train', gini_train, 'val', gini_val, 'test', gini_test)
fe_stats = pd.DataFrame({'feature_importance': model.get_feature_importance(), 'feature_names':features}).sort_values(by=['feature_importance'], ascending=False)
display(fe_stats[0:5])
features = fe_stats[0:5].feature_names.to_list()
cat_feat = ['customer_service_calls']
target = ['churn']
data = data[features + target].copy()
data_info = data.describe(percentiles = [0.01, 0.05, 0.5, 0.95, 0.99], include=list(np.unique(data[data.columns].dtypes.astype('str').values))).T
data_info['type'] = data[data.columns].dtypes
data_info['null'] = data[data.columns].isnull().sum()
data_info['null%'] = np.round(data[data.columns].isnull().mean() * 100, 1)
data_info['nunique'] = data[data.columns].nunique()
data_info['count'] = data_info['count'].astype('int')
data_info = data_info.drop('unique', axis = 1)
data_info.sort_values(by = ['type','nunique'], ascending = [False, False], inplace = True)
data_info=data_info.loc[:,['type','count','nunique','null','null%','top','freq','min','max','mean','std','1%','5%','50%','95%','99%']]
data_info[['min','max','mean','std','1%','5%','50%','95%','99%']] = data_info[['min','max','mean','std','1%','5%','50%','95%','99%']].astype(float).round(4)
display(data_info)
Джини бустинга на всех фичах: train 0.934 val 0.859 test 0.831
feature_importance |
feature_names |
|
21 |
39.692171 |
total_charge |
18 |
17.874816 |
customer_service_calls |
3 |
14.006796 |
international_plan |
16 |
6.891904 |
total_intl_calls |
5 |
5.593090 |
number_vmail_messages |
type |
count |
nunique |
null |
null% |
top |
freq |
min |
max |
mean |
std |
1% |
5% |
50% |
95% |
99% |
|
customer_service_calls |
object |
3333 |
7 |
0 |
0.0 |
B |
1181 |
NaN |
NaN |
NaN |
NaN |
NaN |
NaN |
NaN |
NaN |
NaN |
total_charge |
float64 |
3333 |
2678 |
0 |
0.0 |
NaN |
NaN |
22.93 |
96.15 |
59.4498 |
10.5023 |
33.8532 |
42.338 |
59.47 |
76.516 |
83.8396 |
number_vmail_messages |
int64 |
3333 |
46 |
0 |
0.0 |
NaN |
NaN |
0.00 |
51.00 |
8.0990 |
13.6884 |
0.0000 |
0.000 |
0.00 |
36.000 |
43.0000 |
total_intl_calls |
int64 |
3333 |
21 |
0 |
0.0 |
NaN |
NaN |
0.00 |
20.00 |
4.4794 |
2.4612 |
1.0000 |
1.000 |
4.00 |
9.000 |
13.0000 |
international_plan |
int32 |
3333 |
2 |
0 |
0.0 |
NaN |
NaN |
0.00 |
1.00 |
0.0969 |
0.2959 |
0.0000 |
0.000 |
0.00 |
1.000 |
1.0000 |
churn |
int32 |
3333 |
2 |
0 |
0.0 |
NaN |
NaN |
0.00 |
1.00 |
0.1449 |
0.3521 |
0.0000 |
0.000 |
0.00 |
1.000 |
1.0000 |
Далее обучим catboost только на пяти самых значимых фичах. Полученная модель на тестовой выборке показала коэффициент Джини 82,5%.
params = {
"verbose": False,
'eval_metric': 'Logloss',
'iterations': 1000,
'early_stopping_rounds': 10,
'depth': 4, #по умолчанию 6
'random_state': 42,
}
model = CatBoostClassifier(**params)
model.fit(x_train[features], y_train, eval_set = (x_val[features], y_val), cat_features = cat_feat)
y_pred_train = model.predict_proba(x_train[features]).T[1]
y_pred_val = model.predict_proba(x_val[features]).T[1]
y_pred_test = model.predict_proba(x_test[features]).T[1]
gini_train = np.round(gini(y_train, y_pred_train),3)
gini_val = np.round(gini(y_val, y_pred_val),3)
gini_test = np.round(gini(y_test, y_pred_test),3)
print('Джини бустинга на всех фичах:', 'train', gini_train, 'val', gini_val, 'test', gini_test)
Джини бустинга на всех фичах: train 0.904 val 0.865 test 0.825.
Интерпретировать фичи модели очень удобно с помощью библиотеки SHAP. Детально про алгоритмы расчета Shap написано много, посмотреть можно в следующих статьях:
https://habr.com/ru/articles/428213/
https://habr.com/ru/companies/wunderfund/articles/739744/
https://habr.com/ru/companies/ods/articles/599573/
https://habr.com/ru/companies/otus/articles/465329/
https://www.kaggle.com/code/dansbecker/shap-values
Ссылка на официальную документацию по Shap:
Сутево SHAP-value отвечают на вопрос: "Насколько изменится предсказание модели для конкретного наблюдения по сравнению со средним прогнозом, если мы добавим данную фичу, учитывая все возможные комбинации / перестановки фичей".
В первую очередь рекомендую смотреть график Shap.summary_plot. На нем показываются фичи на уровне модели в целом.
Код для вывода графика следующий:
explainer = shap.TreeExplainer(model)
shap_values = explainer(x_train[features].iloc[:,:])
shap.summary_plot(shap_values, x_train[features], max_display = len(features))
Сортировка на графике идет сверху вниз от самой важной фичи до менее важных по среднему абсолютному значению Shap.values.
Но нам также интересна дополнительная информация в виде значений важности фичей, которая выводится с помощью отдельного графика:
shap.plots.bar(shap_values, max_display = len(features))
Поэтому чтобы не создавать много графиков мы можем улучшить shap.summary_plot, добавив непосредственно на него среднее абсолютное значение Shap.values. Код в данном случае будет выглядеть следующим образом:
# Инициализация explainer
explainer = shap.TreeExplainer(model)
# Вычисление SHAP значений
shap_values = explainer(x_train[features].iloc[:,:])
# Вычисление mean(|SHAP value|) для каждого признака
mean_abs_shap = np.mean(np.abs(shap_values.values), axis=0)
# Создание подписей с SHAP значениями
feature_labels = [f"{features[i]} ({mean_abs_shap[i]*100:.2f})"
for i in range(len(features))]
# Построение графика с модифицированными подписями
shap.summary_plot(shap_values, x_train[features], feature_names=feature_labels, # Используем названия фичей как подписи
max_display=len(features), show=False)
# Добавляем заголовок с пояснением
plt.title("SHAP Summary Plot: значения в скобках - mean(|SHAP value|)", y = 1.05)
plt.tight_layout()
plt.show()

Посмотрев на этот график, можно определить степень влияния и направление влияния по каждой фиче.
В качестве напоминания о том, как интерпретировать графики Shap:
1. Каждая линия на графике представляет собой фактор модели.
2. Каждая точка для определенного фактора — это отдельный прогноз в выборке.
3. Расположение по оси Х меньше или больше нуля показывает, увеличивает ли фича или уменьшает прогноз относительно среднего по выборке.
4. Также показывает значительность влияния — чем дальше от нуля, тем соответственно значительнее было влияние фичи на конкретное предсказание. Например, если наблюдение по определенной фиче имеет значение Shap равное +0,1%, это означает, что значение фичи для данного наблюдения приводит к увеличению значения Shap на эту величину.
5. Справа располагается шкала значений фичей. Точка красного цвета означает, что значение фичи очень высокое, синего цвета — низкое значение фичи. Если множество прогнозов дают похожий результат для данной фичи, то это приводит к тому, что линия становится намного шире (точки начинают накапливаться).
Анализируя ось x возникает логичный вопрос — а в каких единицах измерения выводится график shap.summary_plot?
В shap.TreeExplainer есть следующие параметры, установленные по умолчанию:
feature_perturbation='tree_path_dependent', model_output = 'raw'.
Для бинарной классификации Shap-значения при параметрах по умолчанию показывают, насколько каждый признак отклоняет предсказание модели от среднего значения в пространстве логарифмических шансов (log-odds).
Базовое значение (среднее): explainer.expected_value (средний log-odds по выборке).
Формула: , где
— вероятность положительного класса.
Но всегда интересно посмотреть график Shap.summary_plot именно в вероятностях. В таком случае есть два варианта. Первый — указать feature_perturbation="interventional", model_output="probability". Но данный вариант в некоторых версиях Shap может не сработать, поэтому можно пойти обходным путем через ручное преобразование log-odds в вероятности.
Код и график при переводе на вероятности см. ниже:
try:
# Способ 1: Прямое получение вероятностей (может не работать в некоторых версиях), работает медленнее
explainer = shap.TreeExplainer(model, feature_perturbation="interventional", model_output="probability")
shap_values = explainer(x_train[features].iloc[:,:])
except ValueError:
# Способ 2: Обходной путь через ручное преобразование
explainer = shap.TreeExplainer(model)
shap_values_raw = explainer(x_train[features].iloc[:,:])
expected_value = explainer.expected_value
shap_values = 1/(1+np.exp(-(expected_value + shap_values_raw.values))) - 1/(1+np.exp(-expected_value))
shap.summary_plot(shap_values, x_train[features], max_display = len(features))

Но Shap.summary_plot — очень верхнеуровневый график, на нем не всегда понятны детали работы фичи.
Например, по наиболее важной фиче total_charge видно, что с ростом значения фичи вероятность таргета также существенно растет, а вот что происходит при низких и средних значениях — из графика не вполне понятно. Поэтому многие фичи, особенно с нелинейной зависимостью лучше еще дополнительно посмотреть на отдельном графике shap.plots.scatter(shap_values[:,x])
В данный график можно добавить много полезной информации, и эту доработку удобнее сделать с помощью кастомного графика scatter plot. На нем детально видно изменение логики работы фичи в зависимости от ее значения.
В нашем конкретном случае по фиче total_charge видно, что сначала некоторые наблюдения даже увеличивают вероятность срабатывания таргета, потом происходит снижение, а далее резкий рост предсказаний.
Дополнительно можно отразить на графике:
1) Наблюдения соответствующие таргету (синие точки);
2) Динамика средних значений изменения предсказаний в зависимости от значений фичи;
3) Распределение значений фичи по выборке;
4) Среднее / медиана значений фичи.
# Ручное создание scatter plot с кастомными настройками
x = 'total_charge'
shap_values = explainer(x_train[features].iloc[:,:])
# Получаем данные
x_data = x_train[x]
y_data = shap_values[:, x].values
y_labels = y_train
# Создаем фигуру и основную ось
fig, ax1 = plt.subplots(figsize=(11, 6))
fig.patch.set_facecolor('white')
ax1.set_facecolor('white')
colors = np.where(y_train == 1, 'blue', 'red')
scatter = ax1.scatter(
x=x_train[x], y=shap_values[:, x].values, c=colors, cmap='coolwarm', alpha=1, edgecolors='w', linewidths=0.1)
#Добавляем линию средних значений
# Разбиваем на 20 бинов для расчета среднего
bin_means, bin_edges, = stats.binnedstatistic(x_data, y_data, statistic='mean', bins=20)
bin_centers = bin_edges[:-1] + np.diff(bin_edges)/2
# Рисуем линию средних
ax1.plot(bin_centers, bin_means, color='black', linewidth=2.2, linestyle='-.', label='Среднее SHAP')
# Настройки основной оси
ax1.set_ylabel('SHAP value', fontsize=14)
ax1.grid(True, alpha=0.3)
ax1.axvline(np.median(x_data), color='black', linestyle='-.', label='Медиана', linewidth=1)
#Добавляем вторую ось для распределения
ax2 = ax1.twinx() # Создаем вторую ось с общим X
ax2.hist(x_data, bins=50, color='skyblue', alpha=0.2, density=True)
ax2.set_ylabel('Плотность распределения', fontsize=14)
ax2.grid(False) # Отключаем сетку для второй оси
# Общие элементы
plt.title(f'SHAP значения и распределение', fontsize=14)
Text(0.5, 1.0, 'SHAP-значения и распределение')

Графики PDP (Partial Dependence Plot) и ICE (Individual Conditional Expectation) могут хорошо дополнять Shap.Summary_plot и Scatter_plot с точки зрения раскрытия бизнес-логики фичи. PDP отвечает на вопрос: как в среднем меняется прогноз модели, если зафиксировать определенное значение исследуемой фичи, усредняя влияние всех остальных. ICE отвечает на вопрос: «Как меняется прогноз по конкретному наблюдению при изменении значения фичи, оставляя остальные неизменными». Более детальную информацию про графики PDP и ICE можно посмотреть по следующей ссылке: https://scikit-learn.org/stable/modules/partial_dependence.html
Методика расчета PDP (усредняет все кривые, показывая общий тренд):
Шаг 1: фиксируем возможные значения фичи в выборке;
Шаг 2: для каждого значения фичи подставляем его во все наблюдения и пересчитываем прогноз по каждому наблюдению;
Шаг 3: далее усредняем значения оценок по всем наблюдениям каждого значения фичи;
Шаг 4: на данных усредненных значениях прогнозов строим график PDP.
Методика расчета ICE (показывает индивидуальные зависимости для каждого наблюдения):
Шаг 1: фиксируем диапазон значений для фичи;
Шаг 2: для каждого наблюдения подставляем все значения данного признака и пересчитываем прогноз по каждому наблюдению;
Шаг 3: выводим кривую для каждого наблюдения.
Каждая линия на графике — это отдельное наблюдение в выборке с посчитанным Shap-value в разрезе каждого значения исследуемой фичи. PDP для total_charge показывает, что с ростом значения фичи предсказание в среднем немного снижается. Но посмотрев на ICE мы видим, что для основной части клиентов предсказание не изменилось, а для небольшой доли наблюдений предсказание снизилось очень сильно. Именно за счет данных наблюдений сложилась такая ситуация с бизнес-логикой. Соответственно можно провести отдельный анализ и понять причины данных изменений.
Т.е. с помощью ICE можно находить аномальные наблюдения или клиентов, у которых зависимости резко отличаются от общих.
Недостаток ICE заключается в том, что на больших датасетах бывает трудно интерпретировать (много пересекающихся линий). Также ICE показывает логику работы фичи, но не показывает статистическую значимость эффекта.
У графика ICE есть два варианта визуализации:
При 'centered' = True из всех значений кривой вычитается предсказание при стартовом значении фичи, поэтому все кривые в точке старта графика формально начинаются с Y = 0. Этот вариант может быть полезным для упрощения анализа, т.к. все кривые выровнены по одной стартовой точке, что фильтрует индивидуальные смещения, акцентируя внимание именно на форме зависимости.
Если отключить нормализацию 'centered' = False, кривые будут начинаться с фактических значений предсказаний для каждой кривой.
Оба варианта графика представим ниже.
from sklearn.inspection import PartialDependenceDisplay
fig, ax = plt.subplots(figsize=(10, 4))
fig.patch.set_facecolor('white')
ax.set_facecolor('white')
features_info = {"features": ['total_charge'], "kind": "both", "centered": False,}
common_params = {"subsample": 500, "n_jobs": 2, "grid_resolution": 30, "random_state": 42, 'method':'auto'} display = PartialDependenceDisplay.from_estimator(model, x_train[features].iloc[:,:], features_info, ax = ax, common_params)
= display.figure.suptitle("ICE and PDP total_charge", fontsize=12)

from sklearn.inspection import PartialDependenceDisplay
fig, ax = plt.subplots(figsize=(10, 4))
fig.patch.set_facecolor('white')
ax.set_facecolor('white')
features_info = {"features": ['total_charge'], "kind": "both", "centered": True,}
common_params = {"subsample": 500, "n_jobs": 2, "grid_resolution": 30, "random_state": 42, 'method':'auto'} display = PartialDependenceDisplay.from_estimator(model, x_train[features].iloc[:,:], features_info, ax = ax, common_params)
= display.figure.suptitle("ICE and PDP total_charge (centered)", fontsize=12)

Графики, которые мы показывали, хорошо подходят под непрерывные значения фичи, но мало подходят под категориальные.
Как мы видели на Shap.summary_plot, категориальные фичи не были раскрашены, детально понять ситуацию было сложно.
График plotbar поможет детальнее раскрыть логику работы для категориальных фичей.
График plt.boxplot (ящик с усами) визуализирует основные описательные статистики распределения данных. На нем представлена следующая информация:
1. Ящик (Box)
Границы ящика: Нижняя граница (Q1) — первый квартиль (25-й процентиль), Верхняя граница (Q3) — третий квартиль (75-й процентиль).
Линия внутри ящика (оранжевая линия): Медиана (Q2, 50-й процентиль).
Ширина ящика показывает межквартильный размах (Interquartile range) (IQR = Q3 – Q1) — диапазон, где сосредоточено 50% данных.
2. Усы (Whiskers)
Верхний ус: Q3 + 1.5 IQR (максимальное значение, не считая выбросов). Нижний ус: Обычно Q1 – 1.5 IQR (минимальное значение, не считая выбросов). Усы могут дополнительно настраиваться (например, до 95% процентилей). Отдельно также добавим в каждый ящик линию средних значений (зеленый пунктир) и отдельно общую линию (красный пунктир), соединяющую средние значения между ящиками, чтобы наглядно видеть динамику. Ниже под график добавим гистограмму с распределением количества значений в каждом ящике.
x = 'customer_service_calls'
values = shap_values[:,features.index(x)]
data = x_train[x]
categories = sorted(list(set(x_train[x])))
groups = []
means = []
for category in categories: relevant_values = values.values[values.data == category] groups.append(relevant_values) means.append(np.mean(relevant_values))
labels = [u for u in categories] # Создаём сетку графиков: верх — boxplot, низ — гистограмма
fig, (ax_box, ax_hist) = plt.subplots(nrows=2, sharex=True, gridspec_kw={"height_ratios": (0.6, 0.4)}, figsize=(8, 6)) #plt.figure(figsize=(8, 5))
ax_box.boxplot(groups, labels = labels, showmeans=True, meanline=True)
ax_box.set_ylabel('Shap values', size=15)
ax_box.set_xlabel(x, size=15)
# Добавление линии, соединяющей средние
ax_box.plot(range(1, len(categories) + 1), means, marker='.', color='red', linestyle='--', linewidth=1, label='Средние')
ax_box.grid(True, linestyle='--', alpha=0.5) # Добавляем подписи
for i, category in enumerate(categories): ax_hist.text(i+1, plt.ylim()[0], f"n={len(groups[i])}", ha="center", va='bottom') # Гистограмма количества наблюдений на нижней панели
counts = [len(x) for x in groups]
ax_hist.bar(range(1, len(categories)+1), counts, color="skyblue", edgecolor="black")
ax_hist.set_ylabel("Количество")
plt.tight_layout()
plt.show()

Можно провести дополнительную аналитику и посчитать корреляцию между значениями фичи и shap-значениями. Корреляция значений признаков с shap показывает, как связаны исходные значения фичей с их вкладом в предсказание модели. Соответственно, при сильной положительной корреляции чем больше значение фичи, тем сильнее он увеличивает предсказание. При отрицательной корреляции, в свою очередь, чем больше значение фичи, тем сильнее он уменьшает предсказание. Если корреляция околонулевая, то нет линейной зависимости между значением фичи и его влиянием. Неожиданно низкие корреляции могут говорить о нелинейных зависимостях (например, U-образная кривая) или о сильных взаимодействиях с другими фичами. Для примера покажем график корреляции в нашем кейсе. Видим, что по фиче total_charge достаточно высокая корреляция, что говорит о линейной зависимости фичи, а по фиче total_intl_calls —корреляция очень низкая.
# Вычисление корреляций
features.remove('customer_service_calls')
correlations = {}
for i, feat in enumerate(features): correlations[feat] = np.corrcoef(x_train[feat], shap_values.values[:,i])[0,1]
# График
plt.figure(figsize=(8, 2))
sns.barplot(x = list(correlations.values()), y = list(correlations.keys()), palette = "vlag")
plt.axvline(0, color = 'black', linestyle = '--')
plt.title('Корреляция значений фичи с Shap')
plt.xlabel('Коэффициент корреляции Пирсона')
plt.xlim(-1, 1)
plt.show()
features = fe_stats[0:5].feature_names.to_list()

Чтобы определить взаимодействие total_intl_calls с другими фичами, посмотрим на нее более детально на графике Shap.dependence_plot(). Можно вывести характер взаимодействия total_intl_calls с другой интересующей нас фичей прямо указав ее в interaction_index. Если такую не указывать, то shap.dependence_plot по умолчанию выберет для раскраски точек на графике фичу с наибольшим взаимодействием. Чем сильнее Shap-значения основной фичи меняются в зависимости от значений другой фичи, тем выше взаимодействие. Формально это оценивается через дисперсию Shap-значений одной фичи, объяснённых другой фичей. Например, если при высоком значении international_plan Shap-значения total_intl_calls резко вырастут, а при низком снизятся,
то international_plan автоматически будет выбрана для раскраски точек (см. пример на графике).
# не указываем фичу для проверки взаимодействия, выводится фича с наибольшим взаимодействием - international_plan
shap.dependence_plot('total_intl_calls', shap_values.values, x_train[features].iloc[:,:])

# указываем фичу с которой хотим проверить взаимодействие - 'number_vmail_messages'
shap.dependence_plot('total_intl_calls', shap_values.values, x_train[features].iloc[:,:], interaction_index = 'number_vmail_messages')

Shap.dependence_plot позволяет уведить неочевидные зависимости, например, что наличие international_plan по разному влияет на прогноз по клиентам с небольшим и большим количеством звонков total_intl_calls. Более детально про Shap.dependence_plot можно прочитать в документации.
На этом завершаем. Надеюсь, что данная статья была полезной для вас.
Какие еще интересные возможности есть у графиков, которые были бы полезны для исследований, пишите в комментариях.