Во время длинных новогодних выходных я решил заняться новым пет-проектом. Хотелось попробовать что-то интересное, связанное с нейронными сетями, и мне показалась новой для меня идея сделать приложение для распознавания растений, определения вида, возраста и их болезней прямо на мобильном устройстве без сервера, а потом рекомендация того, как лечить, ухаживать и прочее.

В процессе выяснилось, что подобных приложений уже достаточно, но это не уменьшало ценности эксперимента: мне хотелось понять, как можно оптимизировать нейросети их для мобильных устройств и запускать прямо на своем телефоне. В этой статье поговорим о том, что получилось: от выбора моделей до их оптимизации и интеграции в React Native, а также поделюсь своим опытом работы с различными методами классификации растений и их болезней, оценкой возраста и количества листьев.

Модель SCOLD для кросс-модального поиска болезней

Первым инструментом, с которым я решил поэкспериментировать, стала модель SCOLD с Hugging Face. Она была создана для кросс-модального поиска, то есть позволяет сопоставлять изображение с текстовым описанием. Например, можно подать на вход фото кукурузного листа и текстовое описание болезни и получить оценку того, насколько текст соответствует изображению. Модель обучена на датасете LeafNet. Несмотря на интересную концепцию, мне было не совсем ясно, как применять её на практике, да и сама модель весила около гигабайта, что сразу делает её запуск на мобильном устройстве проблемным. Я решил немного изменить модель, чтобы можно было распознать, что за болезнь растения.

Домашнее растение в рандомный момент
Домашнее растение в рандомный момент

Так как мне была важна только работа с изображениями, я убрал текстовый энкодер Roberta и сохранил только модуль обработки изображений на основе Swin Transformer [исходный код]. В результате я получил модель намного меньше в размере, которая возвращает вектор признаков изображения размерностью 512. Вектор изначально предполагалось сравнивать с выходом Roberta для текстовых описаний, но для мобильного применения я решил использовать другой подход: вместо того, чтобы работать с текстом, я сравниваю векторы изображения с векторами, заранее рассчитанными на всем датасете. Теперь дополнительно появилось два файла 248 Мб бинарного файла с embeddings, плюс JSON с метками болезней в 18 Мб. Такой подход сохраняет смысл задачи определение болезни без необходимости подбирать для изображения определение и размер модели стал меньше.

from timm import create_model
import torch
import torch.nn as nn
import numpy as np
from transformers import RobertaModel

EMBEDDING_DIM = 512


class ImageEncoder(nn.Module):
    def __init__(self):
        super(ImageEncoder, self).__init__()
        # Load the Swin Transformer with features_only=True
        self.swin = create_model("swin_base_patch4_window7_224.ms_in22k", pretrained=True, features_only=True)
        for param in self.swin.parameters():
            param.requires_grad = True

        # Get the feature size of the final stage
        self.swin_output_dim = self.swin.feature_info.channels()[-1]  # Last stage: 1024 channels

        # Define FC layer
        self.fc1 = nn.Linear(self.swin_output_dim * 7 * 7, EMBEDDING_DIM)  # Flattened input size
        nn.init.xavier_uniform_(self.fc1.weight)
        nn.init.zeros_(self.fc1.bias)
        for param in self.fc1.parameters():
            param.requires_grad = True

    def forward(self, x):
        # Extract features from Swin
        swin_features = self.swin(x)[-1]  # Use the last stage feature map (e.g., [B, 1024, 7, 7])

        # Flatten feature map
        swin_features = swin_features.view(swin_features.size(0), -1)  # Shape: (B, 1024*7*7)

        # Pass through FC layer
        output = self.fc1(swin_features)  # Shape: (B, embedding_dim)
        return output


class LVL(nn.Module):
    def __init__(self):
        super(LVL, self).__init__()
        self.image_encoder = ImageEncoder()
        self.text_encoder = nn.Identity()
        self.t_prime = nn.Parameter(torch.ones([]) * np.log(0.07))
        self.b = nn.Parameter(torch.ones([]) * 0)

    def get_images_features(self, images):
        image_embeddings = self.image_encoder(images)  # (batch_size, EMBEDDING_DIM)
        image_embeddings = nn.functional.normalize(image_embeddings, p=2, dim=-1)
        return image_embeddings

    def get_texts_feature(self, input_ids=None, attention_mask=None):
        """
        Plug
        :param input_ids: Tensor of shape (batch_size, seq_length)
        :param attention_mask: Tensor of shape (batch_size, seq_length)
        :return:
        """
        return None

    def forward(self, images, input_ids=None, attention_mask=None):
        """
        Args:
            images: Tensor of shape (batch_size, 3, 224, 224)
            input_ids: Tensor of shape (batch_size, seq_length)
            attention_mask: Tensor of shape (batch_size, seq_length)

        Returns:
            Image and text embeddings normalized for similarity calculation
        """

        image_embeddings = self.get_images_features(images)
        return image_embeddings

Следующим шагом я перевёл модель в формат ONNX, который хорошо поддерживается на мобильных и позволяет интегрировать нейросети в React Native. В процессе пришлось учитывать особенности каждой модели: нужно правильно задать dummy input для проверки входного формата, указать имена входов и выходов, dynamic_axes для поддержки батчей, а иногда и версию opset, чтобы обеспечить совместимость с ONNX Runtime.

import torch
# pip install onnxscript onnxruntime onnxruntime-tools
from model import LVL


