Что важнее: создать продукт, или доставить его до пользователя? Оба этапа необходимы. Сегодня обсудим второй. Как нам построить поисковую e-com систему.
Покажем, что в слово логистика товара входят сложные задачи не только: перевезти наушники из Китая в Америку, но и настройка поисковой выдачи по запросу.
Быстро соберем поисковой MVP-сервис. Дообучим модель E5 на реальных данных от Amazon. Определим метрики качества и сравним BM25, pretrain E5 и fine-tune E5. Так же взглянем глазами с отладочной информацией и проанализируем изменения поисковых выдач.
И под конец обсудим каких технологий еще не хватает и можно добавить, если возникают соответствующие трудности.

⚠️ Дисклеймер
В первую очередь цель - построить MVP работающее решение, чтобы сразу начать проверять гипотезы. Предположим мы на хакатоне и хотим его выйграть, тогда нам нужно реализовать полноценную ML систему и у нас мало времени!
? Почему важно уметь самому быстро писать MVP решения?
Вы умеете бросать камень в воду так, чтобы пошло много блинчиков? На словах — просто: кинул и готово. А на деле — совсем нет.
Когда трогаешь задачу сам — сталкиваешься с трудностями, проходишь их а, самое главное, лучше всего запоминаешь и учишься!
Так устроена голова: она хуже учится на чужом опыте и гораздо лучше — на своём.
Также, когда сроки сжаты — ты действительно выбираешь самые нужные и рабочие решения в текущих условиях. Это не хаос, а практика. Даже есть подход — фреймворк Cynefin, который говорит: в разных ситуациях решения принимаются по-разному.
Например, тебе могут сказать, что инференс модели обязательно нужно делать на GPU. Но это не всегда так: на CPU можно проще и дешевле вырасти горизонтально.
Или вот ещё пример: не всегда нужен индекс по базе товаров. Иногда прямой проход по всем товарам работает не хуже, а проблем с перестроением и лишним кодом — меньше.
Именно при создании работающего решения — под ограничениями, с фидбеком и ошибками — вырабатывается интуиция, что, где и когда применять.
Та самая инженерная чуйка — она не из книжек. Она из практики.
? На скорю руку получается плохой код. Как этот навык поможет мне в работе?
На самом деле — скорее наоборот. С каждым завершённым проектом, с каждым неверным решением ты отбрасываешь лишнее и оставляешь только нужное. Ты уже знаешь, что тут срезал и писал некачественный код и потом спотыкался и тут теперь просто эффективней написать качественно!
Хочешь не хочешь — среднее качество твоего кода монотонно растёт.
Это касается и самого кода: со временем ты начинаешь писать «рукой набитой» — не гадаешь, норм или не норм, ты просто знаешь.
И архитектуры: появляется интуиция, что заложить сейчас, чтобы потом не соединять всё костылями.
Вдруг вы увидели костыль. Жёсткий, странный, специфичный. Это плохой программист?
Я бы не спешил кидаться помидорами.
Сделать систему, которая охватывает десятки аспектов — сложно, особенно в одиночку.
Если человек в каждом шаге пишет костыли и не понимает, что делает — скорее всего, он просто, скорее всего, не дойдёт до финала.
Он споткнётся на каком-нибудь gpt
-запросе, не сможет разгрести баг — и всё закончится.
Но если продукт работает, решает задачу, и в нём жёсткие, но точечные костыли — скорее всего, программист сделал это осознанно. Чтобы побыстрее собрать MVP.
И, вполне возможно, он уже держит в голове, как это потом исправить.
? Настоящий навык — это находить трейд-офф:
между "красивой архитектурой" и "работающим, понятным, поддерживаемым кодом".
Потому что иначе может получиться вот это: пример, как автор книги по чистому коду написал Android-приложение, которое падает на первом юзкейсе
Приступим к реализации!
Постановка задачи
Нужно собрать MVP-сервис с нашей обученной моделью для поиска товаров по текстовому запросу (E-commerce Product Search), общий вид задачи - IR (Information retrieval ). Цель — не просто совпадение по словам, а понимание e-com интента запроса. Проблема в том, что:
Названия товаров часто не говорят прямо, что за продукт (например, если продается книга - не всегда есть информация в названии об этом)
Формулировки запросов нестабильны и субъективны
Релевантность — не бинарная, есть «почти то» и «вообще мимо»
Одинаковое описание товаров не означает их одинаковую релевантность
Коротко опишем решение: будем дообучать DSSM (в качестве единой одной башни и стартовой точки возьмем pretrain E5) на данных от Amazon. Web сервис с отладочной информацией напишем на Streamlit.
Датасет: Amazon ESCI
Берём датасет Amazon ESCI с HuggingFace. Он содержит пользовательские запросы и связанные с ними товары, размеченные по степени релевантности: exact
, substitute
, irrelevant
.
Почему он хорош:
Данные из реального e-commerce каталога,
Много шума — названия товаров неполные или обобщённые,
Запросы разноплановые: от технических до бытовых,
Релевантность размечена вручную.
Подходит для проверки качества поиска и дообучения моделей.
Пример данных
Каждая строка в датасете — это тройка: запрос пользователя, название товара, и оценка релевантности.
{
"query": "usb c charger for iphone",
"product_title": "USB C Charger, 20W PD Wall Charger Block Compatible with iPhone 13 12 11",
"product_description": "...",
"product_id": "B08K2S1NP5",
"query_id": "A1",
"product_locale": "us",
"esci_label": "exact"
}
Возможные значения esci_label
:
exact
— полностью соответствует запросу,substitute
— близкий аналог, но не точное совпадение,complement
— сопутствующий товар (только в тестовой выборке),irrelevant
— не подходит.
В тренировочной выборке используются только: exact
, substitute
, irrelevant
.
История возникновения датасета
Он был создан для KDD Cup 2022 / Amazon Shopping Queries Challenge, цель — улучшить поиск товаров в реальной e‑commerce среде.
Охватывает запросы на трёх языках — английский, японский и испанский.
Каждый запрос имеет до 40 товаров с ручной разметкой: Exact, Substitute, Complement, Irrelevant.
Всего ~130 000 уникальных запросов и ≈ 2.6 млн помеченных пар (query, product) для полного датасета, в уменьшенной версии — ~48 000 запросов и ~1.1 млн пар. Пример статистики — это отражено в описании репозитория Amazon ESCI на GitHub.
? Релевантность размечена не по наличию ключевых слов, а по фактическому смыслу запроса и товара, что добавляет реалистичный «шум»: описания часто неполные, могут быть маркетинговыми, есть синонимы, обобщения и пр.
Простенький EDA (разведочный анализ)
Качаем датасет
from datasets import load_dataset
# https://huggingface.co/datasets/tasksource/esci
# Загружаем тренировочную часть
dataset = load_dataset("tasksource/esci", split="train")
Смотрим на имеющиеся поля
import pandas as pd
df = dataset.to_pandas()
# Смотрим основные поля
print(df.columns)
# Пример строки
print(df[['query', 'product_title', 'esci_label']].sample(5))
# Распределение по меткам
print("Метки релевантности:")
print(df['esci_label'].value_counts())
# Распределение по языку
print("Локали:")
print(df['product_locale'].value_counts())
Рисуем распределение длинн запросов/названий товаров/описаний
# Распределение длин
df['query_len'] = df['query'].str.split().str.len()
df['title_len'] = df['product_title'].str.split().str.len()
df['desc_len'] = df['product_description'].str.split().str.len()
df[['query_len', 'title_len', 'desc_len']].hist(bins=30, figsize=(12, 4))

