Привет, Хабр! Сегодня рассмотрим невинный на первый взгляд параметр shuffle=True в train_test_split.

Под «перемешать» подразумевается применение псевдо-рандомного пермутационного алгоритма (обычно Fisher–Yates) к индексам выборки до того, как мы режем её на train/test. Цель — заставить train-и-test быть независимыми и одинаково распределёнными (i.i.d.). В scikit-learn эта логика зашита в параметр shuffle почти всех сплиттеров. В train_test_split он True по умолчанию, что прямо сказано в документации — «shuffle bool, default=True».

train_test_split

from sklearn.model_selection import train_test_split

X_train, X_test, y_train, y_test = train_test_split(
    features,
    target,
    test_size=0.2,
    random_state=42,  # надо для репликабельности
    shuffle=True      # смело тасуем
)

Когда shuffle=True, функция:

  1. Генерирует случайную перестановку индексов (учитывая random_state).

  2. По ней делит данные.

Если присвоить shuffle=False, она просто берет “с головы” train_size строк, а “хвост” — test. Расклад очевиден из исходников и подтверждён в официальном доке.

Когда shuffle=False — обязательное условие

Временные ряды

Time-series — классика, где порядок — закон. Если мы перемешаем, то модель увидит будущее раньше прошлого и станет гадалкой. С точки зрения статистики это “look-ahead bias”. На том же Cross-Validated прямым текстом: в time-series нужно держать хронологию и юзать TimeSeriesSplit.

Зависимости внутри групп

Клинические данные, где несколько записей на одного пациента; логи пользователей, где одна сессия раскидана на десятки строк. Если рандомно раскидать строки по split’ам, то в test попадут “следы” тех же юзеров, что и в train, а это утечка через id-коррелированные признаки.

Продуктовые AB-эксперименты и всё, где важен session-level split

Тут групповая целостность — must-have. Мы шейкаем между группами, но не внутри.

Где случайность вредит обучению

  • Look-ahead bias — когда модель учится на будущей информации.

  • Target leakage — признак сформирован на основе целевой переменной или будущих значений.

  • Temporal leakage — метки пакуются по календарю: например, is_holiday. Если их перетасовать, тест узнает праздники раньше времени.

leakage — в целом сам по себе самый популярный баг ML-систем. Утечка часто выглядит невинно: добавили total_sales_next_month как фичу для модели, предсказывающей спрос — и получили 99 % R².

Как делать GroupShuffle или TimeSeriesSplit

GroupShuffleSplit

from sklearn.model_selection import GroupShuffleSplit

gss = GroupShuffleSplit(n_splits=1, test_size=0.2, random_state=42)
train_idx, test_idx = next(gss.split(X, y, groups=user_id))

X_train, X_test = X.iloc[train_idx], X.iloc[test_idx]
y_train, y_test = y.iloc[train_idx], y.iloc[test_idx]

GroupShuffleSplit гарантирует, что у каждого user_id ровно один сплит: либо train, либо test. Под капотом он рандомно тасует сами группы, а не записи.

TimeSeriesSplit

from sklearn.model_selection import TimeSeriesSplit

tscv = TimeSeriesSplit(n_splits=5, test_size=24*7)  # неделя в часах
for fold, (train_idx, test_idx) in enumerate(tscv.split(X)):
    model.fit(X.iloc[train_idx], y.iloc[train_idx])
    y_pred = model.predict(X.iloc[test_idx])
    ...

Сплиттер отдаёт растающий train и скользящее окно test. Шейка нет вовсе — порядок священен.

Кейс на примере магазина котиков

У нас зоомаркет Purrfect Shop. В базе лежат четыре ключевых таблицы:

Таблица

Что внутри

Гранулярность

customers

customer_id, демография, дата регистрации

1 строка на владельца

cats

cat_id, порода, цена, дата поступления

1 строка на котика

orders

order_id, customer_id, order_dt, total_sum

1 строка на чек

order_lines

order_id, cat_id, qty

1 строка на позицию в чеке

Мы хотим решить две задачи:

  1. Churn-классификация: предсказать, уйдёт ли клиент в течение 30 дней.

  2. Прогноз оборота на следующие 7 дней из тайм-серии.

Плюс обучаем CNN, которая по фото угадывает породу для автозаполнения карточек.

Churn-модель: где shuffle обязателен и где нельзя

Наивный, но опасный подход:

from sklearn.model_selection import train_test_split

X = features_df  # собрали признаки на уровне *клиента*
y = labels_df['will_churn']

X_train, X_test, y_train, y_test = train_test_split(
    X, y, test_size=0.25, shuffle=True, random_state=42
)

Каждая строка — клиент, а значит зависимостей во времени нет. Shuffle здесь уместен: классы «уйдет/останется» раскиданы равномерно — модель не видит паттерна “первые 75 % клиентов — новые, последние — старые”.

Мы решаем “чуть улучшить” датасет и переходим на строки-чек. Один клиент = десятки чеков:

orders_df['will_churn'] = ...
X = orders_df.drop('will_churn', axis=1)
y = orders_df['will_churn']