def export(model_path: str, output_name: str):
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    # Create a dummy input with the same dimensions as the real data.
    dummy_input = torch.randn(1, 3, 512, 512, device=device)

    # Load model
    model = LVL()
    model.to(device)
    model.eval()

    #  git clone https://huggingface.co/enalis/scold
    state_dict = torch.load(model_path, map_location=device)

    # Leave only the keys that are in the new model (only image_encoder)
    model_dict = model.state_dict()
    filtered_dict = {k: v for k, v in state_dict.items() if k in model_dict}

    # Load only image_encoder weights
    model.load_state_dict(filtered_dict, strict=False)  # <- strict=False important!

    # Export
    torch.onnx.export(
        model,                      # Model
        dummy_input,                # Input
        output_name,           # Name
        export_params=True,         # Save weights
        opset_version=18,           # Version
        do_constant_folding=True,   # Optim
        input_names=['images'],     # Name of inputs
        output_names=['image_embeddings'],  # Name of outputs
        dynamic_axes={'images': {0: 'batch_size'},  # Dynamic support batch size
                      'image_embeddings': {0: 'batch_size'}}
    )
    print("Finish!")


if __name__ == "__main__":
    from argparse import ArgumentParser

    parser = ArgumentParser()
    parser.add_argument("--model_path", type=str, required=True, help="Path to model pth")
    parser.add_argument("--output_name", type=str, default="disease_detector.onnx", help="Path to save onnx model with name model.onnx")
    args = parser.parse_args()

    export(model_path=args.model_path, output_name=args.output_name)

После этого я дополнительно уменьшил размер модели, конвертировав её из float32 в float16, сохранив при этом то, что на вход приходит float32. В итоге удалось снизить вес модели с одного гигабайта до 231 Мб, что делает её удобной для мобильного приложения.

import onnx
from onnxconverter_common import float16


def export(model_path: str, output_name: str):
    # Load model
    model_fp32 = onnx.load(model_path)
    # Export to float16
    model_fp16 = float16.convert_float_to_float16(model_fp32, keep_io_types=True)
    # Save
    onnx.save(model_fp16, output_name)


if __name__ == "__main__":
    from argparse import ArgumentParser

    parser = ArgumentParser()
    parser.add_argument("--model_path", type=str, required=True, help="Path model onnx")
    parser.add_argument("--output_name", type=str, default="fp16.onnx", help="Path to save fp16 mode with name model_fp16.onnx")
    args = parser.parse_args()

    export(model_path=args.model_path, output_name=args.output_name)

Сама модель disease_detection.onnx и код для её запуска можно посмотреть на моём GitHub.

Классификация растений по изображению

Обойтись без нейросети для распознавания растений по фотографии в таком проекте было бы странно, поэтому следующим шагом я занялся именно этой задачей. В процессе поиска подходящих данных и моделей я наткнулся на соревнование PlantCLEF 2024, из которого вышел на датасет Pl@ntNet. Сам датасет огромный, а обучение модели на полном наборе данных требует много времени и ресурсов. Но вместо того, чтобы грузить свой ноутбук, я обнаружил, на Zenodo уже была выложена готовая модель, обученная на этих данных, вместе с кодом для её запуска.

Вид: Подсолнух понтовый
Вид: Подсолнух понтовый

На вход подаётся изображение растения, а на выходе получается набор вероятностей соответствия конкретным видам. По сути, это классическая задача классификации изображений, без дополнительных усложнений. Для моих целей этого было вполне достаточно, поэтому я снова проделал знакомый путь: экспортировал модель в ONNX и затем сконвертировал её в float16, чтобы уменьшить размер и сделать более пригодной для запуска на мобильных устройствах. Готовая модель и код для её использования доступны в репозитории проекта.

import torch
import timm
# pip install onnxscript onnxruntime onnxruntime-tools
import argparse


def load_class_mapping(class_list_file):
    with open(class_list_file) as f:
        class_index_to_class_name = {i: line.strip() for i, line in enumerate(f)}
    return class_index_to_class_name


def export(model_path: str, output_name: str):
    torch.serialization.add_safe_globals([argparse.Namespace])
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    # Class and model after train.py
    class_mapping = load_class_mapping(args.class_mapping)

    # Load model
    model = timm.create_model('vit_base_patch14_reg4_dinov2.lvd142m', pretrained=False, num_classes=len(class_mapping))
    checkpoint = torch.load(model_path, map_location=device)
    model.load_state_dict(checkpoint['state_dict'])
    model.eval()

    # Dummy input (ViT Base 224x224 RGB)
    dummy_input = torch.randn(1, 3, 518, 518)

    # Export
    torch.onnx.export(
        model,                      # Model
        dummy_input,                # Input
        output_name,           # Name
        export_params=True,         # Save weights
        opset_version=18,           # Version
        input_names=['input'],      # Name of inputs
        output_names=['output'],    # Name of outputs
        dynamic_axes={"input": {0: "batch_size"}, "output": {0: "batch_size"}},
    )
    print("Finished!")


if __name__ == "__main__":
    from argparse import ArgumentParser

    parser = ArgumentParser()
    parser.add_argument("--class_mapping", type=str, required=True, help="Path to class mapping")
    parser.add_argument("--model_path", type=str, required=True, help="Path species mapping")
    parser.add_argument("--output_name", type=str, default="plant_classificator.onnx", help="Path to save onnx model with name model.onnx")
    args = parser.parse_args()

    export(model_path=args.model_path, output_name=args.output_name)

Позже, уже более внимательно изучив датасет, я понял, что он в основном содержит редкие и экзотические растения. В нём практически нет распространённых декоративных цветов, овощей или фруктов, которые ожидаешь увидеть в подобном приложении. Либо хотя бы распознавание букетных цветов.

Из-за этого я решил пойти дальше и дописал код для дообучения модели или обучения с нуля. На практике обучение с нуля оказалось более предпочтительным вариантом, так как позволяет получить меньшую по размеру модель, что критично для мобильных устройств. Кроме того, такой подход даёт возможность пользователю самостоятельно выбирать, какую модель загружать, в зависимости от задач и доступных ресурсов устройства.

Я вынес процесс обучения в отдельный модуль и подробно описал структуру данных и шаги запуска в документации проекта. Обучение можно запустить с batch size 8 на видеокарте с 8 Гб видеопамяти, либо уменьшить размер батча, если ресурсы ограничены. Сам датасет при этом приходится собирать самостоятельно, комбинируя данные из Kaggle и других открытых источников. Такой подход оказался более гибким и лучше вписывается в идею экспериментального мобильного ML-проекта.

