• Что бы вступить в ряды "Принятый кодер" Вам нужно:
    Написать 10 полезных сообщений или тем и Получить 10 симпатий.
    Для того кто не хочет терять время,может пожертвовать средства для поддержки сервеса, и вступить в ряды VIP на месяц, дополнительная информация в лс.

  • Пользаватели которые будут спамить, уходят в бан без предупреждения. Спам сообщения определяется администрацией и модератором.

  • Гость, Что бы Вы хотели увидеть на нашем Форуме? Изложить свои идеи и пожелания по улучшению форума Вы можете поделиться с нами здесь. ----> Перейдите сюда
  • Все пользователи не прошедшие проверку электронной почты будут заблокированы. Все вопросы с разблокировкой обращайтесь по адресу электронной почте : info@guardianelinks.com . Не пришло сообщение о проверке или о сбросе также сообщите нам.

Neural Network in Rust on MNIST dataset from scratch

Lomanu4 Оффлайн

Lomanu4

Команда форума
Администратор
Регистрация
1 Мар 2015
Сообщения
1,481
Баллы
155
Implement and train a neural network from scratch on MNIST dataset in Rust without using high-level libraries like TensorFlow or PyTorch.

You can find the code at:

Пожалуйста Авторизируйтесь или Зарегистрируйтесь для просмотра скрытого текста.



It demonstrates:

  • Manual forward and backward propagation
  • Use of ReLU and softmax activation functions
  • One-hot encoding
  • Gradient descent for training
  • Accuracy evaluation
  • Model parameter export to CSV using polars
? Dependencies

  • ndarray (store 2d array of data)
  • ndarray-rand (generate intial random weights(w) and biases(b))
  • polars (to read write data in csv)
? Model Overview

  • input layer, 1 hidden layer, output layer
  • Input: 784-dimensional MNIST images
  • Hidden layer: 10 neurons with ReLU as activation function
  • Output layer: 10 neurons with softmax as activation function for multi-class classification
? Structure

  • main.rs: Training loop and evaluation
  • lib.rs: Core model logic — forward, backward, update, softmax, etc.
  • mnistdata/: Contains input dataset
? Dataset


Make sure the MNIST dataset is placed in mnistdata/.

Prerequisite:



Пожалуйста Авторизируйтесь или Зарегистрируйтесь для просмотра скрытого текста.



Intialization of data using polars crate:


pub fn load_training_data() -> Result<(Array2<f32>, Array2<f32>), Box<dyn Error>> {
let q = LazyCsvReader::new("./mnistdata/mnist_train.csv")
.with_has_header(true)
.finish()?;

let training_labels = q
.clone()
.with_streaming(true)
.select([col("label")])
.collect()?;

let training_data = q
.clone()
.with_streaming(true)
.drop([col("label")])
.collect()?;

let mut traning_data_ndarray = training_data
.to_ndarray::<Float32Type>(IndexOrder::Fortran)
.unwrap();
let mut training_labels_ndarray = training_labels
.to_ndarray::<Float32Type>(IndexOrder::Fortran)
.unwrap();

traning_data_ndarray = traning_data_ndarray.reversed_axes()/ 255.0;
training_labels_ndarray = training_labels_ndarray.reversed_axes();

let data_dimensions:&[usize] = traning_data_ndarray.shape();
let labels_dimensions:&[usize] = training_labels_ndarray.shape();

// println!("{}", traning_data_ndarray);
// println!("{}", training_labels_ndarray);
println!("DATA: {}, {}", data_dimensions[0], data_dimensions[1]);
println!("LABELS: {}, {}", labels_dimensions[0], labels_dimensions[1]);
Ok((traning_data_ndarray, training_labels_ndarray))
}
What is Neural Network and how to use it to recognize digits:



Пожалуйста Авторизируйтесь или Зарегистрируйтесь для просмотра скрытого текста.




Пожалуйста Авторизируйтесь или Зарегистрируйтесь для просмотра скрытого текста.



Our Approach for Neural Network:



Пожалуйста Авторизируйтесь или Зарегистрируйтесь для просмотра скрытого текста.



Declaring intial weights and biasis using ndarray_rand crate:


pub fn init_params()->(Array2<f32>, Array2<f32>, Array2<f32>, Array2<f32>){

let w1 = Array2::random((10, 784), Uniform::new(-0.5, 0.5));
let b1 = Array2::random((10, 1), Uniform::new(-0.5, 0.5));
let w2 = Array2::random((10, 10), Uniform::new(-0.5, 0.5));
let b2 = Array2::random((10, 1), Uniform::new(-0.5, 0.5));

(w1,b1,w2,b2)
}
Forward Propagation:



Пожалуйста Авторизируйтесь или Зарегистрируйтесь для просмотра скрытого текста.





pub fn relu(z:&mut Array2<f32>){
z.mapv_inplace(|x| x.max(0.0))
}

pub fn softmax(z: &mut Array2<f32>) {
for mut col in z.axis_iter_mut(Axis(1)) {
// Subtract max for numerical stability
let max = col.iter().cloned().fold(f32::NEG_INFINITY, f32::max);
col.mapv_inplace(|x| (x - max).exp());

let sum = col.sum();
col.mapv_inplace(|x| x / sum);
}
}

pub fn forward_propagation(
w1:&mut Array2<f32>,
b1:&mut Array2<f32>,
w2:&mut Array2<f32>,
b2:&mut Array2<f32>,
x:&mut Array2<f32>
) -> (Array2<f32>, Array2<f32>, Array2<f32>, Array2<f32>)
{
let z1 = w1.dot(x) + &*b1;
let mut a1 = z1.clone();
relu(&mut a1);

let z2 = w2.dot(&a1) + &*b2;
let mut a2 = z2.clone();
softmax(&mut a2);

(z1,a1,z2,a2)
}
Back Propagation:



Пожалуйста Авторизируйтесь или Зарегистрируйтесь для просмотра скрытого текста.




pub fn one_hot_encoded(y:&mut Array2<f32>, num_classes:usize) -> Array2<f32> {
let ydash= y.flatten();
let label_dimensions:&[usize] = ydash.shape();
let mut one_hot_y = Array2::<f32>::zeros((label_dimensions[0], num_classes));

for (row, &label) in ydash.iter().enumerate(){
let class_index = label as usize;
one_hot_y[(row, class_index)] = 1.0;
}
one_hot_y.reversed_axes()
}

pub fn deriv_relu(z:&mut Array2<f32>){
z.mapv_inplace(|x| if x > 0.0 { 1.0 } else { 0.0 });
}

pub fn backward_propagation(
z1:&mut Array2<f32>,
a1:&mut Array2<f32>,
a2:&mut Array2<f32>,
w2:&mut Array2<f32>,
x:&mut Array2<f32>,
y:&mut Array2<f32>,
)->(Array2<f32>, Array2<f32>, Array2<f32>, Array2<f32>){
let m = y.len() as f32;
let a1t = a1.view().reversed_axes();
let w2t = w2.view().reversed_axes();
let xt = x.view().reversed_axes();
let one_hot_y = one_hot_encoded(y, 10);

let dz2 = &*a2 - &one_hot_y;
let dw2 = (1.0/m)*(dz2.dot(&a1t));
let db2 = dz2.sum_axis(Axis(1)).insert_axis(Axis(1)) * (1.0 / m);

let mut z1_deriv = z1.clone();
deriv_relu(&mut z1_deriv);
let dz1 = w2t.dot(&dz2)*z1_deriv;
let dw1 = (1.0/m)*(dz1.dot(&xt));
let db1 = dz1.sum_axis(Axis(1)).insert_axis(Axis(1)) * (1.0 / m);

(dw1, db1, dw2, db2)
}
Update weights and biasis:



Пожалуйста Авторизируйтесь или Зарегистрируйтесь для просмотра скрытого текста.




pub fn update_params(
w1: &mut Array2<f32>,
b1: &mut Array2<f32>,
w2: &mut Array2<f32>,
b2: &mut Array2<f32>,
dw1: &Array2<f32>,
db1: &Array2<f32>,
dw2: &Array2<f32>,
db2: &Array2<f32>,
alpha: f32,
) {
*w1 -= &(alpha * dw1);
*b1 -= &(alpha * db1);
*w2 -= &(alpha * dw2);
*b2 -= &(alpha * db2);
}
Function to find accuracy of our model:


pub fn get_accuracy(predictions: &Array2<f32>, labels: &Array2<f32>) -> f32 {
let pred_classes: Array1<usize> = predictions
.axis_iter(Axis(1))
.map(|col| {
col.iter()
.enumerate()
.max_by(|a, b| a.1.partial_cmp(b.1).unwrap())
.unwrap()
.0
})
.collect();

let true_classes: Array1<usize> = labels.iter().map(|x| *x as usize).collect();

let correct = pred_classes
.iter()
.zip(true_classes.iter())
.filter(|(pred, truth)| pred == truth)
.count();

correct as f32 / labels.len() as f32
}
Finally using all functions in main:


fn main() -> Result<(), Box<dyn Error>> {
let (mut training_data, mut training_label) = load_training_data()?;
let (mut w1,mut b1, mut w2, mut b2) = init_params();

let iterations = 501;
let alpha = 0.1;
println!("{}", training_label);

for i in 0..iterations{
let (mut z1, mut a1, mut z2, mut a2) = forward_propagation(&mut w1, &mut b1, &mut w2, &mut b2, &mut training_data);
let (dw1, db1, dw2, db2) = backward_propagation(&mut z1, &mut a1, &mut a2, &mut w2, &mut training_data, &mut training_label);
update_params(&mut w1, &mut b1, &mut w2, &mut b2, &dw1, &db1, &dw2, &db2, alpha);
if i%50 == 0{
println!("Iteration: {}", i);
let acc = get_accuracy(&a2, &training_label);
println!("Accuracy: {:.2}%", acc * 100.0);
}
}

Ok(())
}
Results for 200 iterations and learning rate = 0.1


DATA: 784, 60000
LABELS: 1, 60000
[[5, 0, 4, 1, 9, ..., 8, 3, 5, 6, 8]]

Iteration: 0
Accuracy: 10.86%

Iteration: 50
Accuracy: 56.97%

Iteration: 100
Accuracy: 69.91%

Iteration: 150
Accuracy: 75.45%

Iteration: 200
Accuracy: 78.56%
Results for 500 iterations and learning rate = 0.1


DATA: 784, 60000
LABELS: 1, 60000
[[5, 0, 4, 1, 9, ..., 8, 3, 5, 6, 8]]

Iteration: 0
Accuracy: 12.46%

Iteration: 50
Accuracy: 47.05%

Iteration: 100
Accuracy: 61.53%

Iteration: 150
Accuracy: 69.01%

Iteration: 200
Accuracy: 73.28%

Iteration: 250
Accuracy: 76.48%

Iteration: 300
Accuracy: 78.93%

Iteration: 350
Accuracy: 80.81%

Iteration: 400
Accuracy: 82.38%

Iteration: 450
Accuracy: 83.53%

Iteration: 500
Accuracy: 84.48%

Medium Blog:

Пожалуйста Авторизируйтесь или Зарегистрируйтесь для просмотра скрытого текста.





Пожалуйста Авторизируйтесь или Зарегистрируйтесь для просмотра скрытого текста.

 
Вверх Снизу