Среднее число релевантных товаров на запрос
import matplotlib.pyplot as plt
# Оставим только позитивные примеры
positive_df = df[df['esci_label'].isin(['Exact', 'Substitute'])]
# Считаем количество релевантных товаров на каждый запрос
relevant_counts = positive_df.groupby('query')['product_id'].nunique()
# Статистика
print("? Среднее число релевантных товаров на запрос:", relevant_counts.mean())
print("? Распределение (включая топ-10):")
print(relevant_counts.value_counts().head(10))
# Гистограмма
plt.figure(figsize=(10, 4))
relevant_counts.hist(bins=60)
plt.title("Распределение количества релевантных товаров на запрос")
plt.xlabel("Кол-во релевантных товаров")
plt.ylabel("Частота")
plt.grid(True)
plt.show()

далее подробнее в тетрадке...
Как оцениваем качество?
Датасет состоит из строчек (query, pos, neg)
и в дальнейшим разобьется по батчам. Для каждого позитива берем все остальные pos
и neg
из батча в качестве негативных примеров.
Считаем Recall@K — ищем, попал ли pos
в топ-K предсказания модели (если да метрика равна 1
, иначе 0
) и усредняем по батчу.
BM25 как бейзлайн
BM25 — развитие идеи TF-IDF: учитывает, насколько слово важно в документе, но дополнительно нормализует по длине и логарифмически сглаживает веса.
Используем rank_bm25
— индексируем product_title
, ищем по query
, ранжируем по скору.
Считаем бейзлайн метрику (пока по всем товарам, потом будем по батчам)
Токенизируем данные
import nltk
from nltk.tokenize import word_tokenize
import os
NLTK_DATA_PATH = os.path.expanduser('~/nltk_data')
os.makedirs(NLTK_DATA_PATH, exist_ok=True)
nltk.download('punkt_tab', download_dir=NLTK_DATA_PATH)
nltk.data.path.append(NLTK_DATA_PATH)
# Простой тест — если работает, всё ок
print(word_tokenize("This is a test."))
# ['This', 'is', 'a', 'test', '.']
Считаем метрику для BM25
import pandas as pd
import numpy as np
from nltk.tokenize import word_tokenize
from rank_bm25 import BM25Okapi
from tqdm import tqdm
import time
# ——— 0. Удобная токенизация текста ———
def tokenize(text: str) -> list[str]:
return word_tokenize(text.lower())
# ——— 1. Подготовка корпуса товаров ———
print("? Подготовка коллекции товаров...")
start = time.perf_counter()
products = df['product_title'].unique().tolist()
tokenized_products = [tokenize(p) for p in tqdm(products, desc="? Токенизация товаров")]
bm25 = BM25Okapi(tokenized_products)
print(f"✅ Индексация BM25 завершена за {time.perf_counter() - start:.2f} сек")
# ——— 2. Универсальный BM25 поисковик ———
class BM25Search:
def __init__(self, bm25_index, docs):
self.bm25 = bm25_index
self.docs = docs
def top_k(self, query: str, k: int = 10) -> list[str]:
tokens = tokenize(query)
scores = self.bm25.get_scores(tokens)
top_idx = np.argsort(scores)[-k:][::-1]
return [self.docs[i] for i in top_idx]
# ——— 3. Универсальная метрика Recall@k ———
def evaluate_recall_at_k(queries, true_items, searcher, k=10):
hits = 0
for q, true_item in tqdm(zip(queries, true_items), total=len(queries), desc="? Инференс + метрика"):
predicted = searcher.top_k(q, k)
if true_item in predicted:
hits += 1
return hits / len(queries)
# ——— 4. Пример запуска ———
print("? Формируем валидационный сет...")
eval_df = df[df['esci_label'].isin(['Exact', 'Substitute'])].sample(200, random_state=42)
queries = eval_df['query'].tolist()
true_products = eval_df['product_title'].tolist()
print("? Запуск поиска и оценка Recall@10...")
start = time.perf_counter()
bm25_searcher = BM25Search(bm25, products)
recall = evaluate_recall_at_k(queries, true_products, bm25_searcher, k=10)
print(f"✅ Метрика рассчитана за {time.perf_counter() - start:.2f} сек")
print(f"\n? Recall@10 (BM25 baseline): {recall:.4f}")
# ? Подготовка коллекции товаров...
# ? Токенизация товаров: 100%|██████████| 1423918/1423918 [01:50<00:00, 12909.02it/s]
# ✅ Индексация BM25 завершена за 120.43 сек
# ? Формируем валидационный сет...
# ? Запуск поиска и оценка Recall@10...
# ? Инференс + метрика: 100%|██████████| 200/200 [03:37<00:00, 1.09s/it]
# ✅ Метрика рассчитана за 217.02 сек
# ? Recall@10 (BM25 baseline): 0.1750
Алгоритм не учитывает смысл и тем более не смотрим на популярность от пользователей. Он лишь ищет пересечения части запроса и товара и хитро считает скор. На как бейзлайн очень легко заводится и показывает что-то адекватное!
Какую модель взять? Какой pretrain? E5 !
Используем E5 (intfloat/multilingual-e5-small) — bi-encoder на базе BERT с сиамской архитектурой: один энкодер обрабатывает запрос и документ отдельно, с префиксами query:
и passage:
.
Модель обучена на 1,2 млрд пар "запрос–документ" на 100+ языках. В обучении использовались Common Crawl, Wikipedia, MS MARCO, BEIR и другие датасеты. E5 Technical Report (2024, Microsoft)
Работает из коробки, эффективно дообучается и даёт стабильные результаты для поиска.
Сравнение с альтернативами (SBERT, T5, GTR, Contriever)
? SBERT (2019, UKPLab)
Siamese-сеть на основе BERT. Обучалась на задачах парафразов и NLI: модель учится распознавать, насколько два предложения схожи по смыслу. Проста в использовании, но требует выбора подходящей предобученной версии под конкретную задачу.
? sbert.net | Paper
? GTR (2021, Google)
Dual-encoder на архитектуре T5. Обучен на query–document парах, показывает отличные результаты на BEIR-бенчмарках. Тяжёлый по ресурсам и в основном англоязычный, но может превосходить другие модели по качеству.
? Paper
? Что такое T5? (Первая версия статьи 2019)
T5 (Text-To-Text Transfer Transformer) — seq2seq модель от Google, которая формулирует все NLP-задачи как «текст → текст».
Например:
"translate English to German: house" → "Haus"
В GTR используется только энкодер от T5 — он кодирует query и документ по отдельности в эмбеддинги, как в bi-encoder архитектуре.
? T5 paper
? Contriever (2021, Meta AI)
Dense retriever на BERT, обучен без разметки (self-supervised) — просто на соседних предложениях из Википедии. Работает быстро, не требует аннотированных данных, но хуже справляется с пониманием запроса. Подходит для англоязычных сценариев.
? GitHub | Paper
Инферим pretrain E5
Для подсчета скоров релевантности под запрос, нужен в оффлайне переобойти все продукты.
Делаем без векторного индекса. Будем под запрос считать скор со всей базой.
Считаем вектора(и статистику) для каждого товара из базы
Качаем модель и сохраняем в свою обертку.
import torch
import numpy as np
from transformers import AutoTokenizer, AutoModel
# https://huggingface.co/intfloat/multilingual-e5-small
# --- Класс для инференса батчей ---
class E5InferenceModel:
def __init__(self, model_name='intfloat/e5-small', device=None):
self.device = device or ('cuda' if torch.cuda.is_available() else 'cpu')
self.tokenizer = AutoTokenizer.from_pretrained(model_name)
self.model = AutoModel.from_pretrained(model_name).to(self.device)
self.model.eval()
def encode_batch(self, input_ids, attention_mask):
with torch.no_grad():
outputs = self.model(input_ids=input_ids.to(self.device), attention_mask=attention_mask.to(self.device))
return outputs.last_hidden_state[:, 0].cpu().numpy()
model_name = 'intfloat/e5-small'
device = 'cuda:6' if torch.cuda.is_available() else 'cpu'
inference_model = E5InferenceModel(model_name=model_name, device=device)
# Сохраняем модель
# Путь для сохранения
SAVE_DIR = "saved_e5_model"
# Сохраняем модель и токенизатор
inference_model.model.save_pretrained(SAVE_DIR)
inference_model.tokenizer.save_pretrained(SAVE_DIR)
print(f"✅ Модель и токенизатор сохранены в: {SAVE_DIR}")
Загружаем модель в нашей обертке
from transformers import AutoModel, AutoTokenizer
import torch
# Путь к сохранённой модели
MODEL_DIR = "saved_e5_model"
# Загрузка токенизатора и модели
tokenizer = AutoTokenizer.from_pretrained(MODEL_DIR)
model = AutoModel.from_pretrained(MODEL_DIR)
model.eval() # Обязательно для инференса
device = torch.device('cuda:6' if torch.cuda.is_available() else 'cpu')
model = model.to(device)
Написано грязно, могли бы и класс переиспользовать с инференсом батча, но как есть.
def encode_texts(texts, prefix="query", batch_size=32):
"""
Кодирует список текстов в эмбеддинги (используется [CLS] токен).
prefix: "query" или "passage" (для правильного шаблона).
"""
embeddings = []
for i in range(0, len(texts), batch_size):
batch = texts[i:i+batch_size]
inputs = tokenizer(
[f"{prefix}: {text}" for text in batch],
padding=True,
truncation=True,
return_tensors='pt'
).to(device)
with torch.no_grad():
output = model(**inputs)
cls_emb = output.last_hidden_state[:, 0] # [CLS] токен
embeddings.append(cls_emb.cpu())
return torch.cat(embeddings, dim=0).numpy()
# Пример запроса
queries = ["wireless bluetooth headphones", "usb-c charging cable"]
query_embs = encode_texts(queries, prefix="query")
print("✅ Эмбеддинги запросов:", query_embs.shape)
# ✅ Эмбеддинги запросов: (2, 384)
Сохраняем ембеды товаров
# Загрузка модели
print("? Загружаем сохранённую модель...")
model = AutoModel.from_pretrained(MODEL_PATH).eval().cuda()
tokenizer = AutoTokenizer.from_pretrained(MODEL_PATH)
# Кодирование и сохранение
product_embs = encode_texts(product_titles, batch_size=128)
np.save(EMBEDS_PATH, product_embs)
print("✅ Эмбеддинги и названия товаров сохранены.")
Так же статистику встречаемости в датасете посчитаем для каждого товара
import pandas as pd
import os
# добавляем статистику показов для каждого товара
# ——— Загружаем датасет, если еще не загружен ———
if not os.path.exists(PRODUCTS_PATH) or not os.path.exists("product_stats.csv"):
print("? Загружаем датасет с Hugging Face...")
dataset = load_dataset("tasksource/esci", split="train")
df = pd.DataFrame([x for x in tqdm(dataset, desc="? Преобразуем в DataFrame")])
# Уникальные товары
product_titles = df['product_title'].dropna().unique().tolist()
pd.Series(product_titles).to_csv(PRODUCTS_PATH, index=False)
# ——— Считаем количество показов каждого товара ———
print("? Считаем статистику по товарам...")
stats = df['product_title'].value_counts().reset_index()
stats.columns = ['product_title', 'views']
stats.to_csv("product_stats.csv", index=False)
print(f"✅ Сохранили статистику для {len(stats)} товаров")
Пишем поисковой web-сервис
Набрасываем простенький web на Streamlit. Указываем модель, которую векторизует запрос и продукт. Указываем запрос, получаем выдачи с скорами, общей статистико скоров и кол-во показов товара пользователям.

