Я недавно изучал примеры использования нейронных сетей из библиотеки Candle от Hugging Face и обратил внимание, что они довольно сложны для понимания людей, которые только начинают знакомство с нейросетями. Поэтому я решил написать максимально упрощенный пример кода на Rust, который демонстрирует обучение и использование простейшей нейросети.

В моем примере нейросеть пытается предсказать, выиграет ли первый кандидат на втором этапе голосования, исходя из результатов первого тура.

Конечно, предсказательная сила такой простой модели невелика, но я считаю, что этот пример хорошо демонстрирует основные принципы обучения и использования нейронных сетей в самом упрощенном виде. Надеюсь, эта статья поможет новичкам в области машинного обучения лучше разобраться в работе нейросетей!

  1. Мной используется многослойный перцептрон с двумя скрытыми слоями. Первый скрытый слой имеет 4 нейрона, второй - 2 нейрона.

  2. На вход подается вектор из 2 чисел - проценты голосов за первого и второго кандидатов на первом этапе.

  3. На выходе - число 0 или 1, где 1 означает, что первый кандидат выиграет на втором этапе, 0 - что проиграет.

  4. Для обучения используются выборки с "реальными" данными о результатах первого и второго этапов разных выборов.

  5. Модель обучается методом обратного распространения ошибки с использованием градиентного спуска и функции потерь cross-entropy.

  6. Параметры модели (веса нейронов) инициализируются случайно, затем оптимизируются в процессе обучения.

  7. После обучения модель тестируется на отложенной выборке для оценки точности.

  8. Если точность на тестовой выборке ниже 100%, модель считается недообученной и процесс обучения повторяется.

Таким образом, эта нейросеть учится находить скрытые зависимости между результатами первого и второго туров голосования, чтобы делать прогнозы для новых данных.

Как это выглядит в коде:

const VOTE_DIM: usize = 2;
const RESULTS: usize = 1;
const EPOCHS: usize = 20;
const LAYER1_OUT_SIZE: usize = 4;
const LAYER2_OUT_SIZE: usize = 2;
const LEARNING_RATE: f64 = 0.05;

#[derive(Clone)]
pub struct Dataset {
    pub train_votes: Tensor,
    pub train_results: Tensor,
    pub test_votes: Tensor,
    pub test_results: Tensor,
}

struct MultiLevelPerceptron {
    ln1: Linear,
    ln2: Linear,
    ln3: Linear,
}

impl MultiLevelPerceptron {
    fn new(vs: VarBuilder) -> Result<Self> {
        let ln1 = candle_nn::linear(VOTE_DIM, LAYER1_OUT_SIZE, vs.pp("ln1"))?;
        let ln2 = candle_nn::linear(LAYER1_OUT_SIZE, LAYER2_OUT_SIZE, vs.pp("ln2"))?;
        let ln3 = candle_nn::linear(LAYER2_OUT_SIZE, RESULTS + 1, vs.pp("ln3"))?;
        Ok(Self { ln1, ln2, ln3 })
    }

    fn forward(&self, xs: &Tensor) -> Result<Tensor> {
        let xs = self.ln1.forward(xs)?;
        let xs = xs.relu()?;
        let xs = self.ln2.forward(&xs)?;
        let xs = xs.relu()?;
        self.ln3.forward(&xs)
    }
}


pub fn main() -> anyhow::Result<()> {
    let dev = Device::cuda_if_available(0)?;

    let train_votes_vec: Vec<u32> = vec![
        15, 10,
        10, 15,
        5, 12,
        30, 20,
        16, 12,
        13, 25,
        6, 14,
        31, 21,
    ];
    let train_votes_tensor = Tensor::from_vec(train_votes_vec.clone(), (train_votes_vec.len() / VOTE_DIM, VOTE_DIM), &dev)?.to_dtype(DType::F32)?;

    let train_results_vec: Vec<u32> = vec![
        1,
        0,
        0,
        1,
        1,
        0,
        0,
        1,
    ];
    let train_results_tensor = Tensor::from_vec(train_results_vec, train_votes_vec.len() / VOTE_DIM, &dev)?;

    let test_votes_vec: Vec<u32> = vec![
        13, 9,
        8, 14,
        3, 10,
    ];
    let test_votes_tensor = Tensor::from_vec(test_votes_vec.clone(), (test_votes_vec.len() / VOTE_DIM, VOTE_DIM), &dev)?.to_dtype(DType::F32)?;

    let test_results_vec: Vec<u32> = vec![
        1,
        0,
        0,
    ];
    let test_results_tensor = Tensor::from_vec(test_results_vec.clone(), test_results_vec.len(), &dev)?;

    let m = Dataset {
        train_votes: train_votes_tensor,
        train_results: train_results_tensor,
        test_votes: test_votes_tensor,
        test_results: test_results_tensor,
    };

    let trained_model: MultiLevelPerceptron;
    loop {
        println!("Trying to train neural network.");
        match train(m.clone(), &dev) {
            Ok(model) => {
                trained_model = model;
                break;
            },
            Err(e) => {
                println!("Error: {:?}", e);
                continue;
            }
        }

    }

    let real_world_votes: Vec<u32> = vec![
        13, 22,
    ];

    let tensor_test_votes = Tensor::from_vec(real_world_votes.clone(), (1, VOTE_DIM), &dev)?.to_dtype(DType::F32)?;

    let final_result = trained_model.forward(&tensor_test_votes)?;

    let result = final_result
        .argmax(D::Minus1)?
        .to_dtype(DType::F32)?
        .get(0).map(|x| x.to_scalar::<f32>())??;
    println!("real_life_votes: {:?}", real_world_votes);
    println!("neural_network_prediction_result: {:?}", result);

    Ok(())
}