Регрессионная модель для оценки возраста растения и количества листьев

Задача определения возраста растения и количества листьев сама по себе выглядит довольно тривиальной, но на практике оказалось, что готовых моделей я не нашел. Есть научная работа GroMo: Plant Growth Modeling with Multiview Images и соответствующий репозиторий, однако код там написан достаточно небрежно и плохо масштабируется под прикладное использование, поэтому я решил не адаптировать его, а переписать решение с нуля, ориентируясь на свои ограничения. В оригинальной работе растения различных культур, таких как mustard, radish, wheat и okra, фотографировали с нескольких ракурсов и на разных стадиях роста, а в разметке указывали возраст растения и количество листьев. Важный момент здесь в том, что визуальные признаки сильно зависят от культуры: форма листа, его размер и даже сама структура растения могут сильно отличаться, иногда листья визуально почти неотличимы от стебля. В исходной постановке модели на вход подавались сразу четыре изображения с разных углов, тогда как в моём случае вход был упрощён до одного фото, что ближе к реальному пользовательскому сценарию на мобильном устройстве. Также возникла идея сравнивать, соответствует ли текущее количество листьев или стеблей его возрасту, или оно занижено, но её я решил отложить.

В основе этой модели лежит MobileNetV3 Large, использованный в роли универсального визуального энкодера. Из предобученной модели берётся только сверточная часть, которая извлекает компактные и информативные признаки из изображения листа. После этого признаки агрегируются с помощью Adaptive Average Pooling до фиксированного вектора размерности 960, что делает модель независимой от размера входного изображения. Далее этот общий вектор признаков используется сразу для двух регрессионных задач: оценки количества листьев и возраста растения. Для этого добавлены две независимые «головы», каждая из которых представляет собой небольшой полносвязный блок. Модель я назвал LeafNet.

import torch.nn as nn
from torchvision.models import mobilenet_v3_large

class LeafNet(nn.Module):
    def __init__(self):
        super().__init__()
        backbone = mobilenet_v3_large(weights="DEFAULT")
        self.encoder = backbone.features

        self.pool = nn.AdaptiveAvgPool2d(1)

        self.count_head = nn.Sequential(
            nn.Linear(960, 128),
            nn.ReLU(inplace=True),
            nn.Linear(128, 1)
        )

        self.age_head = nn.Sequential(
            nn.Linear(960, 128),
            nn.ReLU(inplace=True),
            nn.Linear(128, 1)
        )

    def forward(self, x):
        feat = self.encoder(x)
        pooled = self.pool(feat).flatten(1)
        count = self.count_head(pooled)
        age = self.age_head(pooled)
        return feat, count, age

Обучение проводилось на объединённом датасете без разделения по культурам, чтобы учиться на более разнообразных визуальных паттернах, суммарный размер датасета составил около 380 Гб. Полное обучение заняло примерно три дня, а сама модель в процессе обучения автоматически экспортируется в формат ONNX, для мобильных устройствах. Документация по обучению и весь код для запуска и экспериментов доступны в репозитории проекта.

Интеграция и запуск нейросетей в React Native

В этой части статьи я расскажу про запуск нейронных сетей непосредственно на мобильном устройстве и про тот код, который для этого понадобился. Я намеренно опускаю все базовые вещи вроде установки Node.js, запуска Android-эмулятора или сборки окружения. Это хорошо описано в других местах и не имеет прямого отношения к задаче. Здесь речь пойдёт о коде и о тех ограничениях, с которыми пришлось столкнуться при работе с нейросетями на мобильных платформах. В репозитории проекта я добавил готовый APK-файл для тех, кто не хочет собирать приложение самостоятельно (весит много, так как внутри модели), но для разработки и экспериментов я всё же рекомендую собирать проект вручную. Тестирование проводилось на эмуляторе Pixel 8 Pro, поэтому поведение на других устройствах может отличаться. Инструкции по сборке и запуску подробно описаны в репозитории проекта в папке mobile.

Первой серьёзной проблемой, которая стоила мне примерно одного дня разработки, оказался выбор стека для мобильного приложения. Изначально я начал с Expo, так как он позволяет быстро поднять прототип, но выяснилось, что запуск ONNX-моделей в таком окружении либо нестабилен, либо вовсе невозможен. Даже сборка собственного билда не помогло. В итоге я отказался от Expo и перешёл на чистый React Native, что оказалось правильным решением с точки зрения дальнейшей интеграции нативного кода и работы с нейросетями.

После этого я столкнулся с фундаментальным ограничением мобильных платформ: JavaScript-движок на Android и iOS не предназначен для работы с большими бинарными файлами. Файлы вроде captions.json и особенно embeddings.bin, размер которого 248 Мб, невозможно просто загрузить в память из JavaScript без риска получить OutOfMemory или крайне медленную работу из-за постоянного копирования данных. При этом именно эти файлы необходимы для работы модели определения болезней, где поиск осуществляется по заранее рассчитанным embedding-векторам.

Значит, выносим всю работу с большими данными и векторным поиском в нативный слой Android, реализовав собственный нативный модуль на Kotlin. Так появился FaissSearchModule. Его основная задача — загрузка embedding-файла, хранение данных вне JavaScript heap и выполнение поиска по векторам полностью в нативной памяти. Вместо загрузки embeddings.bin в RAM используется memory-mapped файл через MappedByteBuffer. Такой подход позволяет операционной системе самой управлять памятью и подгружать данные по мере необходимости, не держа весь файл целиком в памяти приложения. По сути, embedding-файл становится частью виртуальной памяти процесса, что важно я считаю при работе с большими массивами данных на мобильных устройствах.