Код web-сервиса
import streamlit as st
import numpy as np
import pandas as pd
import torch
from transformers import AutoTokenizer, AutoModel
from time import perf_counter
import matplotlib.pyplot as plt
import seaborn as sns
# Это очень плохо так писать, не бейти
class E5Model(torch.nn.Module):
def __init__(self, model_name='intfloat/e5-small'):
super().__init__()
self.encoder = AutoModel.from_pretrained(model_name)
self.tokenizer = AutoTokenizer.from_pretrained(model_name)
def encode(self, input_ids, attention_mask):
outputs = self.encoder(input_ids=input_ids, attention_mask=attention_mask)
return outputs.last_hidden_state[:, 0]
def forward(self, anchor_ids, anchor_mask, pos_ids, pos_mask, neg_ids, neg_mask):
anchor_emb = self.encode(anchor_ids, anchor_mask)
pos_emb = self.encode(pos_ids, pos_mask)
neg_emb = self.encode(neg_ids, neg_mask)
return anchor_emb, pos_emb, neg_emb
# --- Константы ---
PRETRAIN_MODEL_PATH = "saved_e5_model"
# FT_MODEL_PATH = "checkpoints/e5_train_20250703_170142.pt" # lr=2e-5
FT_MODEL_PATH = "checkpoints/e5_train_20250703_173139.pt" # lr=1e-4 + 3 epochs
# FT_MODEL_PATH = "checkpoints/e5_train_20250703_175235.pt" # lr=1e-4 + 10 epochs
PRODUCTS_PATH = "product_titles.csv"
# EMBEDS_PATH = "small_product_embeddings.npy"
EMBEDS_PATH = "product_embeddings.npy"
STATS_PATH = "product_stats.csv" # CSV с колонками ['product_title', 'views']
# --- Выбор модели ---
st.sidebar.markdown("? **Выбор модели**")
model_choice = st.sidebar.selectbox(
"Модель для поиска:",
options=["Дообученная (saved_e5_model)", "Предобученная (intfloat/e5-small)"]
)
# --- Загрузка модели и данных (кешируется) ---
@st.cache_resource
def load_model_and_data(model_choice):
if model_choice == "Дообученная (saved_e5_model)":
model = E5Model(model_name="intfloat/e5-small") # init из того же архетипа
model.load_state_dict(torch.load(FT_MODEL_PATH))
tokenizer = model.tokenizer
model = model.encoder.eval().cuda() # достаём encoder
else:
tokenizer = AutoTokenizer.from_pretrained(PRETRAIN_MODEL_PATH)
model = AutoModel.from_pretrained(PRETRAIN_MODEL_PATH).eval().cuda()
product_titles = pd.read_csv(PRODUCTS_PATH).squeeze().tolist()
product_embs = np.load(EMBEDS_PATH)
# Загрузка статистики
product_stats_df = pd.read_csv(STATS_PATH)
product_stats = dict(zip(product_stats_df['product_title'], product_stats_df['views']))
return tokenizer, model, product_titles, product_embs, product_stats
tokenizer, model, product_titles, product_embs, product_stats = load_model_and_data(model_choice=model_choice)
total_products = len(product_titles)
# --- Функция кодирования запроса ---
def encode_query(query):
inputs = tokenizer([f"query: {query}"], return_tensors="pt", truncation=True, padding=True)
with torch.no_grad():
outputs = model(
input_ids=inputs["input_ids"].cuda(),
attention_mask=inputs["attention_mask"].cuda()
)
return outputs.last_hidden_state[:, 0].cpu().numpy()[0]
# --- Интерфейс ---
st.title("? Поиск по товарам (E5)")
st.markdown(f"?️ **Всего товаров в базе:** `{total_products}`")
query = st.text_input("Введите запрос:", value="wireless headphones")
top_k = st.slider("Сколько результатов показать:", 1, 30, 10)
if query:
with st.status("? Обработка запроса... Пожалуйста, подождите.", expanded=True) as status:
# --- Этап 1: Кодирование запроса ---
t1 = perf_counter()
query_emb = encode_query(query)
query_time = perf_counter() - t1
query_emb /= np.linalg.norm(query_emb)
prod_embs_norm = product_embs / np.linalg.norm(product_embs, axis=1, keepdims=True)
# --- Этап 2: Поиск по базе ---
t2 = perf_counter()
scores = np.dot(prod_embs_norm, query_emb)
search_time = perf_counter() - t2
# --- Этап 3: Top-K ранжирование ---
top_idx = np.argsort(scores)[-top_k:][::-1]
scores_top = scores[top_idx]
# --- Интерфейс: две колонки ---
col1, col2 = st.columns([2, 1])
# --- Левая колонка: выдача товаров ---
with col1:
st.subheader("? Результаты:")
for i in top_idx:
title = product_titles[i]
score = scores[i]
views = product_stats.get(title, 0)
st.markdown(
f"""
<div style='
background: #1c1c1c;
padding: 10px 15px;
margin: 8px 0;
border-left: 4px solid #52c41a;
border-radius: 6px;
'>
<div style='font-size: 16px; font-weight: bold; color: #ffffff;'>{title}</div>
<div style='margin-top: 4px; color: #cccccc; font-size: 14px;'>
? <span style='color:#52c41a;'>score: {score:.3f}</span>
?️ <span style='color:#f4d35e;'>{views:,} показов</span>
</div>
</div>
""",
unsafe_allow_html=True
)
st.markdown("---")
st.markdown(f"⏱️ Время кодирования запроса: `{query_time:.4f} сек`")
st.markdown(f"? Время поиска по базе: `{search_time:.4f} сек`")
# --- Правая колонка: анализ ---
with col2:
st.markdown("### ? Анализ")
fig, ax = plt.subplots(figsize=(5, 3))
sns.histplot(scores, bins=50, kde=True, ax=ax, color='skyblue', label='Все товары')
ax.axvline(scores_top.min(), color='green', linestyle='--', label='Min Top-K')
ax.axvline(scores_top.max(), color='orange', linestyle='--', label='Max Top-K')
ax.set_xlabel("Score")
ax.set_title("Распределение")
ax.legend()
st.pyplot(fig)
st.markdown("### ? Статистика")
stats_dict = {
"Min": float(np.min(scores)),
"Max": float(np.max(scores)),
"Mean": float(np.mean(scores)),
"Median": float(np.median(scores)),
"Std": float(np.std(scores)),
"5%": float(np.percentile(scores, 5)),
"25%": float(np.percentile(scores, 25)),
"75%": float(np.percentile(scores, 75)),
"95%": float(np.percentile(scores, 95)),
}
stats_df = pd.DataFrame(stats_dict, index=["Score"]).T
st.dataframe(stats_df.style.format("{:.4f}"))
status.update(label="✅ Готово!", state="complete", expanded=True)
# Для запуска:
# streamlit run app.py
Приступаем к дообучению на пользовательской активности!
Учим как DSSM только одна модель векторизует две башни с префиксом query:
и passage:
на TripletMarginLoss. В качестве модели и претрайна берем E5.
Учу на одной A100
Экспериментирую с
lr
. Дляlr=0.2 - 0.3
модель разносит и уже бесповоротно портиться. Дляlr=0.4 - 0.5
модель успешно учится (и учиться дальше, не выходит на плато!)




