Не скрою, что являюсь поклонником этого языка. Именно в Rust парадигма RAII нашла своё логическое завершение, а понятия владения и заимствования обеспечили безопасность памяти и параллелизма, гарантируемую на этапе компиляции. Это, действительно, серьёзные достижения.

Увы, данный пост не будет очередным восхвалением Rust. В нём я опишу трудности, с которыми пришлось столкнуться при проектировании API для нейронных сетей. Не будучи уверенным, что эти проблемы связанны с самим языком, а не с недостатком моих знаний о нём, я хотел бы подчеркнуть дискуссионность всего нижеизложенного.

Итак, моей задачей является написание API для полносвязной нейронной сети. API должен поддерживать как f32, так и f64 (первый тип выгоднее использовать при наличии гарантированной векторизации на CPU или GPU, второй — в остальных случаях).

Для вычисления значений выходов сети мне понадобится функция умножения матрицы на вектор (обобщение суммирования произведений значений на входах и соответствующих весовых коэффициентов для каждого слоя полносвязной сети). Поскольку это самая затратная операция, здесь стоит воспользоваться сторонним оптимизированным кодом. Данная функциональность входит в стандартный API по работе с матрицами, называемый BLAS (Basic Linear Algebra Subprograms). Есть несколько реализаций BLAS, в том числе OpenBLAS и cuBLAS (последняя оптимизирована под выполнение на CUDA-совместимых GPU).

Для начала попробуем описать типаж Blas с параметрической функцией умножения матрицы на вектор:

pub trait Blas {
    fn gemv<T>(&self, ...);
    ...
}

Здесь мы и столкнёмся с первой проблемой параметризации. В нынешнем виде Rust не поддерживает специализацию функций для отдельных типов (эта возможность в настоящий момент проходит стадию обсуждения). В нашем случае эта специализация необходима, поскольку варианты для f32 и f64 требуют различных вызовов соответствующих функций из разделяемой библиотеки.

Пробуем пойти другим путём. Можно, ведь, сделать наш типаж параметрическим и реализовать оба варианта (Blas<f32> и Blas<f64>) нашем типом (например, CudaBlas — реализацией BLAS для CUDA):

trait Blas<T> {
    fn gemv(&self, arg: T);
}

struct CudaBlas {
    field: usize,
}

impl Blas<f32> for CudaBlas {
    fn gemv(&self, arg: f32) {
        println!("f32");
    }
}

impl Blas<f64> for CudaBlas {
    fn gemv(&self, arg: f64) {
        println!("f64");
    }
}

Этот код прекрасно скомпилируется и даже будет работать так, как от него ожидается… пока не поместить его в контекст другого параметрического типа (нашей нейронной сети):

struct NeuralNet<T> {
    field: T,
}

impl<T> NeuralNet<T> {
    fn process(&self, arg: T) {
        let cblas = CudaBlas{field:0};
        cblas.gemv(arg);
    }
}

В результате получаем следующую ошибку компиляции:

error: the trait `Blas<T>` is not implemented for the type `CudaBlas` [E0277]

Другими словами компилятор хочет, чтобы наш CudaBlas реализовывал Blas для произвольного типа. Если же добавить пустую реализацию для общего случая, то компилятор начнёт ругаться на «избыточные» специализации для f32 и f64.

Попытка переписать тип CudaBlas как параметрический (костыль, искусственно сужающий функциональность, поскольку для операций с разными типами понадобятся разные экземпляры данного типа) тоже не приводит к успеху. В этом случае компилятор жалуется на то, что не может найти функцию gemv для CudaBlas<T> (по-сути, та же ситуация, что и в прошлом случае).

Странное ощущение. Начиная новый проект, хочу использовать Rust, но наталкиваюсь на удручающие ограничения по параметризации. Хочется верить, что это детские болезни, которые будут преодолены в ближайшее время. Вопрос только в том, на чём писать сейчас. Неужели опять на C++?

Комментарии (0)