Привет, Хабр!

Если вы когда-либо имели дело с временными рядами, то, вероятно, слышали о Darts. А для тех, кто ещё в танке: Darts — это мощный инструмент, который поддерживает мультиварибельные временные ряды и легко интегрируется с PyTorch и TensorFlow.

Зачем же тестировать временные ряды, когда в классическом машинном обучении всё так просто с кросс-валидацией? Временные ряды обладают своей изюминкой: они подвержены временным зависимостям, сезонности, трендам и другим радостям жизни. Так что, если вы хотите, чтобы ваши модели не провалились на тестах, время разобраться с их особенностями!

Основы тестирования временных рядов

Перед тем как приступать к работе с данными, нужно убедиться, что они в правильном формате.

Формат данных

Первым делом нужно проверить временные метки на корректность. Важно, чтобы не было пропусков! С помощью pandas это делается легко:

import pandas as pd

# Загружаем данные
data = pd.read_csv('your_data.csv', parse_dates=['timestamp'], index_col='timestamp')

# Проверка на пропуски
missing_values = data.isnull().sum()
print(f"Пропуски в данных:\n{missing_values}")

# Проверка на равномерность временных меток
time_diff = data.index.to_series().diff()
print(f"Минимальный интервал между временными метками: {time_diff.min()}")
print(f"Максимальный интервал между временными метками: {time_diff.max()}")

Распределение данных

Важно убедиться, что данные равномерно распределены. Если выйдет так, что обнаружатся большие интервалы без наблюдений, может возникнуть необходимость в ресэмплинге или интерполяции. Помните, что данные должны быть в форме, удобной для анализа.

Перед тем как запустить свою модель, выполним несколько предобработок, которые могут существенно повлиять на точность прогнозов.

Скейлинг: Приведение данных к общему масштабу поможет вашему алгоритму быстрее схватывать закономерности. Используйте StandardScaler или MinMaxScaler из scikit-learn:

from sklearn.preprocessing import MinMaxScaler

scaler = MinMaxScaler()
data_scaled = scaler.fit_transform(data.values.reshape(-1, 1))

Логарифмирование: Если ваши данные демонстрируют экспоненциальный рост, логарифмирование станет вашим лучшим другом:

import numpy as np

data_log = np.log(data + 1)  # Добавляем 1, чтобы избежать log(0)

Иммутация пропусков: Обработка пропущенных значений — ещё один важный шаг:

data_filled = data.interpolate(method='linear')

Метрики для оценки моделей

Когда дело доходит до оценки моделей временных рядов, стандартные метрики могут вас разочаровать. Временные ряды требуют специфических подходов:

RMSE: Это показатель отклонения предсказаний от реальных значений.

from sklearn.metrics import mean_squared_error

rmse = np.sqrt(mean_squared_error(y_true, y_pred))
print(f"RMSE: {rmse}")

MAE: Более устойчивый к выбросам, MAE дает ясное представление о точности модели.

mae = np.mean(np.abs(y_true - y_pred))
print(f"MAE: {mae}")

MASE: Эта метрика помогает сравнивать качество модели с наивным подходом.

mase = np.mean(np.abs(y_true - y_pred)) / np.mean(np.abs(y_true - np.roll(y_true, 1)[1:]))
print(f"MASE: {mase}")

Как протестировать данные и прогнозы с помощью Darts

Итак, начнем с загрузки данных. Допустим, есть данные о продажах, хранящиеся в CSV-файле.

import pandas as pd
from darts import TimeSeries

# Загружаем данные
data = pd.read_csv('sales_data.csv', parse_dates=['date'], index_col='date')

# Создаем временной ряд
series = TimeSeries.from_dataframe(data, 'date', 'sales')
print("Данные загружены и преобразованы в временной ряд:")
print(series)

Перед тем как двигаться дальше, убедимся, что с временным рядом всё в порядке. Проверим на наличие пропусков и аномалий:

# Проверка на пропуски
if series.isnull().any():
    print("Обнаружены пропуски в данных!")
else:
    print("Пропусков нет!")

# Визуализация данных
series.plot(title='График продаж', xlabel='Дата', ylabel='Количество продаж')

Теперь, когда мы уверены в целостности данных, проведём их предобработку.

Логарифмирование:

import numpy as np

# Логарифмирование
series_log = TimeSeries.from_dataframe(np.log1p(data.set_index('date')['sales']))
series_log.plot(title='Логарифмированные данные')

Скейлинг:

from sklearn.preprocessing import MinMaxScaler