Поиск ближайших векторов также выполняется целиком на стороне нативного кода. В JavaScript передаётся только embedding запроса, а обратно возвращается компактный результат: индексы, расстояния и текстовые описания болезней. Это минимизирует нагрузку на bridge между JavaScript и нативным кодом и избавляет от передачи больших массивов чисел туда и обратно. Аналогичный подход используется и для captions.json. Вместо стандартного парсинга JSON файл читается потоково, что позволяет избежать загрузки всего содержимого в память. Сами файлы находятся в этой директории для Android.

package com.berkano

import com.facebook.react.bridge.*
import org.json.JSONArray
import java.io.*
import java.nio.ByteBuffer
import java.nio.ByteOrder
import kotlin.math.sqrt

/**
 * Нативный модуль для работы с векторным поиском через Faiss
 * Данные хранятся в нативной памяти, не попадая в JavaScript heap
 */
class FaissSearchModule(
    private val reactContext: ReactApplicationContext
) : ReactContextBaseJavaModule(reactContext) {

    // Храним данные в нативной памяти
    // Используем MappedByteBuffer вместо FloatArray для больших файлов
    // Это позволяет не загружать все данные в RAM сразу
    private var embeddingsBuffer: java.nio.MappedByteBuffer? = null
    private var captions: List<String>? = null
    private var numVectors: Int = 0
    private var embeddingSize: Int = 512
    private var embeddingsFile: File? = null

    override fun getName(): String {
        return "FaissSearch"
    }

    /**
     * Загружает embeddings из бинарного файла используя memory-mapped файл
     * НЕ загружает все данные в RAM, читает по требованию
     */
    @ReactMethod
    fun loadEmbeddings(embeddingsPath: String, promise: Promise) {
        try {
            val file = File(embeddingsPath)
            if (!file.exists()) {
                promise.reject("FILE_NOT_FOUND", "Embeddings file not found: $embeddingsPath")
                return
            }

            val fileSize = file.length()
            numVectors = (fileSize / 4 / embeddingSize).toInt() // 4 байта на float
            
            // Используем MappedByteBuffer - это НЕ загружает файл в RAM полностью
            // Операционная система управляет памятью и подгружает данные по требованию
            FileInputStream(file).use { fis ->
                val channel = fis.channel
                val byteBuffer = channel.map(
                    java.nio.channels.FileChannel.MapMode.READ_ONLY,
                    0,
                    fileSize
                )
                byteBuffer.order(ByteOrder.LITTLE_ENDIAN)
                embeddingsBuffer = byteBuffer
            }

            embeddingsFile = file

            val result = Arguments.createMap()
            result.putInt("numVectors", numVectors)
            result.putInt("embeddingSize", embeddingSize)
            result.putLong("fileSize", fileSize)

            promise.resolve(result)
            println("Mapped $numVectors vectors of size $embeddingSize from file (${fileSize / 1024 / 1024}MB)")
        } catch (e: OutOfMemoryError) {
            promise.reject("OUT_OF_MEMORY", "Not enough memory to map embeddings file. File too large: ${e.message}", e)
        } catch (e: Exception) {
            promise.reject("LOAD_ERROR", "Failed to load embeddings: ${e.message}", e)
        }
    }

    /**
     * Читает вектор по индексу из memory-mapped файла
     * Не загружает все данные в память
     */
    private fun getVector(index: Int): FloatArray {
        if (embeddingsBuffer == null) {
            throw IllegalStateException("Embeddings not loaded")
        }
        
        val startByte = index * embeddingSize * 4 // 4 байта на float
        val vector = FloatArray(embeddingSize)
        
        embeddingsBuffer!!.position(startByte)
        embeddingsBuffer!!.asFloatBuffer().get(vector)
        
        return vector
    }

    /**
     * Загружает captions из JSON файла
     * Использует streaming парсинг для больших файлов
     */
    @ReactMethod
    fun loadCaptions(captionsPath: String, promise: Promise) {
        try {
            val file = File(captionsPath)
            if (!file.exists()) {
                promise.reject("FILE_NOT_FOUND", "Captions file not found: $captionsPath")
                return
            }

            // Используем BufferedReader для эффективного чтения больших файлов
            val captionsList = mutableListOf<String>()
            var buffer = StringBuilder()
            var inString = false
            var escapeNext = false
            
            BufferedReader(FileReader(file), 8192).use { reader ->
                var char: Int
                while (reader.read().also { char = it } != -1) {
                    when {
                        escapeNext -> {
                            buffer.append(char.toChar())
                            escapeNext = false
                        }
                        char == '\\'.code -> {
                            escapeNext = true
                        }
                        char == '"'.code -> {
                            if (inString) {
                                // Конец строки
                                captionsList.add(buffer.toString())
                                buffer = StringBuilder()
                                inString = false
                            } else {
                                // Начало строки
                                inString = true
                            }
                        }
                        inString -> {
                            buffer.append(char.toChar())
                        }
                        char == ']'.code -> {
                            // Конец массива
                            break
                        }
                    }
                }
            }
            
            captions = captionsList
            val result = Arguments.createMap()
            result.putInt("count", captionsList.size)

            promise.resolve(result)
            println("Loaded ${captionsList.size} captions")
        } catch (e: OutOfMemoryError) {
            promise.reject("OUT_OF_MEMORY", "Not enough memory to load captions. File too large: ${e.message}", e)
        } catch (e: Exception) {
            promise.reject("LOAD_ERROR", "Failed to load captions: ${e.message}", e)
        }
    }


    /**
     * Вычисляет косинусное расстояние между двумя векторами
     */
    private fun cosineDistance(vecA: FloatArray, vecB: FloatArray): Float {
        var dotProduct = 0f
        var normA = 0f
        var normB = 0f

        for (i in vecA.indices) {
            dotProduct += vecA[i] * vecB[i]
            normA += vecA[i] * vecA[i]
            normB += vecB[i] * vecB[i]
        }

        normA = sqrt(normA)
        normB = sqrt(normB)

        if (normA == 0f || normB == 0f) {
            return 1f
        }

        val similarity = dotProduct / (normA * normB)
        return 1f - similarity
    }

    /**
     * Находит топ-K ближайших векторов к запросу
     * Поиск выполняется в нативной памяти, в JS передаются только результаты
     */
    @ReactMethod
    fun search(
        queryEmbedding: ReadableArray,
        topK: Int,
        promise: Promise
    ) {
        try {
            if (embeddingsBuffer == null) {
                promise.reject("NOT_LOADED", "Embeddings not loaded. Call loadEmbeddings first.")
                return
            }

            if (captions == null) {
                promise.reject("NOT_LOADED", "Captions not loaded. Call loadCaptions first.")
                return
            }

            // Конвертируем запрос из ReadableArray в FloatArray
            val query = FloatArray(queryEmbedding.size())
            for (i in 0 until queryEmbedding.size()) {
                query[i] = queryEmbedding.getDouble(i).toFloat()
            }

            if (query.size != embeddingSize) {
                promise.reject("INVALID_SIZE", "Query embedding size ${query.size} != $embeddingSize")
                return
            }

            // Вычисляем расстояния для всех векторов
            // Читаем векторы по требованию из memory-mapped файла
            val distances = mutableListOf<Pair<Int, Float>>()

            for (i in 0 until numVectors) {
                // Читаем вектор из файла по требованию
                val vector = getVector(i)
                val distance = cosineDistance(query, vector)
                distances.add(Pair(i, distance))
            }

            // Сортируем и берем топ-K
            distances.sortBy { it.second }
            val topResults = distances.take(topK)

            // Формируем результаты для передачи в JS
            val results = Arguments.createArray()
            for ((index, distance) in topResults) {
                val resultItem = Arguments.createMap()
                resultItem.putInt("index", index)
                resultItem.putDouble("distance", distance.toDouble())
                resultItem.putString("caption", captions!![index])
                results.pushMap(resultItem)
            }

            promise.resolve(results)
        } catch (e: Exception) {
            promise.reject("SEARCH_ERROR", "Search failed: ${e.message}", e)
        }
    }

    /**
     * Очищает загруженные данные из памяти
     */
    @ReactMethod
    fun clearCache(promise: Promise) {
        embeddingsBuffer = null
        embeddingsFile = null
        captions = null
        numVectors = 0
        promise.resolve(null)
    }

    /**
     * Проверяет, загружены ли данные
     */
    @ReactMethod
    fun isLoaded(promise: Promise) {
        val result = Arguments.createMap()
        result.putBoolean("embeddingsLoaded", embeddingsBuffer != null)
        result.putBoolean("captionsLoaded", captions != null)
        result.putInt("numVectors", numVectors)
        promise.resolve(result)
    }
}

