Привет, Хабр!
Эта статья для тех, кто только‑только погружается в машинное обучение и ещё не до конца понимает, что скрывается за интересным вызовом model.fit(). Вы, возможно, уже настраивали ноутбуки, пробовали разные датасеты и, может, даже словили пару неожиданных ошибок — и это нормально.
Зачем копать глубже за fit()
На старте может казаться, что достаточно написать:
model = RandomForestClassifier(n_estimators=100)
model.fit(X_train, y_train)
— и всё заработает. Но стоит проекту вырасти, можно столкнуться с подвохами:
Неожиданные
NotFittedErrorприpredict()Упавшая память на больших выборках
Странное поведение при дообучении
Сложности в интеграции конвейеров и трансформеров
Зная почему так происходит, можно оптимизировать время обучения, контролировать ошибки и выстроить гибкий пайплайн. Сначала разберём, как fit() проверяет наши данные, а затем пройдёмся по остальным этапам.
Валидация данных
Сразу после вызова fit(X, y) у модели запускается внутренняя проверка — validatedata. И да, это не просто формальность:
Преобразование: если вы передали
pandas.DataFrame, он конвертируется вnumpy.ndarray.Сверка размеров: число строк в
Xдолжно совпадать с длинойy.Обработка пропусков:
np.nanи разреженные форматы распознаются и обрабатываются.Приведение типов: целочисленные и низкоточные данные автоматически приводятся к
float64, чтобы алгоритмы могли нормально считать градиенты.
Представьте, вы случайно передали 100 строк признаков, а меток всего 99 — без этой проверки обучение просто рухнет где‑то в глубинах C‑библиотек с непонятным «segmentation fault». А так вы получите понятную ошибку и сможете сразу исправить проблему.
Куда же уходят ваши настройки и гиперпараметры?
Сохраняем гиперпараметры: BaseEstimator в действии
Все алгоритмы в scikit-learn наследуют BaseEstimator. При создании объекта, допустим:
model = LogisticRegression(C=0.1, penalty='l2')
— параметры C и penalty аккуратно ложатся в атрибут dict. Благодаря этому:
Модель можно клонировать clone(model) с теми же настройками. GridSearchCV переберёт все комбинации гиперпараметров. При сериализации joblib.dump можно быть уверенным, что вы не потеряете ни одного значения.
С настройками разобрались, дальше — где и как происходит само обучение.
Собственно обучение: что таится в _fit()
Метод fit() передаёт дело приватному _fit(), где и идёт основная математика.
-
Линейная регрессия решает нормальные уравнения:
X_aug = np.hstack([np.ones((n,1)), X]) # добавляем единичный столбец theta = np.linalg.solve(X_aug.T.dot(X_aug), X_aug.T.dot(y)) self.intercept_, self.coef_ = theta[0], theta[1:]Решаем систему уравнений, только на миллионах строк.
-
Стохастический градиентный спуск (
SGDClassifier):w = np.zeros(n_features) for epoch in range(max_iter): for Xi, yi in shuffle(X, y): grad = compute_gradient(w, Xi, yi) w -= eta * gradЗдесь важно правильно подобрать
learning_rate(eta): слишком большой — не схватится за минимум, слишком маленький — будет ползти вечность. Деревья решений рекурсивно разбивают выборку по лучшим признакам, чтобы минимизировать энтропию или MSE — об этом можно писать отдельный роман, но суть в том, что каждый сплит — это отдельное вычисление статистики.
Когда вычисления завершены, результаты куда‑то записываются — давайте взглянем, куда именно.
Куда деваются результаты — атрибуты с подчёркиванием
После обучения у модели появляются атрибуты, оканчивающиеся на _:
coef_,intercept_у линейных моделейclasses_у классификаторовfeature_importances_у ансамблей деревьевстатистические буферы (
n_iter_, история градиентов и пр.)
Чтобы убедиться, что модель обучена, я всегда использую:
from sklearn.utils.validation import check_is_fitted
check_is_fitted(model)
— и если что‑то упущено, получите дружелюбный NotFittedError.
Различия между fit(), transform() и fit_transform()
fit(X, y): готовит модель к работе — валидация и вычисление внутренних параметров.transform(X): применяет уже посчитанные параметры для преобразования данных (нормализация, PCA и др.).fit_transform(X): объединение первых двух шагов для трансформеров, экономя одну итерацию по данным.
Например:
scaler = StandardScaler()
X_scaled = scaler.fit_transform(X) # за один проход
X_new = scaler.transform(X_test) # применяем те же параметры
Иногда модели умеют оптимизировать fit_transform(), объединяя вычисления в один цикл.
Но что если объём данных огромен или они приходят потоком? Тогда без partial_fit() не обойтись.
Онлайн-обучение и partial_fit()
Когда данные не помещаются в память или приходят постоянно, используем partial_fit():
from sklearn.linear_model import SGDClassifier
clf = SGDClassifier(max_iter=1, tol=None)
# Первый батч: нужно явно указать все классы
clf.partial_fit(X_batch1, y_batch1, classes=np.unique(y_full))
for Xb, yb in get_next_batches():
clf.partial_fit(Xb, yb)
Каждый вызов берёт новую порцию данных и продолжает обучение с текущих весов, не забывая историю градиентов.
А если хочется добавить новые деревья в лес без пересчёта старых? Тогда выручит warm_start.
warm_start:
У многих ансамблей RandomForest, GradientBoosting есть опция warm_start. Вот как она работает на практике:
-
Создаём лес из 50 деревьев:
rf = RandomForestClassifier(n_estimators=50, warm_start=True) rf.fit(X_train, y_train) -
Хотим добавить ещё 50 деревьев:
rf.n_estimators = 100 rf.fit(X_train, y_train) # новые 50 деревьев дописываются к существующим
Это позволяет постепенно наращивать сложность модели, не теряя уже обученные структуры.
Часто мы объединяем разные шаги в единый конвейер — посмотрим, как Pipeline и GridSearchCV взаимодействуют с fit().
Pipeline и GridSearchCV
С Pipeline мы объединяем несколько шагов:
from sklearn.pipeline import Pipeline
pipe = Pipeline([
('scaler', StandardScaler()),
('clf', LogisticRegression())
])
pipe.fit(X_train, y_train)
Последовательность действий:
scaler.fit_transform(X_train)clf.fit(X_scaled, y_train)
Если этот пайплайн завернуть в GridSearchCV, для каждого набора параметров он будет прогонять fit() заново, сохранять результаты и выбирать наилучший вариант.
Параметры обучения и колбэки
В чистом scikit-learn у fit() самые распространённые доп. параметры — это sample_weight и флаги для проверки входных данных. Но многие сторонние библиотеки (XGBoost, LightGBM) через знакомый интерфейс fit() принимают:
early_stopping_rounds— остановка по валидационной метрикеeval_set— данные для валидацииverbose— подробный лог обучения
А что делать, если модель может упаковаться в мультиядерный режим? Тут пригодится n_jobs.
Отладка и профилирование fit()
Чтобы понять, куда уходит время и память, рекомендую:
-
cProfile:
python -m cProfile train.py line_profiler и
@profileдля детального разбора функций-
memory_profiler:
from memory_profiler import profile @profile def train(): model.fit(X, y) verbose у моделей для по‑шаговых логов
Что важно запомнить
В итоге — fit() это последовательность: валидация, сохранение параметров, запуск fit() и запись результатов. Для потоковых данных используйте partialfit(), для дообучения ансамблей — warm_start. Собирайте всё в Pipeline, подключайте GridSearchCV и не забывайте про логирование и профилирование.
Хотите освоить машинное обучение с нуля и стать уверенным Junior-специалистом? Онлайн-курс Otus «Machine Learning. Basic» — это ваш шанс получить знания от экспертов отрасли и прокачать навыки на реальных проектах.
За время обучения вы не только разберётесь в теории, но и будете решать реальные задачи, анализировать данные и строить свои первые модели. Прокачайте компетенции, которые востребованы на рынке, и начните карьеру в одном из самых перспективных направлений!