NeuroWhAI의 잡블로그

[Rust] 경사 하강법 - '밑바닥부터 시작하는 딥러닝' 4장 본문

개발 및 공부/알고리즘

[Rust] 경사 하강법 - '밑바닥부터 시작하는 딥러닝' 4장

NeuroWhAI 2018. 7. 15. 12:35


※ 실제로 동작하는 전체 소스코드는 GitHub에서 보실 수 있습니다.


이전에 신경망의 학습은 손실 함수의 값을 최소화하는 방향으로 파라미터를 조정한다고 했습니다.
그럼 파라미터를 어떻게 조절해야 손실이 줄어드는지 알 수 있을까요?
저는 미적분을 제대로 배우지 않았지만 미분은 값의 변화와 관련이 있다는건 압니다.
그러니까 L = f(x)이고 x가 d만큼 변했을때 L은 d의 몇 배만큼 변하는지가 미분이라고 대충 알고 있습니다.
여기서 L을 손실, f를 손실 함수 + 신경망, x를 파라미터라고 하면 대충 윤곽이 보이죠.
우리는 f 함수를 미분하여 기울기를 구해 x를 올바른 방향으로 조절할 수 있습니다.
그런데 저 같은 수포자에게 함수를 미분하라는건 힘든 일입니다.
그러나 수치 미분이란 방법을 쓰면 해석적 미분보다 쉽게 기울기를 구할 수 있습니다.

아래 코드는 수치 미분을 사용해서 2개의 파라미터를 가진 함수(function_2)의 값을 최소화하는 예제입니다.

코드:
use rulinalg::matrix::{Matrix, BaseMatrix, BaseMatrixMut};

pub fn numerical_gradient<F>(f: F, x: &mut Matrix<f32>) -> Matrix<f32>
    where F: Fn(&Matrix<f32>) -> f32 {
    
    let h = 1e-4;
    let mut grad = Matrix::<f32>::zeros(x.rows(), x.cols());
    
    for it in x.iter_mut().zip(grad.iter_mut()) {
        let (v, g) = it;
    
        let bak = *v;
        
        *v = bak + h;
        let fxh1 = f(x);
        
        *v = bak - h;
        let fxh2 = f(x);
        
        *g = (fxh1 - fxh2) / (2.0 * h);
        *v = bak;
    }
    
    grad
}

pub fn gradient_descent(f: fn(&Matrix<f32>) -> f32, init_x: &Matrix<f32>,
    lr: f32, step_num: usize) -> (Matrix<f32>, Vec<f32>) {
    
    let mut x = Matrix::new(init_x.rows(), init_x.cols(),
        init_x.iter().map(|v| *v).collect::<Vec<_>>());
    let mut history = Vec::new();
        
    for _ in 0..step_num {
        let grad = numerical_gradient(f, &mut x);
        x -= grad * lr;
        
        history.push(f(&x));
    }
    
    (x, history)
}
fn function_2(x: &Matrix<f32>) -> f32 {
    let x1 = x[[0, 0]];
    let x2 = x[[0, 1]];
    return x1 * x1 + x2 * x2;
}

fn test_gradient() {
    let init_x = matrix![-3.0, 4.0f32];
    
    println!("x = {}", init_x);
    println!("f(x) = {}", function_2(&init_x));
    
    let lr = 0.1;
    let step_num = 20;
    
    let (x, history) = gradient::gradient_descent(function_2, &init_x, lr, step_num);
    
    plot::print_graph(&history[..], 50, 20);
    
    println!("x = {}", x);
    println!("f(x) = {}", function_2(&x));
}

결과:
x = [-3  4]
f(x) = 25
-





 -




  -


   -
    -
     -
      -
       ------------
                   -
x = [-0.034610715   0.04615057]
f(x) = 0.0033277767

보시다시피 기울기로 x를 조절하여 f(x1, x2)의 값이 25에서 0에 가까운 수로 줄어들었습니다.




Comments