Notice
Recent Posts
Recent Comments
NeuroWhAI의 잡블로그
[Keras] Attention 매커니즘 간단한 예제 본문
Attention 매커니즘을 검색하다가 좋은 코드를 발견해서 공부하면서 좀 다듬었습니다.
학습 데이터는 랜덤한 쓰레기 값이 들어있는 입력 데이터에서 8번째(인덱스 7) 값만 출력(목표) 데이터와 관련이 있는데
이걸 모델이 잘 알아차렸는지 보는 예제입니다.
코드:
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 | import numpy as np from keras import backend as K from keras import models from keras import layers import matplotlib.pyplot as plt import pandas as pd data_count = 10000 input_dims = 32 attention_column = 7 def make_data(batch_size, input_size, attention_index): """ 학습 데이터를 만듭니다. 한 배치만 보고 설명하자면 입력 데이터는 input_size 길이이며 attention_index를 제외한 곳은 전부 랜덤한 수로 설정됩니다. 목표 데이터는 0 또는 1이며 이 값과 입력 데이터의 attention_index 위치의 값은 같습니다. """ train_x = np.random.standard_normal(size=(batch_size, input_size)) train_y = np.random.randint(low=0, high=2, size=(batch_size, 1)) train_x[:, attention_index] = train_y[:, 0] return (train_x, train_y) # Input Layer input_layer = layers.Input(shape=(input_dims,)) # Attention Layer attention_probs = layers.Dense(input_dims, activation='softmax')(input_layer) attention_mul = layers.multiply([input_layer, attention_probs]) # FC Layer y = layers.Dense(64)(attention_mul) y = layers.Dense(1, activation='sigmoid')(y) model = models.Model(input_layer, y) model.compile(optimizer='adam', loss='binary_crossentropy', metrics=['accuracy']) model.summary() # Train train_x, train_y = make_data(data_count, input_dims, attention_column) model.fit(train_x, train_y, epochs=20, batch_size=64, validation_split=0.2, verbose=2) # Test test_x, test_y = make_data(data_count, input_dims, attention_column) result = model.evaluate(test_x, test_y, batch_size=64, verbose=0) print("Loss:", result[0]) print("Accuracy:", result[1]) # Get attention vector attention_layer = model.layers[1] func = K.function([model.input] + [K.learning_phase()], [attention_layer.output]) output = func([test_x, 1.0])[0] attention_vector = np.mean(output, axis=0) # Show attention vector pd.DataFrame(attention_vector, columns=['attention (%)']).plot(kind='bar', title='Attention Vector') plt.show() | cs |
결과:
__________________________________________________________________________________________________
Layer (type) Output Shape Param # Connected to
==================================================================================================
input_19 (InputLayer) (None, 32) 0
__________________________________________________________________________________________________
dense_53 (Dense) (None, 32) 1056 input_19[0][0]
__________________________________________________________________________________________________
multiply_3 (Multiply) (None, 32) 0 input_19[0][0]
dense_53[0][0]
__________________________________________________________________________________________________
dense_54 (Dense) (None, 64) 2112 multiply_3[0][0]
__________________________________________________________________________________________________
dense_55 (Dense) (None, 1) 65 dense_54[0][0]
==================================================================================================
Total params: 3,233
Trainable params: 3,233
Non-trainable params: 0
__________________________________________________________________________________________________
Train on 8000 samples, validate on 2000 samples
Epoch 1/20
- 2s - loss: 0.6825 - acc: 0.5921 - val_loss: 0.6559 - val_acc: 0.7635
Epoch 2/20
- 1s - loss: 0.5833 - acc: 0.7875 - val_loss: 0.4990 - val_acc: 0.7940
Epoch 3/20
- 1s - loss: 0.4095 - acc: 0.8460 - val_loss: 0.3260 - val_acc: 0.8840
Epoch 4/20
- 1s - loss: 0.2319 - acc: 0.9315 - val_loss: 0.1522 - val_acc: 0.9625
Epoch 5/20
- 1s - loss: 0.0874 - acc: 0.9858 - val_loss: 0.0426 - val_acc: 0.9965
Epoch 6/20
- 1s - loss: 0.0231 - acc: 0.9990 - val_loss: 0.0115 - val_acc: 1.0000
Epoch 7/20
- 1s - loss: 0.0071 - acc: 1.0000 - val_loss: 0.0046 - val_acc: 1.0000
Epoch 8/20
- 1s - loss: 0.0032 - acc: 1.0000 - val_loss: 0.0024 - val_acc: 1.0000
Epoch 9/20
- 1s - loss: 0.0018 - acc: 1.0000 - val_loss: 0.0015 - val_acc: 1.0000
Epoch 10/20
- 1s - loss: 0.0012 - acc: 1.0000 - val_loss: 0.0011 - val_acc: 1.0000
Epoch 11/20
- 1s - loss: 8.6000e-04 - acc: 1.0000 - val_loss: 7.9172e-04 - val_acc: 1.0000
Epoch 12/20
- 1s - loss: 6.4733e-04 - acc: 1.0000 - val_loss: 6.1214e-04 - val_acc: 1.0000
Epoch 13/20
- 1s - loss: 5.0670e-04 - acc: 1.0000 - val_loss: 4.8932e-04 - val_acc: 1.0000
Epoch 14/20
- 1s - loss: 4.0807e-04 - acc: 1.0000 - val_loss: 4.0024e-04 - val_acc: 1.0000
Epoch 15/20
- 1s - loss: 3.3599e-04 - acc: 1.0000 - val_loss: 3.3364e-04 - val_acc: 1.0000
Epoch 16/20
- 1s - loss: 2.8143e-04 - acc: 1.0000 - val_loss: 2.8280e-04 - val_acc: 1.0000
Epoch 17/20
- 1s - loss: 2.3915e-04 - acc: 1.0000 - val_loss: 2.4242e-04 - val_acc: 1.0000
Epoch 18/20
- 1s - loss: 2.0557e-04 - acc: 1.0000 - val_loss: 2.1001e-04 - val_acc: 1.0000
Epoch 19/20
- 1s - loss: 1.7842e-04 - acc: 1.0000 - val_loss: 1.8360e-04 - val_acc: 1.0000
Epoch 20/20
- 1s - loss: 1.5616e-04 - acc: 1.0000 - val_loss: 1.6155e-04 - val_acc: 1.0000
Loss: 0.0001790129337925464
Accuracy: 1.0
Layer (type) Output Shape Param # Connected to
==================================================================================================
input_19 (InputLayer) (None, 32) 0
__________________________________________________________________________________________________
dense_53 (Dense) (None, 32) 1056 input_19[0][0]
__________________________________________________________________________________________________
multiply_3 (Multiply) (None, 32) 0 input_19[0][0]
dense_53[0][0]
__________________________________________________________________________________________________
dense_54 (Dense) (None, 64) 2112 multiply_3[0][0]
__________________________________________________________________________________________________
dense_55 (Dense) (None, 1) 65 dense_54[0][0]
==================================================================================================
Total params: 3,233
Trainable params: 3,233
Non-trainable params: 0
__________________________________________________________________________________________________
Train on 8000 samples, validate on 2000 samples
Epoch 1/20
- 2s - loss: 0.6825 - acc: 0.5921 - val_loss: 0.6559 - val_acc: 0.7635
Epoch 2/20
- 1s - loss: 0.5833 - acc: 0.7875 - val_loss: 0.4990 - val_acc: 0.7940
Epoch 3/20
- 1s - loss: 0.4095 - acc: 0.8460 - val_loss: 0.3260 - val_acc: 0.8840
Epoch 4/20
- 1s - loss: 0.2319 - acc: 0.9315 - val_loss: 0.1522 - val_acc: 0.9625
Epoch 5/20
- 1s - loss: 0.0874 - acc: 0.9858 - val_loss: 0.0426 - val_acc: 0.9965
Epoch 6/20
- 1s - loss: 0.0231 - acc: 0.9990 - val_loss: 0.0115 - val_acc: 1.0000
Epoch 7/20
- 1s - loss: 0.0071 - acc: 1.0000 - val_loss: 0.0046 - val_acc: 1.0000
Epoch 8/20
- 1s - loss: 0.0032 - acc: 1.0000 - val_loss: 0.0024 - val_acc: 1.0000
Epoch 9/20
- 1s - loss: 0.0018 - acc: 1.0000 - val_loss: 0.0015 - val_acc: 1.0000
Epoch 10/20
- 1s - loss: 0.0012 - acc: 1.0000 - val_loss: 0.0011 - val_acc: 1.0000
Epoch 11/20
- 1s - loss: 8.6000e-04 - acc: 1.0000 - val_loss: 7.9172e-04 - val_acc: 1.0000
Epoch 12/20
- 1s - loss: 6.4733e-04 - acc: 1.0000 - val_loss: 6.1214e-04 - val_acc: 1.0000
Epoch 13/20
- 1s - loss: 5.0670e-04 - acc: 1.0000 - val_loss: 4.8932e-04 - val_acc: 1.0000
Epoch 14/20
- 1s - loss: 4.0807e-04 - acc: 1.0000 - val_loss: 4.0024e-04 - val_acc: 1.0000
Epoch 15/20
- 1s - loss: 3.3599e-04 - acc: 1.0000 - val_loss: 3.3364e-04 - val_acc: 1.0000
Epoch 16/20
- 1s - loss: 2.8143e-04 - acc: 1.0000 - val_loss: 2.8280e-04 - val_acc: 1.0000
Epoch 17/20
- 1s - loss: 2.3915e-04 - acc: 1.0000 - val_loss: 2.4242e-04 - val_acc: 1.0000
Epoch 18/20
- 1s - loss: 2.0557e-04 - acc: 1.0000 - val_loss: 2.1001e-04 - val_acc: 1.0000
Epoch 19/20
- 1s - loss: 1.7842e-04 - acc: 1.0000 - val_loss: 1.8360e-04 - val_acc: 1.0000
Epoch 20/20
- 1s - loss: 1.5616e-04 - acc: 1.0000 - val_loss: 1.6155e-04 - val_acc: 1.0000
Loss: 0.0001790129337925464
Accuracy: 1.0
그래프를 보시면 8번째 값(인덱스 7)이 확실히 높죠?
모델이 입력에서 필요한 부분만 잘 주목했다는걸 알 수 있습니다!
음...
이 예제는 그렇게 어렵지 않은데 문제는 Encoder-Decoder 모델에 적용하는 코드는 상당히 복잡했다는 겁니다.
어쩌면 이름만 같은 이거랑 아예 다른 알고리즘일지도 모르겠습니다.
제가 아직 이해가 부족해서... ㅠ
'개발 및 공부 > 알고리즘' 카테고리의 다른 글
[Rust] 퍼셉트론 - '밑바닥부터 시작하는 딥러닝' 2장 (0) | 2018.07.14 |
---|---|
[Rust] 퍼셉트론 (0) | 2018.07.13 |
[C++] Q-Learning : Frozen Lake 코드 (0) | 2018.06.02 |
알고리즘 문제 풀땐 왠만하면 전용 입출력 라이브러리를 사용해야 할듯 (0) | 2018.02.24 |
[Algorithm] Minimax - Tic Tac Toe (0) | 2018.01.23 |
Comments