Привет, Хабр!

Сегодня у нас на повестке дня rust-bert — одна из самых мощных библиотек для обработки естественного языка в экосистеме Rust. Если вы уже знакомы с Hugging Face и их библиотекой Transformers на Python, то rust-bert для вас. Эта библиотека переносит state-of-the-art модели прямо в проект на Rust.

Главная фича rust-bert в том, что она идеально вписывается в Rust.

Настройка окружения

Чтобы использовать rust-bert потребуется установить Libtorch — это C++ API от PyTorch, который библиотека использует для глубокого обучения. Хотя можно выбрать автоматическую установку, ручная настройка даст вам полный контроль, особенно если нужно задействовать GPU.

Ручная установка Libtorch:

  1. Загрузка Libtorch:

    wget https://download.pytorch.org/libtorch/cu124/libtorch-cxx11-abi-shared-with-deps-2.4.0%2Bcu124.zip
  2. Распаковка:

    unzip libtorch-cxx11-abi-shared-with-deps-2.4.0+cu124.zip -d /path/to/libtorch
  3. Настройка переменных окружения:

    Linux/macOS:

    export LIBTORCH=/path/to/libtorch
    export LD_LIBRARY_PATH=${LIBTORCH}/lib:$LD_LIBRARY_PATH

    Windows (PowerShell):

    $Env:LIBTORCH="C:\path\to\libtorch"
    $Env:Path += ";C:\path\to\libtorch\lib"

Если не хочется заморачиваться с ручной настройкой, можно использовать автоматическое скачивание Libtorch через флаг download-libtorch. Удобно для CPU-only версий, но для CUDA потребуется указать версию через переменную TORCH_CUDA_VERSION.

Настройка автоматической загрузки:

[dependencies]
rust-bert = { version = "0.23", features = ["download-libtorch"] }
export TORCH_CUDA_VERSION=cu124
cargo build

Rust-bert также поддерживает кэширование моделей, загружаемых и сохраняемых локально в ~/.cache/.rustbert. Если нужно настроить другой путь, используйте переменную RUSTBERT_CACHE.

Основные возможности

Модели для извлечения ответов на вопросы

Одна из самых полезных возможностей — извлечение ответов на вопросы. Вы даете вопрос и контекст, модель находит ответ с указанием точных координат в тексте.

use rust_bert::pipelines::question_answering::{QaInput, QuestionAnsweringModel};
fn main() -> anyhow::Result<()> {
    let qa_model = QuestionAnsweringModel::new(Default::default())?;
    let question = "Где живет Эми?";
    let context = "Эми живет в Амстердаме.";
    let answers = qa_model.predict(&[QaInput { question, context }], 1, 32);
    println!("{:?}", answers);
    Ok(())
}

Модель находит «Амстердам» с высокой точностью.

Модели перевода текста

Rust-bert поддерживает модели перевода, такие как Marian и M2M100.

use rust_bert::pipelines::translation::{Language, TranslationModelBuilder};
fn main() -> anyhow::Result<()> {
    let model = TranslationModelBuilder::new()
        .with_source_languages(vec![Language::English])
        .with_target_languages(vec![Language::Russian])
        .create_model()?;
    
    let input_text = "This is a test sentence.";
    let output = model.translate(&[input_text], None, Language::Russian)?;
    for sentence in output {
        println!("{}", sentence);
    }
    Ok(())
}

Абстрактивное суммирование

Абстрактивное суммирование позволяет сократить длинные тексты до краткого, но емкого содержания.

use rust_bert::pipelines::summarization::SummarizationModel;
fn main() -> anyhow::Result<()> {
    let summarization_model = SummarizationModel::new(Default::default())?;
    let input = ["Ученые обнаружили воду в атмосфере планеты K2-18b..."];
    let output = summarization_model.summarize(&input);
    println!("{:?}", output);
    Ok(())
}

Генерация текста

Библиотека поддерживает GPT-2 и GPT.

use rust_bert::pipelines::text_generation::TextGenerationModel;
fn main() -> anyhow::Result<()> {
    let model = TextGenerationModel::new(Default::default())?;
    let input_context = "В один прекрасный день,";
    let output = model.generate(&[input_context], None);
    println!("{:?}", output);
    Ok(())
}

Нейронные сетевые классификаторы

Zero-shot классификация позволяет моделям классифицировать текст без необходимости дополнительного обучения для конкретных классов.

use rust_bert::pipelines::zero_shot_classification::ZeroShotClassificationModel;
fn main() -> anyhow::Result<()> {
    let model = ZeroShotClassificationModel::new(Default::default())?;
    let input_sentence = "Сегодня солнечная погода.";
    let candidate_labels = vec!["weather", "sports", "politics"];
    let output = model.predict(&[input_sentence], &candidate_labels, None, 128);
    println!("{:?}", output);
    Ok(())
}

Работа с кастомными моделями и ONNX

Если стандартные модели не подходят, rust-bert позволяет загружать кастомные модели из PyTorch и ONNX. Экспорт из PyTorch осуществляется с помощью torch.save, после чего можно конвертировать веса для использования в rust-bert.

Пример загрузки кастомной модели:

use rust_bert::bert::BertModel;
use tch::nn::VarStore;

fn main() -> anyhow::Result<()> {
    let mut vs = VarStore::new(tch::Device::Cpu);
    let model = BertModel::new(&vs.root(), Default::default());
    vs.load("path/to/bert_model_weights.bin")?;
    Ok(())
}

Экспорт в ONNX через Hugging Face Optimum:

from transformers import BertModel
from optimum.onnxruntime import ORTModelForQuestionAnswering
from optimum.exporters import TasksManager

model = BertModel.from_pretrained("bert-base-uncased")
TasksManager.export(model, "onnx", save_dir="path_to_save_onnx")

ONNX позволяет оптимизировать инференс на GPU и CPU, делая модель быстрее и эффективнее для высоконагруженных приложений.

Подробнее с rust-bert можно ознакомиться здесь.

23 октября пройдёт открытый урок «TSMixter современная архитектура разложения временных рядов от Google». На этом занятии мы рассмотрим современную модель TSMixter от Google, которая умеет автоматически раскладывать временной ряд на сложные компоненты и строить прогноз на их основе. Мы научимся загружать и работать с этой моделью, а также сравним ее с более сложными трансформенными моделями, такими как NBEATS, NHITS, iTransformers, PatchTST и TimesNet. Записаться на урок можно на странице курса "Machine Learning. Professional".

Комментарии (1)


  1. Stonuml
    09.10.2024 07:00

    Можно попросить все же поправить форматирование? fn main() -> anyhow::Result<()> {