FaissSearchPackage в этой схеме играет вспомогательную роль. Он необходим для регистрации нативного модуля в React Native и сообщает JavaScript-части приложения о наличии FaissSearchModule. Вся логика при этом сосредоточена именно в самом модуле, а package служит инфраструктурным слоем для его подключения.

package com.berkano

import com.facebook.react.ReactPackage
import com.facebook.react.bridge.NativeModule
import com.facebook.react.bridge.ReactApplicationContext
import com.facebook.react.uimanager.ViewManager

class FaissSearchPackage : ReactPackage {
    override fun createNativeModules(
        reactContext: ReactApplicationContext
    ): List<NativeModule> {
        return listOf(FaissSearchModule(reactContext))
    }

    override fun createViewManagers(
        reactContext: ReactApplicationContext
    ): List<ViewManager<*, *>> {
        return emptyList()
    }
}

Помимо работы с embedding-файлами, есть ещё одна проблема, которую нельзя решить на уровне JavaScript. Все ONNX-модели в проекте принимают на вход тензор фиксированной формы и формата, тогда как на стороне React Native изображение обычно приходит в виде URI или base64-строки. Преобразование изображения в числовой тензор можно было бы попытаться сделать в JavaScript, но на практике это приводит либо к медленной работе, либо к потреблению памяти, особенно если речь идёт о изображениях с камеры. Поэтому подготовка входных данных для модели также была вынесена в нативный слой.

Для этого реализован отдельный нативный модуль ImageDecoderModule. Его задача принять URI изображения, загрузить его через ContentResolver, декодировать в Bitmap, привести к нужному размеру и преобразовать в тензор в формате CHW, который ожидают модели, обученные в PyTorch и экспортированные в ONNX. На этом этапе происходит нормализация значений пикселей в диапазон от 0 до 1 и явное разделение каналов RGB, чтобы итоговый массив точно соответствовал входу модели.

Важно, что вся эта логика выполняется на стороне Android, до передачи данных в JavaScript. React Native не умеет напрямую работать с FloatArray, поэтому итоговый тензор конвертируется в WritableArray, который уже корректно передаётся через bridge. Да, это всё ещё массив чисел, но он формируется один раз, в нативном коде, без промежуточных копирований и лишних аллокаций в JS-движке, что заметно снижает нагрузку на память и ускоряет подготовку данных перед инференсом.

package com.berkano;

import android.graphics.Bitmap
import android.graphics.BitmapFactory
import android.net.Uri
import com.facebook.react.bridge.*
import java.io.InputStream