Результат работы программы:

Trying to train neural network.
Epoch:   1 Train loss:  4.42555 Test accuracy:  0.00%
Epoch:   2 Train loss:  0.84677 Test accuracy: 33.33%
Epoch:   3 Train loss:  2.54335 Test accuracy: 33.33%
Epoch:   4 Train loss:  0.37806 Test accuracy: 33.33%
Epoch:   5 Train loss:  0.36647 Test accuracy: 100.00%
real_life_votes: [13, 22]
neural_network_prediction_result: 0.0

Ну и конечно код функции где происходит обучение:

fn train(m: Dataset, dev: &Device) -> anyhow::Result<MultiLevelPerceptron> {
    let train_results = m.train_results.to_device(dev)?;
    let train_votes = m.train_votes.to_device(dev)?;
    let varmap = VarMap::new();
    let vs = VarBuilder::from_varmap(&varmap, DType::F32, dev);
    let model = MultiLevelPerceptron::new(vs.clone())?;
    let sgd = candle_nn::SGD::new(varmap.all_vars(), LEARNING_RATE);
    let test_votes = m.test_votes.to_device(dev)?;
    let test_results = m.test_results.to_device(dev)?;
    let mut final_accuracy: f32 = 0.0;
    for epoch in 1..EPOCHS+1 {
        let logits = model.forward(&train_votes)?;
        let log_sm = ops::log_softmax(&logits, D::Minus1)?;
        let loss = loss::nll(&log_sm, &train_results)?;
        sgd.backward_step(&loss)?;

        let test_logits = model.forward(&test_votes)?;
        let sum_ok = test_logits
            .argmax(D::Minus1)?
            .eq(&test_results)?
            .to_dtype(DType::F32)?
            .sum_all()?
            .to_scalar::<f32>()?;
        let test_accuracy = sum_ok / test_results.dims1()? as f32;
        final_accuracy = 100. * test_accuracy;
        println!("Epoch: {epoch:3} Train loss: {:8.5} Test accuracy: {:5.2}%",
                 loss.to_scalar::<f32>()?,
                 final_accuracy
        );
        if final_accuracy == 100.0 {
            break;
        }
    }
    if final_accuracy < 100.0 {
        Err(anyhow::Error::msg("The model is not trained well enough."))
    } else {
        Ok(model)
    }
}

В заключение хотелось бы отметить, что данный пример действительно наглядно демонстрирует основные принципы работы нейронных сетей. Несмотря на простоту модели, код показывает пошагово процесс инициализации, обучения, тестирования и использования нейросети для решения практической задачи предсказания результатов голосования.

Конечно, для реальных прикладных задач потребуются более сложные архитектуры нейронных сетей. Однако этот пример отлично подходит для изучения основ, после чего можно будет переходить к более продвинутым моделям и решениям.

В целом данная статья достигает поставленной цели - дать понимание базовых принципов нейросетей разработчикам, которые только начинают знакомство с машинным обучением. Код написан максимально просто и наглядно. Этот пример может послужить хорошей отправной точкой для дальнейшего изучения возможностей нейронных сетей.

Исходный код, который компилируется и запускается доступен по ссылке.

Комментарии (4)


  1. Gorthauer87
    03.09.2023 20:18

    А в чем сакральный смысл писать модель именно на Расте? Ведь большинство mlщиков его даже не знают, тем более тут в коде сильные стороны раста вообще будто бы не используются.


    1. programmerjava
      03.09.2023 20:18
      +1

      Как не используются ? Память за него сама освобождается : )


    1. Draugdor
      03.09.2023 20:18
      +4

      Понятно что целевая аудитория не большая будет у такой статьи, но в меня она попала, пишу на rust системщину, но и тема ml интересна, а на python не люблю писать, фреймворком заинтересовался буду изучать.


    1. Shado_vi
      03.09.2023 20:18
      +1

      насчет конкретно писать не знаю, но при использовании моделей может быть профит.


      есть статья https://guillaume-be.github.io/2020-11-21/generation_benchmarks


      Rust может значительно ускорить вспомогательные операции

      это относится в общем к более производительным языкам а не только к rust.