Привет, Хабр!
Эта статья для тех, кто только‑только погружается в машинное обучение и ещё не до конца понимает, что скрывается за интересным вызовом model.fit()
. Вы, возможно, уже настраивали ноутбуки, пробовали разные датасеты и, может, даже словили пару неожиданных ошибок — и это нормально.
Зачем копать глубже за fit()
На старте может казаться, что достаточно написать:
model = RandomForestClassifier(n_estimators=100)
model.fit(X_train, y_train)
— и всё заработает. Но стоит проекту вырасти, можно столкнуться с подвохами:
Неожиданные
NotFittedError
приpredict()
Упавшая память на больших выборках
Странное поведение при дообучении
Сложности в интеграции конвейеров и трансформеров
Зная почему так происходит, можно оптимизировать время обучения, контролировать ошибки и выстроить гибкий пайплайн. Сначала разберём, как fit()
проверяет наши данные, а затем пройдёмся по остальным этапам.
Валидация данных
Сразу после вызова fit(X, y)
у модели запускается внутренняя проверка — validatedata
. И да, это не просто формальность:
Преобразование: если вы передали
pandas.DataFrame
, он конвертируется вnumpy.ndarray
.Сверка размеров: число строк в
X
должно совпадать с длинойy
.Обработка пропусков:
np.nan
и разреженные форматы распознаются и обрабатываются.Приведение типов: целочисленные и низкоточные данные автоматически приводятся к
float64
, чтобы алгоритмы могли нормально считать градиенты.
Представьте, вы случайно передали 100 строк признаков, а меток всего 99 — без этой проверки обучение просто рухнет где‑то в глубинах C‑библиотек с непонятным «segmentation fault». А так вы получите понятную ошибку и сможете сразу исправить проблему.
Куда же уходят ваши настройки и гиперпараметры?
Сохраняем гиперпараметры: BaseEstimator в действии
Все алгоритмы в scikit-learn
наследуют BaseEstimator
. При создании объекта, допустим:
model = LogisticRegression(C=0.1, penalty='l2')
— параметры C
и penalty
аккуратно ложатся в атрибут dict
. Благодаря этому:
Модель можно клонировать clone(model)
с теми же настройками. GridSearchCV
переберёт все комбинации гиперпараметров. При сериализации joblib.dump
можно быть уверенным, что вы не потеряете ни одного значения.
С настройками разобрались, дальше — где и как происходит само обучение.
Собственно обучение: что таится в _fit()
Метод fit()
передаёт дело приватному _fit()
, где и идёт основная математика.
-
Линейная регрессия решает нормальные уравнения:
X_aug = np.hstack([np.ones((n,1)), X]) # добавляем единичный столбец theta = np.linalg.solve(X_aug.T.dot(X_aug), X_aug.T.dot(y)) self.intercept_, self.coef_ = theta[0], theta[1:]
Решаем систему уравнений, только на миллионах строк.
-
Стохастический градиентный спуск (
SGDClassifier
):w = np.zeros(n_features) for epoch in range(max_iter): for Xi, yi in shuffle(X, y): grad = compute_gradient(w, Xi, yi) w -= eta * grad
Здесь важно правильно подобрать
learning_rate
(eta
): слишком большой — не схватится за минимум, слишком маленький — будет ползти вечность. Деревья решений рекурсивно разбивают выборку по лучшим признакам, чтобы минимизировать энтропию или MSE — об этом можно писать отдельный роман, но суть в том, что каждый сплит — это отдельное вычисление статистики.
Когда вычисления завершены, результаты куда‑то записываются — давайте взглянем, куда именно.
Куда деваются результаты — атрибуты с подчёркиванием
После обучения у модели появляются атрибуты, оканчивающиеся на _
:
coef_
,intercept_
у линейных моделейclasses_
у классификаторовfeature_importances_
у ансамблей деревьевстатистические буферы (
n_iter_
, история градиентов и пр.)
Чтобы убедиться, что модель обучена, я всегда использую:
from sklearn.utils.validation import check_is_fitted
check_is_fitted(model)
— и если что‑то упущено, получите дружелюбный NotFittedError
.
Различия между fit(), transform() и fit_transform()
fit(X, y)
: готовит модель к работе — валидация и вычисление внутренних параметров.transform(X)
: применяет уже посчитанные параметры для преобразования данных (нормализация, PCA и др.).fit_transform(X)
: объединение первых двух шагов для трансформеров, экономя одну итерацию по данным.
Например:
scaler = StandardScaler()
X_scaled = scaler.fit_transform(X) # за один проход
X_new = scaler.transform(X_test) # применяем те же параметры
Иногда модели умеют оптимизировать fit_transform()
, объединяя вычисления в один цикл.
Но что если объём данных огромен или они приходят потоком? Тогда без partial_fit()
не обойтись.
Онлайн-обучение и partial_fit()
Когда данные не помещаются в память или приходят постоянно, используем partial_fit()
:
from sklearn.linear_model import SGDClassifier
clf = SGDClassifier(max_iter=1, tol=None)
# Первый батч: нужно явно указать все классы
clf.partial_fit(X_batch1, y_batch1, classes=np.unique(y_full))
for Xb, yb in get_next_batches():
clf.partial_fit(Xb, yb)
Каждый вызов берёт новую порцию данных и продолжает обучение с текущих весов, не забывая историю градиентов.
А если хочется добавить новые деревья в лес без пересчёта старых? Тогда выручит warm_start
.
warm_start:
У многих ансамблей RandomForest
, GradientBoosting
есть опция warm_start
. Вот как она работает на практике:
-
Создаём лес из 50 деревьев:
rf = RandomForestClassifier(n_estimators=50, warm_start=True) rf.fit(X_train, y_train)
-
Хотим добавить ещё 50 деревьев:
rf.n_estimators = 100 rf.fit(X_train, y_train) # новые 50 деревьев дописываются к существующим
Это позволяет постепенно наращивать сложность модели, не теряя уже обученные структуры.
Часто мы объединяем разные шаги в единый конвейер — посмотрим, как Pipeline
и GridSearchCV
взаимодействуют с fit()
.
Pipeline и GridSearchCV
С Pipeline
мы объединяем несколько шагов:
from sklearn.pipeline import Pipeline
pipe = Pipeline([
('scaler', StandardScaler()),
('clf', LogisticRegression())
])
pipe.fit(X_train, y_train)
Последовательность действий:
scaler.fit_transform(X_train)
clf.fit(X_scaled, y_train)
Если этот пайплайн завернуть в GridSearchCV
, для каждого набора параметров он будет прогонять fit()
заново, сохранять результаты и выбирать наилучший вариант.
Параметры обучения и колбэки
В чистом scikit-learn
у fit()
самые распространённые доп. параметры — это sample_weight
и флаги для проверки входных данных. Но многие сторонние библиотеки (XGBoost, LightGBM) через знакомый интерфейс fit()
принимают:
early_stopping_rounds
— остановка по валидационной метрикеeval_set
— данные для валидацииverbose
— подробный лог обучения
А что делать, если модель может упаковаться в мультиядерный режим? Тут пригодится n_jobs
.
Отладка и профилирование fit()
Чтобы понять, куда уходит время и память, рекомендую:
-
cProfile:
python -m cProfile train.py
line_profiler и
@profile
для детального разбора функций-
memory_profiler:
from memory_profiler import profile @profile def train(): model.fit(X, y)
verbose у моделей для по‑шаговых логов
Что важно запомнить
В итоге — fit()
это последовательность: валидация, сохранение параметров, запуск fit()
и запись результатов. Для потоковых данных используйте partialfit()
, для дообучения ансамблей — warm_start
. Собирайте всё в Pipeline
, подключайте GridSearchCV
и не забывайте про логирование и профилирование.
Хотите освоить машинное обучение с нуля и стать уверенным Junior-специалистом? Онлайн-курс Otus «Machine Learning. Basic» — это ваш шанс получить знания от экспертов отрасли и прокачать навыки на реальных проектах.
За время обучения вы не только разберётесь в теории, но и будете решать реальные задачи, анализировать данные и строить свои первые модели. Прокачайте компетенции, которые востребованы на рынке, и начните карьеру в одном из самых перспективных направлений!