class ImageDecoderModule(
  private val reactContext: ReactApplicationContext
) : ReactContextBaseJavaModule(reactContext) {

  override fun getName(): String {
    return "ImageDecoder"
  }

  @ReactMethod
  fun decodeToTensor(
    uriString: String,
    targetWidth: Int,
    targetHeight: Int,
    promise: Promise
  ) {
    try {
      val uri = Uri.parse(uriString)
      val inputStream: InputStream? =
        reactContext.contentResolver.openInputStream(uri)

      if (inputStream == null) {
        promise.reject("ERROR", "Cannot open image stream")
        return
      }

      var bitmap: Bitmap? = null
      try {
        bitmap = BitmapFactory.decodeStream(inputStream)
        if (bitmap == null) {
          promise.reject("ERROR", "Cannot decode image")
          return
        }
        bitmap = Bitmap.createScaledBitmap(bitmap, targetWidth, targetHeight, true)
      } finally {
        inputStream.close()
      }

      val width = bitmap.width
      val height = bitmap.height

      val pixels = IntArray(width * height)
      bitmap.getPixels(pixels, 0, width, 0, 0, width, height)

      // CHW: [3, H, W] - формат для передачи в JavaScript
      val tensor = FloatArray(3 * width * height)

      for (y in 0 until height) {
        for (x in 0 until width) {
          val color = pixels[y * width + x]

          val r = (color shr 16 and 0xFF) / 255f
          val g = (color shr 8 and 0xFF) / 255f
          val b = (color and 0xFF) / 255f

          val idx = y * width + x
          tensor[idx] = r
          tensor[width * height + idx] = g
          tensor[2 * width * height + idx] = b
        }
      }

      // Конвертируем FloatArray в WritableArray для правильной передачи в JavaScript
      // React Native не может напрямую конвертировать FloatArray, поэтому используем WritableArray
      val writableArray = Arguments.createArray()
      for (value in tensor) {
        // pushDouble принимает double, поэтому конвертируем float в double
        writableArray.pushDouble(value.toDouble())
      }
      
      promise.resolve(writableArray)
    } catch (e: Exception) {
      promise.reject("ERROR_DECODING_IMAGE", e)
    }
  }
}

Как и в случае с векторным поиском, ImageDecoderModule регистрируется через отдельный ImageDecoderPackage, который подключается в MainApplication.kt. Без этого React Native просто не увидит модуль. Также необходимо явно указать соответствующие разрешения в AndroidManifest.xml, иначе приложение не сможет получать доступ к изображениям. Все эти детали есть в репозитории проекта, в android-части мобильного приложения.

package com.berkano;

import com.facebook.react.ReactPackage
import com.facebook.react.bridge.NativeModule
import com.facebook.react.bridge.ReactApplicationContext
import com.facebook.react.uimanager.ViewManager

class ImageDecoderPackage : ReactPackage {

  override fun createNativeModules(
    reactContext: ReactApplicationContext
  ): List<NativeModule> {
    return listOf(ImageDecoderModule(reactContext))
  }

  override fun createViewManagers(
    reactContext: ReactApplicationContext
  ): List<ViewManager<*, *>> {
    return emptyList()
  }
}

В итоге весь тяжёлый и чувствительный к памяти код остаётся на нативной стороне, а JavaScript работает только с уже подготовленными данными и результатами инференса.

Теперь запуск самих моделей на мобильном устройстве. Для удобства я создал отдельный файл, который отвечает за подготовку, копирование и инициализацию моделей, а также за выполнение инференса. Сначала при первом запуске проверяется, есть ли файлы модели и сопутствующие данные в локальной файловой системе. Если нет, то они копируются из папки assets (для Android). Логика реализована в функции copyAssetIfNeeded. На Android большие файлы лучше копировать через copyFileAssetse.

Функция prepareModelAssets собирает все необходимые пути к моделям, embeddings и вспомогательным файлам, возвращая объект с путями для последующего использования. Модель ONNX мы инициализируем через initializeModel, где создаётся сессия ONNX с приоритетом NNAPI для GPU/NPU, а CPU используется как fallback.

Инференс выполняется в runInference, где на вход подаётся Float32Array с тензором изображения, уже подготовленным нативными модулями (об этом подробно говорилось в ImageDecoderModule). Функция создаёт тензор формата под модель, например для детекции болезней [1, 3, 224, 224], запускает модель и возвращает нормализованный эмбеддинг.

import * as ort from 'onnxruntime-react-native';
import RNFS from 'react-native-fs';
import { Platform } from 'react-native';

async function copyAssetIfNeeded(assetName: string, folder: string): Promise<string> {
  const localPath = `${RNFS.DocumentDirectoryPath}/${assetName}`;

  const exists = await RNFS.exists(localPath);
  if (!exists) {
    try {
      if (Platform.OS === 'android') {
        // Android: читаем из assets
        // Путь должен быть относительно папки assets: folder/assetName
        const assetPath = `${folder}/${assetName}`;
        console.log(`Attempting to copy from assets: ${assetPath} to ${localPath}`);
        await RNFS.copyFileAssets(assetPath, localPath);
        console.log(`Successfully copied ${assetName} to local FS: ${localPath}`);
      } else {
        // iOS: читаем из main bundle
        const source = `${RNFS.MainBundlePath}/${folder}/${assetName}`;
        console.log(`Attempting to copy from bundle: ${source} to ${localPath}`);
        await RNFS.copyFile(source, localPath);
        console.log(`Successfully copied ${assetName} to local FS: ${localPath}`);
      }
    } catch (err: any) {
      console.error(`Failed to copy ${assetName} from ${folder}/${assetName}:`, err);
      console.error(`Error details:`, {
        message: err?.message,
        code: err?.code,
        platform: Platform.OS,
        localPath,
        assetPath: `${folder}/${assetName}`,
      });
      // Пробуем альтернативный путь без папки (на случай если файл в корне assets)
      if (Platform.OS === 'android') {
        try {
          console.log(`Trying alternative path: ${assetName}`);
          await RNFS.copyFileAssets(assetName, localPath);
          console.log(`Successfully copied using alternative path`);
        } catch (altErr) {
          console.error(`Alternative path also failed:`, altErr);
          throw new Error(`Не удалось скопировать ${assetName}. Проверьте, что файл находится в assets/${folder}/`);
        }
      } else {
        throw err;
      }
    }
  } else {
    console.log(`${assetName} already exists at ${localPath}`);
  }

  return localPath;
}