# те же 5 строк кода:
X_train, X_test, y_train, y_test = train_test_split(
    X, y, test_size=0.25, shuffle=True, random_state=42
)

Теперь половина чеков Петя-Котолюб попала в train, половина — в test.

Утечка: признаки, вроде «total_sum_last_3_orders», пересекаются. Результат — AUC = 0.97 на тесте, но в проде падаем до 0.68 и ловим.

Правильно: GroupShuffleSplit

from sklearn.model_selection import GroupShuffleSplit

gss = GroupShuffleSplit(test_size=0.25, n_splits=1, random_state=42)
train_idx, test_idx = next(gss.split(X, y, groups=orders_df['customer_id']))

X_train, X_test = X.iloc[train_idx], X.iloc[test_idx]
y_train, y_test = y.iloc[train_idx], y.iloc[test_idx]

Теперь каждый клиент живёт только в одном сплите. AUC честно падает до 0.79 — зато в проде всё стабильно.

Прогноз оборота: shuffle=False, иначе бабах

Как ломается time series:

from sklearn.model_selection import train_test_split

# aggregated_df: daily revenue, lag-features, holidays, etc.
X = aggregated_df.drop('revenue', axis=1)
y = aggregated_df['revenue']

X_train, X_test, y_train, y_test = train_test_split(
    X, y, test_size=0.2, shuffle=True, random_state=42
)

Модель случайно видит 2025-06-01 в train, а 2025-05-20 в test.

Используем TimeSeriesSplit:

from sklearn.model_selection import TimeSeriesSplit
import lightgbm as lgb
import numpy as np

tscv = TimeSeriesSplit(n_splits=5, test_size=7)
scores = []

for fold, (tr, val) in enumerate(tscv.split(X)):
    model = lgb.LGBMRegressor(n_estimators=500, learning_rate=0.03)
    model.fit(X.iloc[tr], y.iloc[tr])
    preds = model.predict(X.iloc[val])
    rmse = np.sqrt(((preds - y.iloc[val])**2).mean())
    scores.append(rmse)
    print(f'Fold {fold}: RMSE={rmse:.2f}')

print(f'Mean CV RMSE: {np.mean(scores):.2f}')

Без shuffle, с растущим окном.

CNN для фото-котиков

Когда обучаем сверточку, порядок картинок не важен; наоборот, shuffle помогает стохастическому градиенту быстрее и стабильнее сходиться.

from torchvision import datasets, transforms
from torch.utils.data import DataLoader

train_ds = datasets.ImageFolder(
    root='cats/train',
    transform=transforms.Compose([
        transforms.Resize((224, 224)),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor()
    ])
)

train_loader = DataLoader(
    train_ds,
    batch_size=32,
    shuffle=True,          # критично!
    num_workers=4,
    pin_memory=True
)
val_loader = DataLoader(
    val_ds,
    batch_size=32,
    shuffle=False,        # чтоб метрики не «плясали»
    num_workers=4
)

На примере нашего котомаркета видим три сценария:

  • Табличные i.i.d. данные — мешаем, чтобы избежать систематической ошибки.

  • Группы / время — бережём порядок, потому что утечка дороже.

  • Vision / NLP — мешаем внутри эпохи, но держим валидацию детерминированной.

Если сомневаетесь — не шейкайте.

Шейкать или не шейкать

Тип данных / задача

Шейкать?

Сплиттер

Комментарий

Классический табличный ML

Да

train_test_split

Базовая практика, избегаем order-bias.

Изображения/тексты без явных групп

Да

train_test_split + stratify

Сохраняем баланс классов.

Логи пользователей, несколько строк на ID

Нет

GroupShuffleSplit

Целостность группы важнее.

Временные ряды (прогноз спроса, финансы)

Нет

TimeSeriesSplit

Хронология важна.

Подготовка hold-out для AB-теста

Нет

GroupShuffleSplit

Сессии не должны пересекаться.

Kaggle с фейк-ID, но явной утечки нет

Скорее да

StratifiedKFold

Читайте описание соревнования.

Tiny-датасет ≤ 100 строк

Да, но фиксируйте seed

Любой

Варьируйте до 10-кратного CV для стабильности.


Вывод

Если чувствуете запах временной или групповой зависимости — уберите руку от shuffle=True и достаньте правильный сплиттер.


Готовите данные для моделей машинного обучения? Тогда знаете: неправильный сплит — и модель учит будущее, «подглядывает» в тест и в итоге проваливается в проде.

Если вам близки такие темы, как предотвращение утечек, GroupShuffle, TimeSeriesSplit, честные A/B‑тесты и грамотная работа с временными рядами — в Otus пройдут скоро открытые уроки, которые рекомендуем посетить:

Хотите больше? Загляните в каталог курсов — там есть всё: от ML‑специализации до продвинутого Python.

А чтобы ничего не пропустить, добавьте календарь открытых уроков — пусть он напомнит вам, когда стоит подключиться к трансляции.

Комментарии (0)