Привет, Хабр!

Сегодня разберём RecBole — универсальный фреймворк на PyTorch, который отвечает на три насущных вопроса любого ML-инженера рекомендаций:

  • Как быстро обкатать десятки алгоритмов (от классического MF до SASRec и KGAT) на собственном датасете — без сотни скриптов?

  • Как хранить все настройки в одном YAML, а не в трёх сотнях аргументов CLI?

  • Как получить честное сравнение метрик и сразу вынести лучший чекпоинт в прод?

Рассмотрим подробнее под катом.

Установка и подготовка данных

pip install recbole>=1.2
python -m recbole.quick_start.run_recbole --model=BPR --dataset=ml-1m

У RecBole есть встроенная заготовка датасетов (ml-1m, yelp, amazon-*). Свой датасет кидаем в папку dataset/<name>/ в формате Atomic Files:

файл

обязательные поля

комментарий

<name>.inter

user_id, item_id, rating, timestamp

минимум две первые колонки

<name>.item

item_id, genre, year, …

любые side-фичи

<name>.user

user_id, age, city, …

optional

Parquet читается быстрее, но RecBole «проглатывает» и CSV.

Автоматический сплит

В recbole.yaml достаточно:

split_ratio: [0.8, 0.1, 0.1]   # train/valid/test
group_by_user: True            # чтобы у каждого юзера были все статусы

Всё, никаких ручных датафреймов на pandas.

Разбираемся с API

run_recbole

from recbole.quick_start import run_recbole

run_recbole(
    model='LightGCN',          # любая из 90+ моделей
    dataset='ml-1m',           # или путь к своему набору
    config_dict={              # приоритет над YAML и CLI
        'epochs': 50,
        'topk': 10,
        'neg_sampling': {'uniform': 1},
        'seed': 42,            # чтобы метрики не «плавали»
    }
)

Что происходит под капотом

Шаг

Вызов

Под капотом

Конфигурация

Config

собирает всё из recbole.yaml, аргументов CLI и config_dict, давая приоритет последнему. Можно вызвать config.save() и получить итоговый YAML для репродюса.

Дата

create_dataset