export async function prepareModelAssets() {
  const modelPath = await copyAssetIfNeeded('disease_detection.onnx', 'models');
  const embeddingsPath = await copyAssetIfNeeded('embeddings.bin', 'files');
  const captionsPath = await copyAssetIfNeeded('captions.json', 'files');
  const classMappingPath = await copyAssetIfNeeded('class_mapping.txt', 'files');
  const speciesMappingPath = await copyAssetIfNeeded('species_id_to_name.txt', 'files');

  return { 
    modelPath, 
    embeddingsPath, 
    captionsPath,
    classMappingPath,
    speciesMappingPath,
  };
}

export interface ModelSession {
  session: ort.InferenceSession;
  inputName: string;
  outputName: string;
}

let modelSession: ModelSession | null = null;

/**
 * Инициализирует ONNX модель с приоритетом GPU/NPU через NNAPI
 */
export async function initializeModel(): Promise<ModelSession> {
  if (modelSession) {
    return modelSession;
  }

  try {
    // Используем путь из prepareModelAssets
    const { modelPath } = await prepareModelAssets();
    
    // Используем локальный путь к модели
    const modelUri = modelPath;
    
    // Настройки сессии с приоритетом NNAPI (GPU/NPU)
    const sessionOptions: ort.InferenceSession.SessionOptions = {
      executionProviders: ['nnapi', 'cpu'], // NNAPI для GPU/NPU, CPU как fallback
      graphOptimizationLevel: 'all',
    };

    // Создаем сессию
    // onnxruntime-react-native может работать с путями к файлам напрямую
    const session = await ort.InferenceSession.create(modelUri, sessionOptions);

    // Получаем имена входов и выходов
    const inputNames = session.inputNames;
    const outputNames = session.outputNames;

    if (inputNames.length === 0 || outputNames.length === 0) {
      throw new Error('Модель не имеет входов или выходов');
    }

    modelSession = {
      session,
      inputName: inputNames[0], // Предполагаем первый вход
      outputName: outputNames[0], // Предполагаем первый выход
    };

    console.log('Модель успешно загружена');
    console.log('Input name:', modelSession.inputName);
    console.log('Output name:', modelSession.outputName);

    return modelSession;
  } catch (error) {
    console.error('Ошибка при загрузке модели:', error);
    throw error;
  }
}

/**
 * Выполняет инференс модели на изображении
 * @param imageTensor Тензор изображения формата [1, 3, 224, 224]
 * @returns Эмбеддинг изображения
 */
export async function runInference(
  imageTensor: Float32Array,
): Promise<Float32Array> {
  if (!modelSession) {
    throw new Error('Модель не инициализирована. Вызовите initializeModel() сначала.');
  }

  try {
    // Создаем тензор для входных данных
    // Формат: [batch, channels, height, width] = [1, 3, 224, 224]
    const inputTensor = new ort.Tensor('float32', imageTensor, [1, 3, 224, 224]);

    // Выполняем инференс
    const feeds = { [modelSession.inputName]: inputTensor };
    const results = await modelSession.session.run(feeds);

    // Получаем выходной тензор
    const outputTensor = results[modelSession.outputName];
    const embedding = outputTensor.data as Float32Array;

    // Нормализуем эмбеддинг
    const norm = Math.sqrt(
      Array.from(embedding).reduce((sum, val) => sum + val * val, 0),
    );
    const normalizedEmbedding = new Float32Array(embedding.length);
    for (let i = 0; i < embedding.length; i++) {
      normalizedEmbedding[i] = embedding[i] / norm;
    }

    return normalizedEmbedding;
  } catch (error) {
    console.error('Ошибка при выполнении инференса:', error);
    throw error;
  }
}

/**
 * Освобождает ресурсы модели
 */
export function disposeModel(): void {
  if (modelSession) {
    modelSession.session.release();
    modelSession = null;
  }
}

Для управления несколькими моделями я сделал modelManager.ts. Он позволяет регистрировать разные модели (disease, plant, age) и управлять их загрузкой в память. В памяти одновременно держатся только 2 модели, чтобы не перегружать устройство; если лимит достигнут, выгружается наименее используемая модель.

import * as ort from 'onnxruntime-react-native';
import RNFS from 'react-native-fs';
import { Platform } from 'react-native';

export type ModelType = 'disease' | 'plant' | 'age';

export interface ModelInfo {
  type: ModelType;
  path: string;
  session: ort.InferenceSession | null;
  inputName: string;
  outputName: string;
  isLoaded: boolean;
  lastUsed: number;
}

// Максимальное количество моделей в памяти одновременно
const MAX_MODELS_IN_MEMORY = 2;

// Кэш моделей
const modelCache: Map<ModelType, ModelInfo> = new Map();

/**
 * Копирует файл из assets в локальную файловую систему
 */
async function copyAssetIfNeeded(assetName: string, folder: string): Promise<string> {
  const localPath = `${RNFS.DocumentDirectoryPath}/${assetName}`;

  const exists = await RNFS.exists(localPath);
  if (!exists) {
    try {
      if (Platform.OS === 'android') {
        const assetPath = `${folder}/${assetName}`;
        console.log(`Copying ${assetName} from assets...`);
        await RNFS.copyFileAssets(assetPath, localPath);
        console.log(`Successfully copied ${assetName}`);
      } else {
        const source = `${RNFS.MainBundlePath}/${folder}/${assetName}`;
        await RNFS.copyFile(source, localPath);
      }
    } catch (err: any) {
      console.error(`Failed to copy ${assetName}:`, err);
      if (Platform.OS === 'android') {
        try {
          await RNFS.copyFileAssets(assetName, localPath);
        } catch (altErr) {
          throw new Error(`Не удалось скопировать ${assetName}`);
        }
      } else {
        throw err;
      }
    }
  }

  return localPath;
}

/**
 * Выгружает наименее используемую модель из памяти
 */
