NeuroWhAI의 잡블로그

[Rust] 경사 하강법 본문

개발 및 공부/알고리즘

[Rust] 경사 하강법

NeuroWhAI 2018. 1. 3. 18:41



1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
use std::fmt;
 
struct LinearFunction {
    a: f64,
    b: f64,
}
 
impl LinearFunction {
    fn new(a: f64, b: f64-> LinearFunction {
        LinearFunction {
            a: a,
            b: b,
        }
    }
 
    fn feed(&self, x: f64-> f64 {
        self.a * x + self.b
    }
 
    fn learn(&mut self, input: f64, target_output: f64, learning_rate: f64-> f64 {
        let error = target_output - self.feed(input);
        
        self.a -= error * -input * learning_rate;
        self.b -= error * -1.0 * learning_rate;
        
        return error;
    }
}
 
impl fmt::Display for LinearFunction {
    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
        write!(f, "f(x) = {} * x + {}"self.a, self.b)
    }
}
 
fn main() {
    let mut fun = LinearFunction::new(0.00.0);
    
    println!("Before");
    
    let data_list = [
        (0.01.0),
        (1.03.0),
        (2.05.0),
        (3.07.0),
    ];
    
    let learning_rate = 0.01 / data_list.len() as f64;
    
    for epoch in 0..10000 {
        let mut mse = 0_f64;
    
        for data in data_list.iter() {
            let error = fun.learn(data.0, data.1, learning_rate);
            mse += error * error;
        }
        
        if epoch % 1000 == 0 {
            println!("{}", fun);
            println!("Error : {}", mse / data_list.len() as f64);
        }
    }
    
    println!("After");
    println!("{}", fun);
}
cs

실행 결과 : https://ideone.com/qzzXqL

Before
f(x) = 0.0839723100390625 * x + 0.039632452304687496
Error : 20.519144377863604
f(x) = 2.0011632736774065 * x + 0.9974990391031798
Error : 0.0000022823391159126393
f(x) = 2.000058928111989 * x + 0.9998733084856376
Error : 0.000000005856813909427799
f(x) = 2.0000029851293366 * x + 0.9999935821708257
Error : 0.000000000015029435769872653
f(x) = 2.0000001512181007 * x + 0.9999996748911606
Error : 0.00000000000003856771653562187
f(x) = 2.000000007660276 * x + 0.999999983530919
Error : 0.0000000000000000989703538043435
f(x) = 2.0000000003880474 * x + 0.9999999991657231
Error : 0.00000000000000000025397248558118256
f(x) = 2.0000000000196576 * x + 0.9999999999577371
Error : 0.0000000000000000000006517520649855388
f(x) = 2.000000000000997 * x + 0.9999999999978586
Error : 0.0000000000000000000000016736991178990983
f(x) = 2.0000000000000444 * x + 0.9999999999998948
Error : 0.000000000000000000000000003977942048141749
After

f(x) = 2.000000000000026 * x + 0.9999999999999513


가장 기초적인걸로 해봤습니다.

f(x) = a*x + b 형태의 방정식에서 {f(x), x}를 알때 a, b를 구하는 겁니다.

위 식을 a에 대해 편미분하고 b에 대해 편미분해서 얻은 기울기 값을 이용하면 쉽게 계산할 수 있습니다.




Comments