Notice
Recent Posts
Recent Comments
NeuroWhAI의 잡블로그
[Rust] 경사 하강법 본문
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 | use std::fmt; struct LinearFunction { a: f64, b: f64, } impl LinearFunction { fn new(a: f64, b: f64) -> LinearFunction { LinearFunction { a: a, b: b, } } fn feed(&self, x: f64) -> f64 { self.a * x + self.b } fn learn(&mut self, input: f64, target_output: f64, learning_rate: f64) -> f64 { let error = target_output - self.feed(input); self.a -= error * -input * learning_rate; self.b -= error * -1.0 * learning_rate; return error; } } impl fmt::Display for LinearFunction { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { write!(f, "f(x) = {} * x + {}", self.a, self.b) } } fn main() { let mut fun = LinearFunction::new(0.0, 0.0); println!("Before"); let data_list = [ (0.0, 1.0), (1.0, 3.0), (2.0, 5.0), (3.0, 7.0), ]; let learning_rate = 0.01 / data_list.len() as f64; for epoch in 0..10000 { let mut mse = 0_f64; for data in data_list.iter() { let error = fun.learn(data.0, data.1, learning_rate); mse += error * error; } if epoch % 1000 == 0 { println!("{}", fun); println!("Error : {}", mse / data_list.len() as f64); } } println!("After"); println!("{}", fun); } | cs |
실행 결과 : https://ideone.com/qzzXqL
Beforef(x) = 0.0839723100390625 * x + 0.039632452304687496
Error : 20.519144377863604
f(x) = 2.0011632736774065 * x + 0.9974990391031798
Error : 0.0000022823391159126393
f(x) = 2.000058928111989 * x + 0.9998733084856376
Error : 0.000000005856813909427799
f(x) = 2.0000029851293366 * x + 0.9999935821708257
Error : 0.000000000015029435769872653
f(x) = 2.0000001512181007 * x + 0.9999996748911606
Error : 0.00000000000003856771653562187
f(x) = 2.000000007660276 * x + 0.999999983530919
Error : 0.0000000000000000989703538043435
f(x) = 2.0000000003880474 * x + 0.9999999991657231
Error : 0.00000000000000000025397248558118256
f(x) = 2.0000000000196576 * x + 0.9999999999577371
Error : 0.0000000000000000000006517520649855388
f(x) = 2.000000000000997 * x + 0.9999999999978586
Error : 0.0000000000000000000000016736991178990983
f(x) = 2.0000000000000444 * x + 0.9999999999998948
Error : 0.000000000000000000000000003977942048141749
After
f(x) = 2.000000000000026 * x + 0.9999999999999513
가장 기초적인걸로 해봤습니다.
f(x) = a*x + b 형태의 방정식에서 {f(x), x}를 알때 a, b를 구하는 겁니다.
위 식을 a에 대해 편미분하고 b에 대해 편미분해서 얻은 기울기 값을 이용하면 쉽게 계산할 수 있습니다.
'개발 및 공부 > 알고리즘' 카테고리의 다른 글
[Rust] 퍼셉트론 (0) | 2018.07.13 |
---|---|
[Keras] Attention 매커니즘 간단한 예제 (0) | 2018.06.02 |
[C++] Q-Learning : Frozen Lake 코드 (0) | 2018.06.02 |
알고리즘 문제 풀땐 왠만하면 전용 입출력 라이브러리를 사용해야 할듯 (0) | 2018.02.24 |
[Algorithm] Minimax - Tic Tac Toe (0) | 2018.01.23 |
Comments