Notice
Recent Posts
Recent Comments
NeuroWhAI의 잡블로그
[TensorFlow] 코드로 이해해본 RNN 본문
여기서 나온 설명과 코드를 가지고 직접 돌려봤습니다.
텐서플로의 기본적인 요소만 사용해서 RNN을 직접 만들고 MNIST 학습을 시켜보는 예제입니다.
이론적인 설명만 공부했을땐 RNN의 동작이 머리에 잘 그려지지 않았는데
이렇게 직접 구현하니 바로 이해가 되네요.
테스트하면서 신기했던건 Optimizer를 Adam으로 바꾸니 학습이 잘 안됬다는거?
또 이상한건 결과를 보시면 아시겠지만 도중에 정확도가 갑자기 확 떨어지는건 대체 왜 그런걸까요...
코드에 버그가 있나?
코드:
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 | import tensorflow as tf from tensorflow.examples.tutorials.mnist import input_data mnist = input_data.read_data_sets("./mnist/data/", one_hot=True, reshape=False) total_epoch = 20 batch_size = 100 total_batch = mnist.train.num_examples // batch_size h_size = 28 w_size = 28 c_size = 1 hidden_size = 100 x_raw = tf.placeholder(tf.float32, shape=[batch_size, h_size, w_size, c_size]) # [100, 28, 28, 1] x_split = tf.split(x_raw, h_size, axis=1) # [100, 28, 28, 1] -> list of [100, 1, 28, 1] y = tf.placeholder(tf.float32, shape=[batch_size, 10]) U = tf.Variable(tf.random_normal([w_size, hidden_size], stddev=0.01)) W = tf.Variable(tf.random_normal([hidden_size, hidden_size], stddev=0.01)) # always square V = tf.Variable(tf.random_normal([hidden_size, 10], stddev=0.01)) s = {} s_init = tf.random_normal(shape=[batch_size, hidden_size], stddev=0.01) s[-1] = s_init for t, x_split in enumerate(x_split): x = tf.reshape(x_split, [batch_size, w_size]) # [100, 1, 28, 1] -> [100, 28] s[t] = tf.nn.tanh(tf.matmul(x, U) + tf.matmul(s[t-1], W)) o = tf.nn.softmax(tf.matmul(s[h_size-1], V)) cost = -tf.reduce_mean(tf.log(tf.reduce_sum(o*y, axis=1))) learning_rate = 0.1 trainer = tf.train.GradientDescentOptimizer(learning_rate).minimize(cost) sess = tf.InteractiveSession() init = tf.global_variables_initializer() init.run() def accuracy(network, t): t_predict = tf.argmax(network, axis=1) t_actual = tf.argmax(t, axis=1) return tf.reduce_mean(tf.cast(tf.equal(t_predict, t_actual), tf.float32)) acc = accuracy(o, y) for epoch in range(total_epoch): print("Epoch: {}".format(epoch + 1)) for i in range(total_batch): batch_xs, batch_ys = mnist.train.next_batch(batch_size) trainer.run({x_raw: batch_xs, y: batch_ys}) accuracy = acc.eval({ x_raw: mnist.test.images[:batch_size], y: mnist.test.labels[:batch_size] }) print("Test accuracy: {}%".format(accuracy * 100.0)) accuracy = acc.eval({ x_raw: mnist.validation.images[:batch_size], y: mnist.validation.labels[:batch_size] }) print("Validation accuracy: {}%".format(accuracy * 100.0)) sess.close() | cs |
결과:
Epoch: 1
Test accuracy: 15.999999642372131%
Epoch: 2
Test accuracy: 15.000000596046448%
Epoch: 3
Test accuracy: 15.000000596046448%
Epoch: 4
Test accuracy: 34.00000035762787%
Epoch: 5
Test accuracy: 49.000000953674316%
Epoch: 6
Test accuracy: 63.999998569488525%
Epoch: 7
Test accuracy: 73.00000190734863%
Epoch: 8
Test accuracy: 77.99999713897705%
Epoch: 9
Test accuracy: 87.00000047683716%
Epoch: 10
Test accuracy: 40.99999964237213%
Epoch: 11
Test accuracy: 75.99999904632568%
Epoch: 12
Test accuracy: 93.00000071525574%
Epoch: 13
Test accuracy: 93.00000071525574%
Epoch: 14
Test accuracy: 89.99999761581421%
Epoch: 15
Test accuracy: 95.99999785423279%
Epoch: 16
Test accuracy: 92.00000166893005%
Epoch: 17
Test accuracy: 93.00000071525574%
Epoch: 18
Test accuracy: 92.00000166893005%
Epoch: 19
Test accuracy: 95.99999785423279%
Epoch: 20
Test accuracy: 93.00000071525574%
Validation accuracy: 92.00000166893005%
Test accuracy: 15.999999642372131%
Epoch: 2
Test accuracy: 15.000000596046448%
Epoch: 3
Test accuracy: 15.000000596046448%
Epoch: 4
Test accuracy: 34.00000035762787%
Epoch: 5
Test accuracy: 49.000000953674316%
Epoch: 6
Test accuracy: 63.999998569488525%
Epoch: 7
Test accuracy: 73.00000190734863%
Epoch: 8
Test accuracy: 77.99999713897705%
Epoch: 9
Test accuracy: 87.00000047683716%
Epoch: 10
Test accuracy: 40.99999964237213%
Epoch: 11
Test accuracy: 75.99999904632568%
Epoch: 12
Test accuracy: 93.00000071525574%
Epoch: 13
Test accuracy: 93.00000071525574%
Epoch: 14
Test accuracy: 89.99999761581421%
Epoch: 15
Test accuracy: 95.99999785423279%
Epoch: 16
Test accuracy: 92.00000166893005%
Epoch: 17
Test accuracy: 93.00000071525574%
Epoch: 18
Test accuracy: 92.00000166893005%
Epoch: 19
Test accuracy: 95.99999785423279%
Epoch: 20
Test accuracy: 93.00000071525574%
Validation accuracy: 92.00000166893005%
'개발 및 공부' 카테고리의 다른 글
소녀전선 DB 파싱 중 (0) | 2018.05.12 |
---|---|
[AI] Attention 매커니즘 공부 (0) | 2018.05.09 |
[Keras] Style Transfer 코드 공부 (0) | 2018.04.22 |
Local Response Normalization 설명...? (0) | 2018.03.03 |
'골빈해커의 3분 딥러닝 텐서플로맛' 다 봤습니다. (0) | 2018.02.24 |
Comments