Привет, Хабр!

Эта статья для тех, кто только‑только погружается в машинное обучение и ещё не до конца понимает, что скрывается за интересным вызовом model.fit(). Вы, возможно, уже настраивали ноутбуки, пробовали разные датасеты и, может, даже словили пару неожиданных ошибок — и это нормально.


Зачем копать глубже за fit()

На старте может казаться, что достаточно написать:

model = RandomForestClassifier(n_estimators=100)
model.fit(X_train, y_train)

— и всё заработает. Но стоит проекту вырасти, можно столкнуться с подвохами:

  • Неожиданные NotFittedError при predict()

  • Упавшая память на больших выборках

  • Странное поведение при дообучении

  • Сложности в интеграции конвейеров и трансформеров

Зная почему так происходит, можно оптимизировать время обучения, контролировать ошибки и выстроить гибкий пайплайн. Сначала разберём, как fit() проверяет наши данные, а затем пройдёмся по остальным этапам.

Валидация данных

Сразу после вызова fit(X, y) у модели запускается внутренняя проверка — validatedata. И да, это не просто формальность:

  • Преобразование: если вы передали pandas.DataFrame, он конвертируется в numpy.ndarray.

  • Сверка размеров: число строк в X должно совпадать с длиной y.

  • Обработка пропусков: np.nan и разреженные форматы распознаются и обрабатываются.

  • Приведение типов: целочисленные и низкоточные данные автоматически приводятся к float64, чтобы алгоритмы могли нормально считать градиенты.

Представьте, вы случайно передали 100 строк признаков, а меток всего 99 — без этой проверки обучение просто рухнет где‑то в глубинах C‑библиотек с непонятным «segmentation fault». А так вы получите понятную ошибку и сможете сразу исправить проблему.

Куда же уходят ваши настройки и гиперпараметры?

Сохраняем гиперпараметры: BaseEstimator в действии

Все алгоритмы в scikit-learn наследуют BaseEstimator. При создании объекта, допустим:

model = LogisticRegression(C=0.1, penalty='l2')

— параметры C и penalty аккуратно ложатся в атрибут dict. Благодаря этому:

Модель можно клонировать clone(model) с теми же настройками. GridSearchCV переберёт все комбинации гиперпараметров. При сериализации joblib.dump можно быть уверенным, что вы не потеряете ни одного значения.

С настройками разобрались, дальше — где и как происходит само обучение.

Собственно обучение: что таится в _fit()

Метод fit() передаёт дело приватному _fit(), где и идёт основная математика.

  • Линейная регрессия решает нормальные уравнения:

    X_aug = np.hstack([np.ones((n,1)), X])  # добавляем единичный столбец
    theta = np.linalg.solve(X_aug.T.dot(X_aug), X_aug.T.dot(y))
    self.intercept_, self.coef_ = theta[0], theta[1:]

    Решаем систему уравнений, только на миллионах строк.

  • Стохастический градиентный спуск (SGDClassifier):

    w = np.zeros(n_features)
    for epoch in range(max_iter):
        for Xi, yi in shuffle(X, y):
            grad = compute_gradient(w, Xi, yi)
            w -= eta * grad

    Здесь важно правильно подобрать learning_rate (eta): слишком большой — не схватится за минимум, слишком маленький — будет ползти вечность.

  • Деревья решений рекурсивно разбивают выборку по лучшим признакам, чтобы минимизировать энтропию или MSE — об этом можно писать отдельный роман, но суть в том, что каждый сплит — это отдельное вычисление статистики.

Когда вычисления завершены, результаты куда‑то записываются — давайте взглянем, куда именно.

Куда деваются результаты — атрибуты с подчёркиванием

После обучения у модели появляются атрибуты, оканчивающиеся на _:

  • coef_, intercept_ у линейных моделей

  • classes_ у классификаторов

  • feature_importances_ у ансамблей деревьев

  • статистические буферы (n_iter_, история градиентов и пр.)

Чтобы убедиться, что модель обучена, я всегда использую:

from sklearn.utils.validation import check_is_fitted
check_is_fitted(model)

— и если что‑то упущено, получите дружелюбный NotFittedError.