читает Atomic Files, авто-инференсит типы полей (int/float/token/sequence), пишет мета-JSON в processed/*.json.

Семплеры

create_sampler

строит Sampler (point-wise, pair-wise, full-sort). Хотите динамический негатив — передайте neg_sampling.dynamic: 1 и получите новый семплер без правки кода.

Лоадеры

create_dataloader

лениво подгружает батчи; для огромных данных ставьте lazy_loading: True, чтобы не держать всё в памяти.

Модель

Model

вытягивается рефлексией из recbole.model. Хотите кастом — наследуйтесь от BaseModel, регайте через register_model.

Тренер

Trainer

инициализирует оптимизатор/скедьюлер, early-stopping, логгер. Для knowledge distillation есть KnowledgeDistillationTrainer.

Эвалар

Evaluator

считает HR@K, NDCG@K, MRR, MAP; full_sort_topk ранжирует весь каталог, а не sampled-негативы.

Вывод

~

сохраняет лучший чек-пойнт + лог в /saved/LightGCN-<timestamp>/

Хотите логировать в W&B — добавьте wandb: True в YAML. Нужен mixed-precision — train_stage: fp16. Гиперпараметры через CLI: python run_recbole.py --learning_rate=5e-4 --dropout_prob=0.3.

Гранулярный контроль

Иногда однострочник — роскошь, и нужен доступ к каждому объекту. Тогда:

from recbole.config import Config
from recbole.data import create_dataset, data_preparation
from recbole.utils import init_seed
from recbole.model.general_recommender import LightGCN
from recbole.trainer import Trainer

# 1. Конфиг из файла + CLI
config = Config(model='LightGCN', dataset='ml-1m')      # читает recbole.yaml
config['epochs'] = 30                                   # оверрайд «на лету»

# 2. Dataset
init_seed(config['seed'])
dataset = create_dataset(config)                        # <RecDataset 1 1000209>

# 3. Sampler / Dataloader
train_data, valid_data, test_data = data_preparation(config, dataset)

# 4. Модель
model = LightGCN(config, dataset).to(config['device'])

# 5. Тренер
trainer = Trainer(config, model)
best_valid_score, best_valid_result = trainer.fit(
    train_data, valid_data, saved=True, show_progress=True)

score, result = trainer.evaluate(test_data, load_best_model=True)
print(result)     # {'Recall@10': 0.1627, 'NDCG@10': 0.0894, ...}

Config

# recbole.yaml (кусочек)
MODEL_TYPE: Sequential     # автоматически подскажет, что у модели есть max_seq_length
epochs: 40
neg_sampling:
  dynamic: 1
eval_args:
  mode: full                # full-sort evaluation
  order: RO                 # рейтинг -> онлайн
  split: {'RS': [0.8,0.1,0.1]}
checkpoint_dir: ./saved/
wandb: True

Переопределение при импорте:

cfg = Config(model='SASRec', dataset='ml-1m',
             config_dict={'epochs': 10, 'dropout_prob': 0.2})

Доступ к параметрам — по ключу: cfg['topk'], cfg.final_config_dict — готовый словарь для логирования.

Dataset и друзья

from recbole.data import Dataset
dset = Dataset(config)      # наследник torch.utils.data.Dataset
len(dset.field2type)        # {'user_id': 'token', 'item_id': 'token', ...}
  • Custom поля — добавьте колонку в .inter и опишите тип в YAML:

    FIELD_TYPES: {'price': float, 'brand': token}
  • Sequence -> unrolled. Для последовательных моделей (SASRec, GRU4Rec) RecBole сам создаёт hist_seq и target_item.

  • Lazy loading для >10 GB дат:

    lazy_loading: True

Самплеры и “кормушки”

from recbole.data import (
    create_samplers, create_dataloader, data_preparation
)

samplers = create_samplers(config, dataset)       # TrainSampler / FullSortSampler
train_loader, valid_loader, test_loader = create_dataloader(
    config, dataset, samplers)
  • Популярный негатив:

    neg_sampling:
      popularity: 1

    Внутри PopularitySampler — item-frequency softmax.

  • Dynamic Sampler считает свежие негативы каждую эпоху, спасая от информации-leakage.

Пишем свою модель

from recbole.model.abstract_recommender import GeneralRecommender
from recbole.model.loss import BPRLoss
import torch.nn as nn
import torch

class MyDotMF(GeneralRecommender):
    def __init__(self, config, dataset):
        super().__init__(config, dataset)
        self.embedding_size = config['embedding_size']
        self.user_embedding = nn.Embedding(
            dataset.num(self.USER_ID), self.embedding_size)
        self.item_embedding = nn.Embedding(
            dataset.num(self.ITEM_ID), self.embedding_size)
        self.loss_fct = BPRLoss()

    def forward(self, interaction):
        user = interaction[self.USER_ID]
        pos_item = interaction[self.ITEM_ID]
        user_e = self.user_embedding(user)
        item_e = self.item_embedding(pos_item)
        scores = (user_e * item_e).sum(-1)
        return scores

    def calculate_loss(self, interaction):
        pos_score = self.forward(interaction)
        neg_items = interaction[self.NEG_ITEM_ID]
        neg_e = self.item_embedding(neg_items)
        neg_score = (user_e.unsqueeze(1) * neg_e).sum(-1)
        return self.loss_fct(pos_score, neg_score)

Регистрируем:

from recbole.utils import register_model
register_model('MyDotMF', MyDotMF)

Теперь в YAML достаточно model: MyDotMF.

Минимальный кейс:

mkdir -p dataset/shop
python - <<'PY'
import pandas as pd, pyarrow.parquet as pq
df = pq.read_table('orders.parquet').to_pandas()
df[['user_id','sku','ts']].to_csv(
    'dataset/shop/shop.inter', sep='\t', index=False)
PY

recbole.yaml:

field_separator: "\t"
USER_ID_FIELD: user_id
ITEM_ID_FIELD: sku
TIME_FIELD: ts

model: SASRec
epochs: 20
learning_rate: 1e-3
neg_sampling: ~
LABEL_FIELD: click
topk: 20
metrics: ['Recall', 'NDCG', 'MRR']
device: cuda

Запуск:

from recbole.quick_start import run_recbole
run_recbole(dataset='shop')

RecBole сам сделает сплит, залиогирует Recall@20, сохранит чек-пойнт и итоговый YAML в saved/SASRec-shop-<ts>/.

Фичи

Правильный neg_sampling — бесплатный буст к NDCG

neg_sampling:
  uniform: 1
# или
  popularity: 1
# или
  dynamic: 1            # поддерживается с v1.2

dynamic может давать +5 % NDCG@10 vs uniform.

Knowledge-Graph модели

Если берёте KGAT/CFKG/TransRec, добавляйте файл графа:

knowledge_graph_file: shop.kg

Формат тривиальный: head relation tail. RecBole сам построит adjacency matrix.

GPU-OOM ловушка

Параметр train_batch_size умноженный на количество GPU → ваша фактическая матрица эмбеддингов. Когда загоняете SASRec на A100 40 GB, не забывайте, что скрытая матрица self-attention растёт квадратично от max_seq_length.

train_batch_size: 512     # ок
max_seq_length: 200       # ок
n_layers: 4               # ок

Уехали в 1024×512×6 — здравствуй, CUDA OOM.

Экспорт в прод

torch.save(model.state_dict(), 'lightgcn.pt')
# inference
model = LightGCN(config, dataset)
model.load_state_dict(torch.load('lightgcn.pt', map_location='cpu'))
model.eval()

Никаких RecBole-зависимостей в рантайме: чистый PyTorch внутри Docker.

Итоги

RecBole закрывает 80 % типовых задач ресёрча и «ML-прототипов» в одном пакете: вам остаётся только решать, какую модель кормить продакшену. Да, бывают кейсы, где нужен Sparkили multi-tower архитектура под рекламу – тогда пляшем руками. Но для большинства продуктовых рекомендателей «поднять бейзлайн» быстрее RecBole сегодня мало что умеет.


Если вы работаете с рекомендательными системами или только собираетесь внедрять их в продукт, обязательно загляните в RecBole — мощный фреймворк на PyTorch, который закрывает до 80% задач ресёрча и ML‑прототипирования «из коробки». Поддержка 90+ моделей, единый YAML для всей конфигурации, автоматическая обработка данных, гибкий negative sampling и честные метрики — всё это помогает не тратить время на рутину и быстрее выходить в прод.

Чтобы разобраться в возможностях RecBole и не тратить недели на документацию, присоединяйтесь к нашему циклу открытых уроков:

Каждый урок — это практическое погружение: от запуска бейзлайна на своём датасете до кастомизации модели и экспорта в inference. Присоединяйтесь — и проверьте на практике, насколько RecBole может упростить вашу работу.

Комментарии (0)