NeuroWhAI의 잡블로그

[Rust] 계산 그래프 - '밑바닥부터 시작하는 딥러닝' 5장 본문

개발 및 공부/알고리즘

[Rust] 계산 그래프 - '밑바닥부터 시작하는 딥러닝' 5장

NeuroWhAI 2018. 7. 16. 19:46


※ 실제로 동작하는 전체 소스코드는 GitHub에서 보실 수 있습니다.


5장에서는 계산 그래프라는 방법을 통해 순전파, 역전파를 가르쳐줍니다.
연쇄 법칙(Chain rule) 덕분에 각 층의 자체(?) 기울기를 곱하면서 역전파하면 각 층의 최종 기울기를 얻을 수 있습니다.
아래 코드는 간단하게 덧셈, 곱셈 레이어를 만들고 이를 사용하여 사과, 오렌지의 가격을 순전파로 계산한 다음
역전파로 각 요소가 최종 가격에 끼치는 기울기를 계산하는 예제입니다.

코드:
pub struct MulLayer {
    x: f32,
    y: f32,
}

impl MulLayer {
    pub fn new() -> Self {
        MulLayer {
            x: 0.0,
            y: 0.0
        }
    }
    
    pub fn forward(&mut self, x: f32, y: f32) -> f32 {
        self.x = x;
        self.y = y;
        
        x * y
    }
    
    pub fn backward(&self, dout: f32) -> (f32, f32) {
        (dout * self.y, dout * self.x)
    }
}


pub struct AddLayer {}

impl AddLayer {
    pub fn new() -> Self {
        AddLayer {}
    }
    
    pub fn forward(&mut self, x: f32, y: f32) -> f32 {
        x + y
    }
    
    pub fn backward(&self, dout: f32) -> (f32, f32) {
        (dout, dout)
    }
}
fn test_layer_naive() {
    use self::layer_naive::{MulLayer, AddLayer};

    let apple = 100.0;
    let apple_num = 2.0;
    let orange = 150.0;
    let orange_num = 3.0;
    let tax = 1.1;
    
    // Layer
    let mut mul_apple_layer = MulLayer::new();
    let mut mul_orange_layer = MulLayer::new();
    let mut add_all_layer = AddLayer::new();
    let mut mul_tax_layer = MulLayer::new();
    
    // Forward
    let apple_price = mul_apple_layer.forward(apple, apple_num);
    let orange_price = mul_orange_layer.forward(orange, orange_num);
    let all_price = add_all_layer.forward(apple_price, orange_price);
    let price = mul_tax_layer.forward(all_price, tax);
    
    println!("Total apple price: {}", apple_price);
    println!("Total orange price: {}", orange_price);
    println!("Total price without tax: {}", all_price);
    println!("Total price with tax: {}", price);
    
    // Backward
    let d_price = 1.0;
    let (d_all_price, d_tax) = mul_tax_layer.backward(d_price);
    let (d_apple_price, d_orange_price) = add_all_layer.backward(d_all_price);
    let (d_orange, d_orange_num) = mul_orange_layer.backward(d_orange_price);
    let (d_apple, d_apple_num) = mul_apple_layer.backward(d_apple_price);
    
    println!("dApple: {}", d_apple);
    println!("dApple_num: {}", d_apple_num);
    println!("dOrange: {}", d_orange);
    println!("dOrange_num: {}", d_orange_num);
    println!("dTax: {}", d_tax);
}

결과:
Total apple price: 200
Total orange price: 450
Total price without tax: 650
Total price with tax: 715
dApple: 2.2
dApple_num: 110
dOrange: 3.3000002
dOrange_num: 165
dTax: 650

결과에서 dTax는 부과세가 1원 증가하면 최종 가격은 650원 증가한다는 의미입니다.
같은 이야기로 dApple_num은 사과 개수가 1개 증가하면 최종 가격은 110원 증가한다는 의미입니다.
이렇게 각 요소별로 최종 가격에 대한 기울기를 계산할 수 있음을 확인했습니다!




Comments