Различия между fit(), transform() и fit_transform()

  • fit(X, y): готовит модель к работе — валидация и вычисление внутренних параметров.

  • transform(X): применяет уже посчитанные параметры для преобразования данных (нормализация, PCA и др.).

  • fit_transform(X): объединение первых двух шагов для трансформеров, экономя одну итерацию по данным.

Например:

scaler = StandardScaler()
X_scaled = scaler.fit_transform(X)  # за один проход
X_new = scaler.transform(X_test)    # применяем те же параметры

Иногда модели умеют оптимизировать fit_transform(), объединяя вычисления в один цикл.

Но что если объём данных огромен или они приходят потоком? Тогда без partial_fit() не обойтись.

Онлайн-обучение и partial_fit()

Когда данные не помещаются в память или приходят постоянно, используем partial_fit():

from sklearn.linear_model import SGDClassifier
clf = SGDClassifier(max_iter=1, tol=None)
# Первый батч: нужно явно указать все классы
clf.partial_fit(X_batch1, y_batch1, classes=np.unique(y_full))
for Xb, yb in get_next_batches():
    clf.partial_fit(Xb, yb)

Каждый вызов берёт новую порцию данных и продолжает обучение с текущих весов, не забывая историю градиентов.

А если хочется добавить новые деревья в лес без пересчёта старых? Тогда выручит warm_start.

warm_start:

У многих ансамблей RandomForest, GradientBoosting есть опция warm_start. Вот как она работает на практике:

  1. Создаём лес из 50 деревьев:

    rf = RandomForestClassifier(n_estimators=50, warm_start=True)
    rf.fit(X_train, y_train)
  2. Хотим добавить ещё 50 деревьев:

    rf.n_estimators = 100
    rf.fit(X_train, y_train)  # новые 50 деревьев дописываются к существующим

Это позволяет постепенно наращивать сложность модели, не теряя уже обученные структуры.

Часто мы объединяем разные шаги в единый конвейер — посмотрим, как Pipeline и GridSearchCV взаимодействуют с fit().

Pipeline и GridSearchCV

С Pipeline мы объединяем несколько шагов:

from sklearn.pipeline import Pipeline

pipe = Pipeline([
    ('scaler', StandardScaler()),
    ('clf', LogisticRegression())
])
pipe.fit(X_train, y_train)

Последовательность действий:

  1. scaler.fit_transform(X_train)

  2. clf.fit(X_scaled, y_train)

Если этот пайплайн завернуть в GridSearchCV, для каждого набора параметров он будет прогонять fit() заново, сохранять результаты и выбирать наилучший вариант.

Параметры обучения и колбэки

В чистом scikit-learn у fit() самые распространённые доп. параметры — это sample_weight и флаги для проверки входных данных. Но многие сторонние библиотеки (XGBoost, LightGBM) через знакомый интерфейс fit() принимают:

  • early_stopping_rounds — остановка по валидационной метрике

  • eval_set — данные для валидации

  • verbose — подробный лог обучения

А что делать, если модель может упаковаться в мультиядерный режим? Тут пригодится n_jobs.

Отладка и профилирование fit()

Чтобы понять, куда уходит время и память, рекомендую:

  • cProfile:

    python -m cProfile train.py
  • line_profiler и @profile для детального разбора функций

  • memory_profiler:

    from memory_profiler import profile
    @profile
    def train():
        model.fit(X, y)
  • verbose у моделей для по‑шаговых логов

Что важно запомнить

В итоге — fit() это последовательность: валидация, сохранение параметров, запуск fit() и запись результатов. Для потоковых данных используйте partialfit(), для дообучения ансамблей — warm_start. Собирайте всё в Pipeline, подключайте GridSearchCV и не забывайте про логирование и профилирование.


Хотите освоить машинное обучение с нуля и стать уверенным Junior-специалистом? Онлайн-курс Otus «Machine Learning. Basic» — это ваш шанс получить знания от экспертов отрасли и прокачать навыки на реальных проектах.

За время обучения вы не только разберётесь в теории, но и будете решать реальные задачи, анализировать данные и строить свои первые модели. Прокачайте компетенции, которые востребованы на рынке, и начните карьеру в одном из самых перспективных направлений!

Комментарии (0)