Привет, Хабр!
Эта статья — о том, как кастомизировать функции потерь в CatBoost.
Стандартные функции потерь хороши для типовых задач, но в нашей суровой жизни часто требуются специфичные решения. Например, может понадобиться усилить внимание модели на редких классах или минимизировать разные типы ошибок в зависимости от их влияния на бизнес..
Кастомные функции в CatBoost
CatBoost позволяет задавать свои функции потерь через класс с методом calc_ders_range
. Этот метод должен возвращать градиенты и Гессиан (вторые производные), которые алгоритм использует для обновления дерева.
Ручной MSE
Начнем с простого примера, где вручну. создадим MSE
для регрессии. Это поможет понять, как работает механизм кастомизации:
from catboost import CatBoostRegressor, Pool
class CustomMSE:
def calc_ders_range(self, approxes, targets, weights):
der1 = []
der2 = []
for approx, target in zip(approxes, targets):
error = approx - target
der1.append(error) # Градиент (разница между предсказанием и истиной)
der2.append(1.0) # Гессиан (в MSE фиксируем на 1)
return zip(der1, der2)
approxes
и targets
— это предсказанные и реальные значения.
Градиент ( в кодеder1
) — ошибка, которую модель должна минимизировать. В данном случае это разница между approx
и target
.
Гессиан (der2
) — второй порядок производной. Здесь он фиксирован как 1, что в целом допустимо для простых задач, как MSE.
Усиление ошибки на редких классах
Теперь пример с классификацией, где усилим ошибку на редких классах. Скажем, есть классы 0 и 1, и 1 встречается реже. Чтобы модель хоть как-то задумалась над ошибками на этом классе, добавим к ним вес.
class RareClassLoss:
def __init__(self, class_weight=10):
self.class_weight = class_weight
def calc_ders_range(self, approxes, targets, weights):
der1 = []
der2 = []
for approx, target in zip(approxes, targets):
diff = approx - target
weight = self.class_weight if target == 1 else 1 # Вес редкого класса
der1.append(diff * weight)
der2.append(weight)
return zip(der1, der2)
Вес ошибки для редкого класса: добавляем атрибут
class_weight
, который усиливает вклад ошибки, еслиtarget == 1
(редкий класс).Градиент и Гессиан: ошибка
diff
умножается на вес. Чем больше вес, тем сильнее модель будет обращать внимание на ошибки по редкому классу.
Где использовать этот подход? Например, в задачах, где важно не пропускать примеры с редким классом — например, в задачах классификации с дисбалансом классов.
Асимметричные ошибки — ложноположительные и ложноотрицательные
В задачах, где цена ложноположительной и ложноотрицательной ошибки различается, нужна асимметричная функция потерь. Добавим веса для разных типов ошибок.
class AsymmetricLoss:
def __init__(self, false_positive_weight=2, false_negative_weight=5):
self.false_positive_weight = false_positive_weight
self.false_negative_weight = false_negative_weight
def calc_ders_range(self, approxes, targets, weights):
der1 = []
der2 = []
for approx, target in zip(approxes, targets):
diff = approx - target
# Ложноположительная ошибка
if diff > 0:
der1.append(diff * self.false_positive_weight)
der2.append(self.false_positive_weight)
# Ложноотрицательная ошибка
else:
der1.append(diff * self.false_negative_weight)
der2.append(self.false_negative_weight)
return zip(der1, der2)
Параметры для ложноположительных и ложноотрицательных ошибок. используем разные веса, передавая их через false_positive_weight
и false_negative_weight
.
Условие на diff
: если ошибка положительная (модель ошиблась «вверх»), то применяем применяем вес false_positive_weight
, иначе — false_negative_weight
.
Итак, используем этот подход там, где одни типы ошибок обходятся дороже других, например, когда пропуск критичнее ложного срабатывания.
Как оптимизировать все это дело
Конечно же, кастомные функции могут снижать скорость обучения. Для этого можно:
Использовать
Numba
. Эта библиотека может ускорить Python-код, если функция сложная. Подробнее с ней можно ознакомиться здесь.Лишние условия замедляют выполнение. Заменяйте их на векторизированные операции, если это допустимо в вашем проекте.
Тестируйте кастомную функцию на небольшом наборе данных, чтобы убедиться, что она не тормозит и ведёт себя корректно.
Подробнее с библиотекой можно ознакомиться здесь.
А на странице курса Machine Learning от OTUS вы можете зарегистрироваться на бесплатный вебинар: «Обучение с учителем: разбираем задачу классификации».