Создаем Dataset
Варим датасет с триплетами (query, pos, neg)
import pandas as pd
import torch
from torch.utils.data import Dataset, DataLoader
from transformers import AutoTokenizer, AutoModel
import pytorch_lightning as pl
from torch.nn import TripletMarginLoss
from torch.nn.functional import normalize
from sklearn.model_selection import train_test_split
import numpy as np
import faiss
from tqdm import tqdm
from datetime import datetime
import os
import random
# ——— 1. Токенизация и подготовка данных ———
# def dummy_tokenize(text: str):
# return text.lower()
class TripletDataset(Dataset):
def __init__(self, df, tokenizer, max_length=64, num_negatives=10):
self.samples = []
self.tokenizer = tokenizer
self.max_length = max_length
self.num_negatives = num_negatives
self.all_products = df['product_title'].unique()
self._build_triplets(df)
def _build_triplets(self, df):
n = len(df)
for i in range(self.num_negatives):
negs = random.choices(self.all_products, k=n)
for idx, row in enumerate(df.itertuples(index=False)):
query = row.query
pos = row.product_title
neg = negs[idx]
self.samples.append((query, pos, neg))
def __len__(self):
return len(self.samples)
def __getitem__(self, idx):
q, pos, neg = self.samples[idx]
# Токенизация сразу для модели
anchor_enc = self.tokenizer(
f"query: {q}", padding='max_length', truncation=True,
max_length=self.max_length, return_tensors='pt'
)
pos_enc = self.tokenizer(
f"passage: {pos}", padding='max_length', truncation=True,
max_length=self.max_length, return_tensors='pt'
)
neg_enc = self.tokenizer(
f"passage: {neg}", padding='max_length', truncation=True,
max_length=self.max_length, return_tensors='pt'
)
# Возвращаем только input_ids и attention_mask, как нужно для forward
return (
(anchor_enc['input_ids'].squeeze(0), anchor_enc['attention_mask'].squeeze(0)),
(pos_enc['input_ids'].squeeze(0), pos_enc['attention_mask'].squeeze(0)),
(neg_enc['input_ids'].squeeze(0), neg_enc['attention_mask'].squeeze(0)),
f"query: {q}", # для дебага
f"passage: {pos}",
f"passage: {neg}"
)
Варим Dataloader и сплитим данные
from datasets import load_dataset
from tqdm import tqdm
import pandas as pd
from sklearn.model_selection import train_test_split
from torch.utils.data import DataLoader
from transformers import AutoTokenizer
import pytorch_lightning as pl
from datetime import datetime
# Предполагается, что TripletDataset и E5Model уже определены ранее
# ——— 1. Подготовка данных ———
def prepare_data(model_name='intfloat/e5-small', batch_size=16, sample_rate=1.0):
print("? Загружаем датасет tasksource/esci...")
dataset = load_dataset("tasksource/esci", split="train")
dataset_len = len(dataset)
# семплим чтобы быстрее отдебажить!
dataset_current_len = int(dataset_len * sample_rate)
dataset = dataset.shuffle(seed=42).select(range(dataset_current_len))
print("? Конвертируем в pandas DataFrame...")
df = pd.DataFrame([x for x in tqdm(dataset, desc="→ Преобразование строк")])
print("? Фильтруем классы: Exact / Substitute / Irrelevant...")
df = df[df['esci_label'].isin(['Exact', 'Substitute', 'Irrelevant'])]
print("? Удаляем запросы с < 2 примерами...")
query_counts = df['query'].value_counts()
df = df[df['query'].isin(query_counts[query_counts >= 2].index)]
print("✂️ Разбиваем на train/val...")
train_df, val_df = train_test_split(
df, test_size=0.1, random_state=42
# , stratify=df['query']
)
print(f"✅ Train size: {len(train_df)} / Val size: {len(val_df)}")
print(f"? Загружаем токенизатор: {model_name}")
tokenizer = AutoTokenizer.from_pretrained(model_name)
print("? Создаём TripletDataset'ы...")
train_dataset = TripletDataset(train_df, tokenizer)
val_dataset = TripletDataset(val_df, tokenizer)
print(f"? Train triplets: {len(train_dataset)} / Val triplets: {len(val_dataset)}")
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, drop_last=True)
val_loader = DataLoader(val_dataset, batch_size=batch_size, drop_last=True)
print(f"Train batches: {len(train_loader)} / Val batches: {len(val_loader)} | {batch_size=}")
return train_loader, val_loader, df
train_loader, val_loader, df = prepare_data(sample_rate=0.01, batch_size=batch_size)
# ? Загружаем датасет tasksource/esci...
# ? Конвертируем в pandas DataFrame...
# → Преобразование строк: 100%|██████████| 20278/20278 [00:03<00:00, 6354.40it/s]
# ? Фильтруем классы: Exact / Substitute / Irrelevant...
# ? Удаляем запросы с < 2 примерами...
# ✂️ Разбиваем на train/val...
# ✅ Train size: 3657 / Val size: 407
# ? Загружаем токенизатор: intfloat/e5-small
# ? Создаём TripletDataset'ы...
# ? Train triplets: 36570 / Val triplets: 4070
# Train batches: 285 / Val batches: 31 | batch_size=128
Учим и валидируем модель
Определяем обертку для обучения модели
class E5Model(torch.nn.Module):
def __init__(self, model_name='intfloat/e5-small'):
super().__init__()
self.encoder = AutoModel.from_pretrained(model_name)
self.tokenizer = AutoTokenizer.from_pretrained(model_name)
def encode(self, input_ids, attention_mask):
outputs = self.encoder(input_ids=input_ids, attention_mask=attention_mask)
return outputs.last_hidden_state[:, 0]
def forward(self, anchor_ids, anchor_mask, pos_ids, pos_mask, neg_ids, neg_mask):
anchor_emb = self.encode(anchor_ids, anchor_mask)
pos_emb = self.encode(pos_ids, pos_mask)
neg_emb = self.encode(neg_ids, neg_mask)
return anchor_emb, pos_emb, neg_emb
Определяем eval в время train
def eval_model(model, val_loader, device='cuda'):
model.eval()
model.to(device)
all_recalls = {k: [] for k in [1, 5, 10, 30]}
with torch.no_grad():
# for batch in tqdm(val_loader, desc="? Eval"):
for batch in val_loader:
(a_ids, a_mask), (p_ids, p_mask), (n_ids, n_mask), _, _, _ = batch
# Переносим на нужное устройство
a_ids, a_mask = a_ids.to(device), a_mask.to(device)
p_ids, p_mask = p_ids.to(device), p_mask.to(device)
n_ids, n_mask = n_ids.to(device), n_mask.to(device)
# Получаем эмбеддинги
anchor_embs = model.encode(a_ids, a_mask).cpu().numpy()
pos_embs = model.encode(p_ids, p_mask).cpu().numpy()
neg_embs = model.encode(n_ids, n_mask).cpu().numpy()
# Собираем "пул" продуктов (positive + negative)
product_embs = np.concatenate([pos_embs, neg_embs], axis=0)
recalls = RetrievalMetrics.recall_at_k_batch(anchor_embs, product_embs, k_list=k_list)
for k in recalls:
all_recalls[k].append(recalls[k])
# Усреднение и печать
all_means_recalls = {}
for k in all_recalls:
all_means_recalls[k] = np.mean(all_recalls[k])
return all_means_recalls
Определяем обучение + логирования в tensorboard + сохранение модели
import os
CHECKPOINTS_DIR = "checkpoints"
os.makedirs(CHECKPOINTS_DIR, exist_ok=True)
def train_model(model, train_loader, val_loader, num_epochs=3, lr=2e-5, device='cuda', every_n_step_do_val=50):
time_suffix = str(datetime.now().strftime('%Y%m%d_%H%M%S'))
run_name = f"e5_train_{time_suffix}" # and model_name!
writer = SummaryWriter(log_dir=f"runs/{run_name}")
model.to(device)
optimizer = torch.optim.AdamW(model.parameters(), lr=lr)
loss_fn = TripletMarginLoss(margin=0.2)
global_step = 1
for epoch in range(num_epochs):
model.train()
for batch_id, batch in tqdm(enumerate(train_loader), desc=f"?️ Epoch {epoch + 1}/{num_epochs}", total=len(train_loader)):
writer.add_scalar("train/epoch_marker", epoch, global_step)
(a_ids, a_mask), (p_ids, p_mask), (n_ids, n_mask), _, _, _ = batch
a_ids, a_mask = a_ids.to(device), a_mask.to(device)
p_ids, p_mask = p_ids.to(device), p_mask.to(device)
n_ids, n_mask = n_ids.to(device), n_mask.to(device)
anchor, pos, neg = model(a_ids, a_mask, p_ids, p_mask, n_ids, n_mask)
loss = loss_fn(anchor, pos, neg)
optimizer.zero_grad()
loss.backward()
optimizer.step()
grad_norm = torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1e9)
writer.add_scalar("train/loss", loss.item(), global_step)
writer.add_scalar("GradNorm/train", grad_norm, global_step)
writer.add_scalar("LR/train", optimizer.param_groups[0]['lr'], global_step)
train_recalls = eval_model(model, [batch], device=device)
for k, val in train_recalls.items():
writer.add_scalar(f"train/recall@{k}", val, global_step)
if batch_id % every_n_step_do_val == 0:
recalls = eval_model(model, val_loader, device=device)
for k, val in recalls.items():
writer.add_scalar(f"val/recall@{k}", val, global_step)
global_step += 1
print("✅ Обучение завершено. Сохраняем модель...")
save_path = os.path.join(CHECKPOINTS_DIR, f"{run_name}.pt")
torch.save(model.state_dict(), save_path)
print(f"? Модель сохранена в: {save_path}")
Запускаем само обучение
# train_loader, val_loader, _ = prepare_data(batch_size=16, sample_rate=0.05)
model = E5Model(model_name='intfloat/e5-small')
train_model(model, train_loader, val_loader, num_epochs=3, device='cuda:6', lr=1e-4, every_n_step_do_val=10)
# ?️ Epoch 1/10: 100%|██████████| 285/285 [04:22<00:00, 1.09it/s]
# ?️ Epoch 2/10: 100%|██████████| 285/285 [04:22<00:00, 1.08it/s]
# ...
# ?️ Epoch 9/10: 100%|██████████| 285/285 [04:24<00:00, 1.08it/s]
# ?️ Epoch 10/10: 100%|██████████| 285/285 [04:23<00:00, 1.08it/s]
# ✅ Обучение завершено. Сохраняем модель...
# ? Модель сохранена в: checkpoints/e5_train_20250703_175235.pt
Сравниваем метрики BM25 & pretrain E5 & fine-tune E5
Видим, что дообучение успешно схватывает пользовательскую активность!
Если обучать дольше, то метрики не уходят на плато! Они еще лучше начинают учитывать online сигнал. Тут уже нужно заключать трейдофф между кликами и семантической релевантностью.
Model |
Recall@1 |
Recall@5 |
Recall@10 |
Recall@30 |
---|---|---|---|---|
BM25 |
0.5731 |
0.6822 |
0.7145 |
0.7631 |
E5 (pretrain) |
0.2686 |
0.5129 |
0.5862 |
0.6971 |
E5 (fine-tuned) |
0.6413 |
0.8750 |
0.9270 |
0.9660 |
Считаем метрики на val Для BM25
Функция для подсчета метрики
# --- Класс для метрик внутри батча ---
class RetrievalMetrics:
@staticmethod
def recall_at_k_batch(anchor_embs, product_embs, k_list=[5, 10, 30]):
recalls = {k: 0 for k in k_list}
n = len(anchor_embs)
for i, a_emb in enumerate(anchor_embs):
scores = np.dot(product_embs, a_emb)
top_indices = np.argsort(scores)[-max(k_list):][::-1]
for k in k_list:
# Positive всегда на позиции i (по построению TripletDataset)
if i in top_indices[:k]:
recalls[k] += 1
for k in k_list:
recalls[k] /= n
return recalls
BM25 на батчах
from rank_bm25 import BM25Okapi
from nltk.tokenize import word_tokenize
def tokenize(text):
return word_tokenize(text.lower())
all_products = df['product_title'].dropna().unique().tolist()
all_products = [f"passage: {p}" for p in all_products]
tokenized_products_all = [tokenize(p) for p in all_products]
bm25_full = BM25Okapi(tokenized_products_all)
######
from collections import defaultdict
bm25_recalls = defaultdict(list)
k_list = [1, 5, 10, 30]
product_idx_map = {title: i for i, title in enumerate(all_products)}
for batch in tqdm(val_loader, desc="? BM25 на батчах"):
_, _, _, queries, pos_titles, neg_titles = batch
# --- 1. Подготовим список из продуктов текущего батча ---
batch_products = pos_titles + neg_titles
# --- 2. Индексы этих товаров в all_products (для score фильтрации) ---
batch_indices = [product_idx_map[p] for p in batch_products if p in product_idx_map]
for q, true_title in zip(queries, pos_titles):
q_tokens = tokenize(q)
scores_all = bm25_full.get_scores(q_tokens)
# --- 3. Оставим только скоры товаров из текущего батча ---
scores_batch = [(i, scores_all[i]) for i in batch_indices]
top_indices = sorted(scores_batch, key=lambda x: x[1], reverse=True)
top_titles = [all_products[i] for i, _ in top_indices]
for k in k_list:
bm25_recalls[k].append(int(true_title in top_titles[:k]))
# --- Усреднение ---
for k in k_list:
print(f"Recall@{k} (BM25): {np.mean(bm25_recalls[k]):.4f}")
# Recall@1 (BM25): 0.5731
# Recall@5 (BM25): 0.6822
# Recall@10 (BM25): 0.7145
# Recall@30 (BM25): 0.7631
Для pretrain E5
pretrain E5
# --- Пример использования с val_loader ---
model_name = 'intfloat/e5-small'
device = 'cuda:6' if torch.cuda.is_available() else 'cpu'
inference_model = E5InferenceModel(model_name=model_name, device=device)
all_recalls = {k: [] for k in [1, 5, 10, 30]}
k_list = [1, 5, 10, 30]
for batch in tqdm(val_loader, desc="? Pretrain E5 на батчах"):
(anchor_ids, anchor_mask), (pos_ids, pos_mask), (neg_ids, neg_mask), q, pos, neg = batch
# Собираем все продукты батча (positive + negative)
batch_product_ids = torch.cat([pos_ids, neg_ids], dim=0)
batch_product_mask = torch.cat([pos_mask, neg_mask], dim=0)
# Эмбеддинги
anchor_embs = inference_model.encode_batch(anchor_ids, anchor_mask)
product_embs = inference_model.encode_batch(batch_product_ids, batch_product_mask)
# Метрики
recalls = RetrievalMetrics.recall_at_k_batch(anchor_embs, product_embs, k_list=k_list)
for k in k_list:
all_recalls[k].append(recalls[k])
# Усреднение по всем батчам
for k in k_list:
mean_recall = np.mean(all_recalls[k])
print(f"Recall@{k}: {mean_recall:.4f}")
# Recall@1: 0.2686
# Recall@5: 0.5129
# Recall@10: 0.5862
# Recall@30: 0.6971
Для fine-tune E5
Я поленился и возьму метрику с графиков (мерятся на том же val).
# Recall@1: 0.6413
# Recall@5: 0.875
# Recall@10: 0.927
# Recall@30: 0.966
Разница поисковых выдач pretrain E5 и fine-tune E5
Видим, что при дообучении в выдаче начинают доминировать товары с бОльшим числом показов/кликов(=более одобряемые пользователями).
Скор начинает закладываться, какой-то эфемерный смысл популярности у пользователей
начинает быть мене уверенный (отдаляется от максимума равного
)
pretrain