scaler = MinMaxScaler()
scaled_data = scaler.fit_transform(data['sales'].values.reshape(-1, 1))
scaled_series = TimeSeries.from_dataframe(pd.DataFrame(scaled_data, index=data['date']))
scaled_series.plot(title='Масштабированные данные')

Теперь всё готово к построению модели. В Darts есть множество моделей, и мы будем использовать N‑BEATS.

from darts.models import NBEATSModel

# Определяем модель
model = NBEATSModel(input_chunk_length=30, output_chunk_length=10, n_epochs=100)

# Обучаем модель
model.fit(series_log)
print("Модель N-BEATS успешно обучена.")

Теперь протестируем модель на различных горизонтах прогнозирования:

# Прогноз на 10 шагов вперед
forecast_horizon = 10
predictions = model.predict(forecast_horizon)

# Визуализация прогноза
predictions.plot(label='Прогноз N-BEATS', title='Прогнозирование на 10 шагов вперед')
series_log.plot(label='Исторические данные')

После получения прогнозов важно оценить их точность с помощью метрик:

from darts.utils.statistics import mean_absolute_error, mean_squared_error

# Оценка точности
mae = mean_absolute_error(series_log[-forecast_horizon:], predictions)
rmse = np.sqrt(mean_squared_error(series_log[-forecast_horizon:], predictions))

print(f"MAE: {mae:.4f}")
print(f"RMSE: {rmse:.4f}")

Также стоит визуализировать остатки:

# Остатки
residuals = series_log[-forecast_horizon:] - predictions
residuals.plot(title='Остатки прогноза')

Если остатки показывают закономерности, это сигнализирует о проблемах в модели.

Как улучшить прогноз

Если N-BEATS не даёт желаемых результатов, можно попробовать другие подходы:

Изменение модели: Попробуйте, например, модель Prophet.

from darts.models import Prophet

# Обучаем модель Prophet
prophet_model = Prophet()
prophet_model.fit(series)

# Прогнозируем
prophet_predictions = prophet_model.predict(forecast_horizon)

# Визуализация
prophet_predictions.plot(label='Прогноз Prophet')
series_log.plot(label='Исторические данные')

Тонкая настройка гиперпараметров: Используйте кросс-валидацию для подбора гиперпараметров.

from darts.models import NBEATSModel
from sklearn.model_selection import GridSearchCV

# Определяем параметры для поиска
param_grid = {
    'input_chunk_length': [10, 30],
    'output_chunk_length': [5, 10],
    'n_epochs': [100, 200],
}

# Настраиваем GridSearchCV
grid_search = GridSearchCV(estimator=NBEATSModel(), param_grid=param_grid, scoring='neg_mean_absolute_error')
grid_search.fit(series_log)

print(f"Лучшие параметры: {grid_search.best_params_}")

Анализ различных горизонтов: Прогнозируйтена разных интервалах и проверяйте точность.

for horizon in [1, 5, 10]:
    pred = model.predict(horizon)
    error = mean_absolute_error(series_log[-horizon:], pred)
    print(f"MAE для горизонта {horizon}: {error:.4f}")

Вот и всё! Darts предоставляет мощные инструменты для анализа временных рядов, так что не бойтесь экспериментировать с разными моделями и гиперпараметрами. Удачи вам и пусть ваши прогнозы всегда будут точными!

Подробнее с библиотекой можно ознакомиться здесь.


А в ближайшие дни пройдут открытые уроки по ML и CV, которые можно посетить бесплатно:

  • 7 октября: «Word2Vec — классика векторных представлений слов для решения задач текстовой обработки». Узнать подробнее

  • 10 октября: «OpenCV: Как Начать Работать с Компьютерным Зрением». Узнать подробнее

Комментарии (3)


  1. alexxxdevelop
    05.10.2024 11:14

    1. Какое оборудование требуется для использования и обучения таких моделей?

    2. На каких основаниях модель делает прогнозирование? Анализирует как-то по своему или в обучение можно закладывать паттерны технического анализа?


    1. badcasedaily1 Автор
      05.10.2024 11:14

      Для небольших моделей подойдет обычный ноутбук i5, 8ГБ ОЗУ. Для больших данных — лучше использовать GPU.

      Модели строят прогнозы на основе исторических данных. Паттерны теханализа можно заложить через features


      1. alexxxdevelop
        05.10.2024 11:14

        Ни на один вопрос не ответили

        1. Это любой школьник знает. Я спросил конкретно ваши модели из статьи вы тестировали на каком оборудовании? Или вы просто скопировали код откуда-то?

        2. Что за features? Это где?