Коротко про рекомендательные системы

Глобально существует два подхода в создании рекомендательных систем. Контентно-ориентированная и коллаборативная фильтрация. Основополагающее предположение подхода коллаборативной фильтрации заключается в том, что если А и В покупают аналогичные продукты, А, скорее всего, купит продукт, который купил В, чем продукт, который купил случайный человек. В отличие от контентно-ориентированного подхода, здесь нет признаков, соответствующих пользователям или предметам. Рекомендательная система базируется на матрице взаимодействий пользователей. Контентно-ориентированная система базируется на знаниях о предметах. Например если пользователь смотрит шелковые футболки возможно ему будет интересно посмотреть на другие шелковые футболки.

В этой статье я хочу рассказать о подходе который основан на поиске схожих изображений. Зачем подготавливать дополнительнительные данные если почти все основные характеристики некоторых товаров, например одежда, можно отобразить на изображении.

исключение softmax
исключение softmax

Суть подхода заключается в извлчении признаков из изображений товаров. С помощью сверточной сети, в своем примере я использовал Resnet50, так как вектор признаков resnet имеет относительно небольшую размерность. Извлечь вектор признаков с помощью обученой сети очень просто. Нужно просто исключить softmax классификатор именно он определяет к какому классу относится изображение и мы получим на выходе вектор признаков. Далее необходимо сравнивать векторы и искать похожие. Чем более схожи изображения тем меньше евклидово расстояние между векторами.

Код и датасет

Датасет можно скачать отсюда: ссылка на датасет.

Инициализации обученой restnet50 из библиотеки pytorch и извлечении признаков из датасета:

from torchvision.io import read_image
from torchvision.models import resnet50, ResNet50_Weights
import torch
import glob
import pickle
from tqdm import tqdm
from PIL import Image

def pil_loader(path):
    # Некоторые изображения из датасета представленны не в RGB формате, необходимо их конверитровать в RGB
    with open(path, 'rb') as f:
        img = Image.open(f)
        return img.convert('RGB')


# Инициализация модели обученой на датасете imagenet
weights = ResNet50_Weights.DEFAULT
model = resnet50(weights=weights)
model.eval()
preprocess = weights.transforms()

use_precomputed_embeddings = True
emb_filename = 'fashion_images_embs.pickle'
if use_precomputed_embeddings: 
    with open(emb_filename, 'rb') as fIn:
        img_names, img_emb_tensors = pickle.load(fIn)  
    print("Images:", len(img_names))
else:
    img_names  = list(glob.glob('images/*.jpg'))
    img_emb = []
    # извлечение признаков из изображений в датасете. У меня на CPU заняло около часа
    for image in tqdm(img_names):
        img_emb.append(
            model(preprocess(pil_loader(image)).unsqueeze(0)).squeeze(0).detach().numpy()
        )
    img_emb_tensors = torch.tensor(img_emb)
    
    with open(emb_filename, 'wb') as handle:
        pickle.dump([img_names, img_emb_tensors], handle, protocol=pickle.HIGHEST_PROTOCOL)

Функция которая создает поисковый индекс с помощью faiss и уменьшает размерность векторов признаков:

# Для сравнения векторов используется faiss
import faiss                   
from sklearn.decomposition import PCA

def build_compressed_index(n_features):
    pca = PCA(n_components=n_features)
    pca.fit(img_emb_tensors)
    compressed_features = pca.transform(img_emb_tensors)
    dataset = np.float32(compressed_features)
    d = dataset.shape[1]
    nb = dataset.shape[0]
    xb = dataset

    index_compressed = faiss.IndexFlatL2(d)
    index_compressed.add(xb)
    return [pca, index_compressed]

Хэлперы для отображения результатов:

import matplotlib.pyplot as plt
import matplotlib.image as mpimg

def main_image(img_path, desc):
    plt.imshow(mpimg.imread(img_path))
    plt.xlabel(img_path.split('.')[0] + '_Original Image',fontsize=12)
    plt.title(desc,fontsize=20)
    plt.show()

def similar_images(indices, suptitle):
    plt.figure(figsize=(15,10), facecolor='white')
    plotnumber = 1    
    for index in indices[0:4]:
        if plotnumber<=len(indices) :
            ax = plt.subplot(2,2,plotnumber)
            plt.imshow(mpimg.imread(img_names[index]))
            plt.xlabel(img_names[index],fontsize=12)
            plotnumber+=1
    plt.suptitle(suptitle,fontsize=15)
    plt.tight_layout()

Сама функция поиска. Принимает на вход количество признаков, что бы можно было поэксперементировать с достаточным количеством признаков:

import numpy as np
# поиск, можно искать по индексу из предварительно извлеченных изображений или передать новое изображение
def search(query, factors):
    if(type(query) == str):
        img_path = query
    else:
        img_path = img_names[query]
    one_img_emb = torch.tensor(model(preprocess(read_image(img_path)).unsqueeze(0)).squeeze(0).detach().numpy())
    main_image(img_path, 'Query')
    compressor, index_compressed = build_compressed_index(factors)
    D, I = index_compressed.search(np.float32(compressor.transform([one_img_emb.detach().numpy()])),5)
    similar_images(I[0][1:], "faiss compressed " + str(factors))

Виновник торжества. Вызов поиска:

search(100,300)
search("t-shirt.jpg", 500)

Выводы

В итоге за пару часов можно собрать довольно качественную рекомендательную систему основаную на схожести изображений, чего достаточно для некоторых случаев. Изображения не требуют предварительной подготовки, разметки и какой то метаинформации что значительно упрощает процесс.

Для повышения качества рекомендаций можно дообучить некторые слои сети на используемом датасете.

Комментарии (0)