function unloadLeastUsedModel(): void {
  let leastUsed: ModelInfo | null = null;
  let leastUsedTime = Date.now();

  for (const model of modelCache.values()) {
    if (model.isLoaded && model.lastUsed < leastUsedTime) {
      leastUsed = model;
      leastUsedTime = model.lastUsed;
    }
  }

  if (leastUsed) {
    console.log(`Unloading model: ${leastUsed.type}`);
    if (leastUsed.session) {
      leastUsed.session.release();
      leastUsed.session = null;
      leastUsed.isLoaded = false;
    }
  }
}

/**
 * Загружает модель в память
 */
async function loadModel(modelInfo: ModelInfo): Promise<void> {
  if (modelInfo.isLoaded && modelInfo.session) {
    modelInfo.lastUsed = Date.now();
    return;
  }

  // Проверяем, не превышен ли лимит моделей в памяти
  const loadedModels = Array.from(modelCache.values()).filter(m => m.isLoaded);
  if (loadedModels.length >= MAX_MODELS_IN_MEMORY) {
    console.log('Memory limit reached, unloading least used model...');
    unloadLeastUsedModel();
  }

  try {
    console.log(`Loading model: ${modelInfo.type}`);
    
    const sessionOptions: ort.InferenceSession.SessionOptions = {
      executionProviders: ['nnapi', 'cpu'],
      graphOptimizationLevel: 'all',
    };

    const session = await ort.InferenceSession.create(modelInfo.path, sessionOptions);

    const inputNames = session.inputNames;
    const outputNames = session.outputNames;

    if (inputNames.length === 0 || outputNames.length === 0) {
      throw new Error('Модель не имеет входов или выходов');
    }

    modelInfo.session = session;
    modelInfo.inputName = inputNames[0];
    modelInfo.outputName = outputNames[0];
    modelInfo.isLoaded = true;
    modelInfo.lastUsed = Date.now();

    console.log(`Model ${modelInfo.type} loaded successfully`);
  } catch (error) {
    console.error(`Error loading model ${modelInfo.type}:`, error);
    throw error;
  }
}

/**
 * Инициализирует систему управления моделями
 */
export async function initializeModelManager(): Promise<void> {
  // Подготавливаем пути к моделям
  const diseaseModelPath = await copyAssetIfNeeded('disease_detection.onnx', 'models');
  const plantModelPath = await copyAssetIfNeeded('plant_classification.onnx', 'models');
  const ageModelPath = await copyAssetIfNeeded('plant_analysis.onnx', 'models');

  // Регистрируем модели
  modelCache.set('disease', {
    type: 'disease',
    path: diseaseModelPath,
    session: null,
    inputName: '',
    outputName: '',
    isLoaded: false,
    lastUsed: 0,
  });

  modelCache.set('plant', {
    type: 'plant',
    path: plantModelPath,
    session: null,
    inputName: '',
    outputName: '',
    isLoaded: false,
    lastUsed: 0,
  });

  modelCache.set('age', {
    type: 'age',
    path: ageModelPath,
    session: null,
    inputName: '',
    outputName: '',
    isLoaded: false,
    lastUsed: 0,
  });

  console.log('Model manager initialized');
}

/**
 * Получает модель (загружает если нужно)
 */
export async function getModel(modelType: ModelType): Promise<ModelInfo> {
  const modelInfo = modelCache.get(modelType);
  if (!modelInfo) {
    throw new Error(`Model ${modelType} not found`);
  }

  await loadModel(modelInfo);
  return modelInfo;
}

/**
 * Выгружает модель из памяти
 */
export function unloadModel(modelType: ModelType): void {
  const modelInfo = modelCache.get(modelType);
  if (modelInfo && modelInfo.session) {
    console.log(`Unloading model: ${modelType}`);
    modelInfo.session.release();
    modelInfo.session = null;
    modelInfo.isLoaded = false;
  }
}

/**
 * Выгружает все модели
 */
export function unloadAllModels(): void {
  for (const modelType of modelCache.keys()) {
    unloadModel(modelType);
  }
}

/**
 * Проверяет, загружена ли модель
 */
export function isModelLoaded(modelType: ModelType): boolean {
  const modelInfo = modelCache.get(modelType);
  return modelInfo?.isLoaded ?? false;
}

Для загрузки используется тот же подход с onnxruntime-react-native, после получения modelInfo через getModel modelInfo = await getModel('age') создаётся тензор изображения inputTensor = new ort.Tensor('float32', imageTensor, [1, 3, n, n]), подаётся на вход модели, и сразу возвращаются выходные тензоры.

const feeds = { [modelInfo.inputName]: inputTensor };
const results = await modelInfo.session.run(feeds);

В примерах, доступных в plantClassificationService.ts и diseaseSearchService.ts, есть реализация подачи эмбеддингов, нормализации и поиска по векторам с использованием Faiss, и это полезно не только для этого приложения, но и для любых проектов с большими моделями и векторным поиском на мобильных устройствах.

Рад, что вы дошли до конца этой нудятины про растения. Надеюсь, было интересно.
Рад, что вы дошли до конца этой нудятины про растения. Надеюсь, было интересно.

В конце пара слов о самом проекте и его смысле. На первый взгляд, это просто нейросети для растений, запуск на мобилке и немного оптимизаций. Но на деле это эксперимент о том, как переносить сложные модели на ограниченные устройства, как управлять памятью, как строить унифицированный код для разных типов задач, который я написал не только для того, кто прочитает, но и как заметки для себя.

Если вы дошли до конца этой статьи, значит вы любите разбираться в деталях и не боитесь сложностей. Проект открыт под лицензией MIT, весь код на GitHub, так что каждый может попробовать, дообучить модели, оптимизировать или предложить что-то новое. И наконец, приглашаю всех желающих присоединиться к контрибьютерам: каждая идея, каждая оптимизация или небольшая фича делают проект лучше и интереснее. Иногда именно небольшая попытка приводит к неожиданным открытиям, и этот проект именно про эксперименты и поиск таких открытий.

Комментарии (0)