fine-tune

✅ Итоги
Собрали минимальное, но полноценное решение:
поиск товаров по смыслу, с дообученной E5, визуализацией выдачи и реальными метриками. Всё это — в виде простого MVP-сервиса.
Решение универсально: с минимальными изменениями оно подойдёт для задач поиска по резюме, статьям, тикетам, продуктам и даже пользовательским вопросам.
Точки роста (их много ?)
Текущий стек:
HuggingFace + PyTorch + TensorBoard + Streamlit
Что можно улучшить:
Структура и кодовая база
Избавляться от ноутбуков в проде: выносить логику в модули (
train.py
,inference.py
,app/
)Универсальные сигнатуры функций и моделей (единый интерфейс)
Поддержка нескольких моделей без переписывания кода
Чёткая структура проекта, разделение обучения и сервиса
Обучение и логирование
Перейти на PyTorch Lightning для компактности и читаемости пайплайна
Добавить ClearML или W&B вместо TensorBoard для отслеживания экспериментов
Вынести параметры обучения в конфиги (
yaml
,json
,hydra
)
Оптимизация инференса
ONNX / ONNX Runtime для ускорения и портируемости модели
Квантование модели (int8/float16) для меньшего размера и CPU-инференса
Поддержка батчевой обработки и асинхронного инференса
Поиск и масштабирование
Добавить векторный индекс (например, Faiss) вместо полного перебора
CPU vs GPU: протестировать, где проще и дешевле — больше CPU с шардингом или один GPU
Работа с датасетом как с итератором, без загрузки в RAM
Весь код: данные, EDA, обучение и MVP-сервис лежит в репозитории
Так же может быть полезно пробежаться по статье: RecSys + DSSM + FPSLoss is all you need
Что дальше?
В следующих статьях покажем, как решать подобные задачи ещё компактнее и технологичнее — вплоть до полного отказа от кастомного кода с помощью NoCode-инструментов.
А пока — присоединяйтесь к нашему Telegram-сообществу @datafeeling и вдохновляйтесь современными подходами решения реальных задач.