NeuroWhAI의 잡블로그

[Keras] Seq2Seq에 Attention 매커니즘 적용 성공! (+코드) 본문

개발 및 공부/라이브러리&프레임워크

[Keras] Seq2Seq에 Attention 매커니즘 적용 성공! (+코드)

NeuroWhAI 2018. 12. 4. 18:45


드디어!
성공했습니다!
와! PPAP!!

간단하게 입력 문장을 문자 단위로 나누고 그걸 그대로 출력하는게 정답인 데이터 세트를 사용했고
학습을 시킨 후 어텐션 매트릭스(Attention Matrix)를 출력해서 제대로 학습이 되었는지 검증을 했습니다.


짠!
성공한게 틀림없다구요! (아마도요...)
X축이 입력 문장, Y축이 출력 문장입니다.

만들면서 가장 힘들었던 건 텐서를 생각대로 다루는 것과 어텐션 스코어 계산이었습니다.
텐서를 막 늘리고 돌리고 연산하는데 차원을 잘 맞춰야 하면서도 무작정 차원만 맞추면 연산이 올바르지 않게 되어버렸어서 힘들었습니다.
가장 멍청했던 실수는 텐서의 축을 치환하는데 Permute대신 Reshape을 썼던 것입니다.
이러면 차원은 맞춰지지만 요소들이 이상하게 배치되죠... ㅠㅠ
그리고 어텐션 스코어 계산은 구현에서가 아니라 스코어 함수가 뭔질 몰라서 힘들었습니다.
검색하면 쉽게 찾을 수 있을 줄 알았는데 스코어 함수라는게 있다는 것만 알려주고 어떤 연산인지를 알려주지 않더라고요..
그래서 적당히 제가 생각해서 구현을 했었는데 이전에는 실패해서 제가 틀린 줄 알았죠.
페이스북 텐서플로우 코리아 그룹에 스코어 함수 관련 질문을 올렸는데 답변으로 논문 링크를 받았었습니다.
저는 제가 논문을 볼 수준이 아니라고 생각해서 당황스러웠지만 막상 열어보니


띠용???
바로 있는겁니다;;
그런데 보다보니 제가 직접 구현했던 스코어 함수랑 concat 부분의 공식이 비슷해서 계산법은 이전 방법을 그대로 썼습니다.

그나저나 데이터가 단어 단위가 아니라 문자 단위라서 불안정한 현상은 있는 것 같습니다.
또 시간이 된다면 단어 단위로, 임베딩까지 적용해서 실험해보도록 하겠습니다 ㅎㅎ

마지막으로 아래는 코드입니다.
코드는 예전에 올린 번역기 만들기를 기반으로 하였고 많이 더럽습니다 ㅎㅎ;;
그냥 참고용으로만 봐주세요.
from keras import layers, models
from keras import datasets
from keras import backend as K
from keras.utils import plot_model
import matplotlib
from matplotlib import ticker
import matplotlib.pyplot as plt
import numpy as np
#from IPython.display import Image, display
  
def get_data(file_id):
  """
  구글 드라이브에서 file_id에 해당하는 파일을 가져와 읽습니다.
  Colab에서 돌리기 위해 필요합니다.
  """
  
  # Install the PyDrive wrapper & import libraries.
  # This only needs to be done once per notebook.
  from pydrive.auth import GoogleAuth
  from pydrive.drive import GoogleDrive
  from google.colab import auth
  from oauth2client.client import GoogleCredentials
 
  # Authenticate and create the PyDrive client.
  # This only needs to be done once per notebook.
  auth.authenticate_user()
  gauth = GoogleAuth()
  gauth.credentials = GoogleCredentials.get_application_default()
  drive = GoogleDrive(gauth)
 
  # Download a file based on its file ID.
  #
  # A file ID looks like: laggVyWshwcyP6kEI-y_W3P8D26sz
  downloaded = drive.CreateFile({'id': file_id})
  return downloaded.GetContentString()

# 학습 정보
batch_size = 100
epochs = 24
latent_dim = 256
num_samples = 4096
 
# 문장 벡터화
input_texts = []
target_texts = []
input_characters = set()
target_characters = set()
 
data_set = {
    'kor': '12UL7MH38rxRFskhYuj1B9zU7wUXgv8Q7',
    'jpn': '1ULdUnurb_DEDTzJFaycdWHWkYWqPu4Bo',
    'deu': '1OhP8vPwTGrLyweo0HuLhhGpe0VRWn4k6',
}
 
lines = get_data(data_set['jpn']).split('\n')
for line in lines[: min(num_samples, len(lines) - 1)]:
  input_text, target_text = line.split('\t')
  target_text = input_text # 임시
  
  # "\t"문자를 시작 문자로, "\n"문자를 종료 문자로 사용.
  target_text = '\t' + target_text + '\n'
  input_texts.append(input_text)
  target_texts.append(target_text)
 
  # 문자 집합 생성
  for char in input_text:
    if char not in input_characters:
      input_characters.add(char)
  for char in target_text:
    if char not in target_characters:
      target_characters.add(char)
      
# 학습 데이터 개수
num_samples = len(input_texts)
            
input_characters = sorted(list(input_characters))
target_characters = sorted(list(target_characters))
num_encoder_tokens = len(input_characters)
num_decoder_tokens = len(target_characters)
max_encoder_seq_length = max([len(txt) for txt in input_texts])
max_decoder_seq_length = max([len(txt) for txt in target_texts])
 
print('Number of samples:', num_samples)
print('Number of unique input tokens:', num_encoder_tokens)
print('Number of unique output tokens:', num_decoder_tokens)
print('Max sequence length for inputs:', max_encoder_seq_length)
print('Max sequence length for outputs:', max_decoder_seq_length)

# 문자 -> 숫자 변환용 사전
input_token_index = dict(
    [(char, i) for i, char in enumerate(input_characters)])
target_token_index = dict(
    [(char, i) for i, char in enumerate(target_characters)])

# 학습에 사용할 데이터를 담을 3차원 배열
encoder_input_data = np.zeros(
    (num_samples, max_encoder_seq_length, num_encoder_tokens),
    dtype='float32')
decoder_input_data = np.zeros(
    (num_samples, max_decoder_seq_length, num_decoder_tokens),
    dtype='float32')
decoder_target_data = np.zeros(
    (num_samples, max_decoder_seq_length, num_decoder_tokens),
    dtype='float32')

# 문장을 문자 단위로 원 핫 인코딩하면서 학습용 데이터를 만듬
for i, (input_text, target_text) in enumerate(zip(input_texts, target_texts)):
  for t, char in enumerate(input_text):
    encoder_input_data[i, t, input_token_index[char]] = 1.
  for t, char in enumerate(target_text):
    decoder_input_data[i, t, target_token_index[char]] = 1.
    if t > 0:
      decoder_target_data[i, t - 1, target_token_index[char]] = 1.

# 숫자 -> 문자 변환용 사전
reverse_input_char_index = dict(
    (i, char) for char, i in input_token_index.items())
reverse_target_char_index = dict(
    (i, char) for char, i in target_token_index.items())

def RepeatVectorLayer(rep, axis):
  return layers.Lambda(lambda x: K.repeat_elements(K.expand_dims(x, axis), rep, axis),
                      lambda x: tuple((x[0],) + x[1:axis] + (rep,) + x[axis:]))

# 인코더 생성
encoder_inputs = layers.Input(shape=(max_encoder_seq_length, num_encoder_tokens))
encoder = layers.GRU(latent_dim, return_sequences=True, return_state=True)
encoder_outputs, state_h = encoder(encoder_inputs)

# 디코더 생성.
decoder_inputs = layers.Input(shape=(max_decoder_seq_length, num_decoder_tokens))
decoder = layers.GRU(latent_dim, return_sequences=True, return_state=True)
decoder_outputs, _ = decoder(decoder_inputs, initial_state=state_h)

# 어텐션 매커니즘.
repeat_d_layer = RepeatVectorLayer(max_encoder_seq_length, 2)
repeat_d = repeat_d_layer(decoder_outputs)

repeat_e_layer = RepeatVectorLayer(max_decoder_seq_length, 1)
repeat_e = repeat_e_layer(encoder_outputs)

concat_for_score_layer = layers.Concatenate(axis=-1)
concat_for_score = concat_for_score_layer([repeat_d, repeat_e])

dense1_t_score_layer = layers.Dense(latent_dim // 2, activation='tanh')
dense1_score_layer = layers.TimeDistributed(dense1_t_score_layer)
dense1_score = dense1_score_layer(concat_for_score)
dense2_t_score_layer = layers.Dense(1)
dense2_score_layer = layers.TimeDistributed(dense2_t_score_layer)
dense2_score = dense2_score_layer(dense1_score)
dense2_score = layers.Reshape((max_decoder_seq_length, max_encoder_seq_length))(dense2_score)

softmax_score_layer = layers.Softmax(axis=-1)
softmax_score = softmax_score_layer(dense2_score)

repeat_score_layer = RepeatVectorLayer(latent_dim, 2)
repeat_score = repeat_score_layer(softmax_score)

permute_e = layers.Permute((2, 1))(encoder_outputs)
repeat_e_layer = RepeatVectorLayer(max_decoder_seq_length, 1)
repeat_e = repeat_e_layer(permute_e)

attended_mat_layer = layers.Multiply()
attended_mat = attended_mat_layer([repeat_score, repeat_e])

context_layer = layers.Lambda(lambda x: K.sum(x, axis=-1),
                             lambda x: tuple(x[:-1]))
context = context_layer(attended_mat)

concat_context_layer = layers.Concatenate(axis=-1)
concat_context = concat_context_layer([context, decoder_outputs])

attention_dense_output_layer = layers.Dense(latent_dim, activation='tanh')
attention_output_layer = layers.TimeDistributed(attention_dense_output_layer)
attention_output = attention_output_layer(concat_context)

decoder_dense = layers.Dense(num_decoder_tokens, activation='softmax')
decoder_outputs = decoder_dense(attention_output)

# 모델 생성
model = models.Model([encoder_inputs, decoder_inputs], decoder_outputs)

choice = input("Load weights?")
if choice == 'y' or choice == 'Y':
  model.load_weights('att_seq2seq_weights.h5')

model.compile(optimizer='rmsprop', loss='categorical_crossentropy',
             metrics=['acc'])
#model.summary()
#plot_model(model, show_shapes=True, to_file='model.png')
#display(Image(filename='model.png'))

choice = input("Train?")
if choice == 'y' or choice == 'Y':
  # 학습
  history = model.fit([encoder_input_data, decoder_input_data], decoder_target_data,
                      batch_size=batch_size,
                      epochs=epochs,
                      validation_split=0.2,
                      verbose=2)

  model.save_weights('att_seq2seq_weights.h5')

  # 손실 그래프
  plt.plot(history.history['loss'], 'y', label='train loss')
  plt.plot(history.history['val_loss'], 'r', label='val loss')
  plt.legend(loc='upper left')
  plt.show()

  # 정확도 그래프
  plt.plot(history.history['acc'], 'y', label='train acc')
  plt.plot(history.history['val_acc'], 'r', label='val acc')
  plt.legend(loc='upper left')
  plt.show()

# 어텐션 검증
test_data_num = 0
test_max_len = 0
for i, s in enumerate(input_texts):
  if len(s) > test_max_len:
    test_max_len = len(s)
    test_data_num = i

test_enc_input = encoder_input_data[test_data_num].reshape(
    (1, max_encoder_seq_length, num_encoder_tokens))
test_dec_input = decoder_input_data[test_data_num].reshape(
    (1, max_decoder_seq_length, num_decoder_tokens))

attention_layer = softmax_score_layer
func = K.function([encoder_inputs, decoder_inputs] + [K.learning_phase()], [attention_layer.output])
score_values = func([test_enc_input, test_dec_input, 1.0])[0]
score_values = score_values.reshape((max_decoder_seq_length, max_encoder_seq_length))

score_values = score_values[:len(target_texts[test_data_num])-1, :len(input_texts[test_data_num])]

fig = plt.figure()
ax = fig.add_subplot(111)
cax = ax.matshow(score_values, interpolation='nearest')
fig.colorbar(cax)

test_enc_names = []
for vec in test_enc_input[0]:
  sampled_token_index = np.argmax(vec)
  sampled_char = reverse_input_char_index[sampled_token_index]
  test_enc_names.append(sampled_char)
test_dec_names = []
for vec in test_dec_input[0]:
  sampled_token_index = np.argmax(vec)
  sampled_char = reverse_target_char_index[sampled_token_index]
  test_dec_names.append(sampled_char)

print(test_dec_names[1:len(target_texts[test_data_num])])

ax.xaxis.set_major_locator(ticker.MultipleLocator(1))
ax.yaxis.set_major_locator(ticker.MultipleLocator(1))
ax.set_yticklabels(['']+test_dec_names[1:-1] + ['<END>'])
ax.set_xticklabels(['']+test_enc_names)

plt.show()
  
# 추론(테스트)

# 추론 모델 생성
encoder_model = models.Model(encoder_inputs, [encoder_outputs, state_h])
encoder_outputs_input = layers.Input(shape=(max_encoder_seq_length, latent_dim))

decoder_inputs = layers.Input(shape=(1, num_decoder_tokens))
decoder_state_input_h = layers.Input(shape=(latent_dim,))
decoder_outputs, decoder_h = decoder(decoder_inputs, initial_state=decoder_state_input_h)

repeat_d_layer = RepeatVectorLayer(max_encoder_seq_length, 2)
repeat_d = repeat_d_layer(decoder_outputs)

repeat_e_layer = RepeatVectorLayer(1, axis=1)
repeat_e = repeat_e_layer(encoder_outputs_input)

concat_for_score_layer = layers.Concatenate(axis=-1)
concat_for_score = concat_for_score_layer([repeat_d, repeat_e])

dense1_score_layer = layers.TimeDistributed(dense1_t_score_layer)
dense1_score = dense1_score_layer(concat_for_score)

dense2_score_layer = layers.TimeDistributed(dense2_t_score_layer)
dense2_score = dense2_score_layer(dense1_score)
dense2_score = layers.Reshape((1, max_encoder_seq_length))(dense2_score)

softmax_score_layer = layers.Softmax(axis=-1)
softmax_score = softmax_score_layer(dense2_score)

repeat_score_layer = RepeatVectorLayer(latent_dim, 2)
repeat_score = repeat_score_layer(softmax_score)

permute_e = layers.Permute((2, 1))(encoder_outputs_input)
repeat_e_layer = RepeatVectorLayer(1, axis=1)
repeat_e = repeat_e_layer(permute_e)

attended_mat_layer = layers.Multiply()
attended_mat = attended_mat_layer([repeat_score, repeat_e])

context_layer = layers.Lambda(lambda x: K.sum(x, axis=-1),
                             lambda x: tuple(x[:-1]))
context = context_layer(attended_mat)

concat_context_layer = layers.Concatenate(axis=-1)
concat_context = concat_context_layer([context, decoder_outputs])

attention_output_layer = layers.TimeDistributed(attention_dense_output_layer)
attention_output = attention_output_layer(concat_context)

decoder_att_outputs = decoder_dense(attention_output)

decoder_model = models.Model([decoder_inputs, decoder_state_input_h, encoder_outputs_input],
                            [decoder_outputs, decoder_h, decoder_att_outputs])
#decoder_model.summary()
#plot_model(decoder_model, show_shapes=True, to_file='decoder_model.png')
#display(Image(filename='decoder_model.png'))

def decode_sequence(input_seq):
  # 입력 문장을 인코딩
  enc_outputs, states_value = encoder_model.predict(input_seq)
 
  # 디코더의 입력으로 쓸 단일 문자
  target_seq = np.zeros((1, 1, num_decoder_tokens))
  # 첫 입력은 시작 문자인 '\t'로 설정
  target_seq[0, 0, target_token_index['\t']] = 1.
 
  # 문장 생성
  stop_condition = False
  decoded_sentence = ''
  while not stop_condition:
    # 이전의 출력, 상태를 디코더에 넣어서 새로운 출력, 상태를 얻음
    # 이전 문자와 상태로 다음 문자와 상태를 얻는다고 보면 됨.
    dec_outputs, h, output_tokens = decoder_model.predict(
        [target_seq, states_value, enc_outputs])
 
    # 사전을 사용해서 원 핫 인코딩 출력을 실제 문자로 변환
    sampled_token_index = np.argmax(output_tokens[0, -1, :])
    sampled_char = reverse_target_char_index[sampled_token_index]
    decoded_sentence += sampled_char
 
    # 종료 문자가 나왔거나 문장 길이가 한계를 넘으면 종료
    if (sampled_char == '\n' or len(decoded_sentence) > max_decoder_seq_length):
      stop_condition = True
 
    # 디코더의 다음 입력으로 쓸 데이터 갱신
    target_seq = np.zeros((1, 1, num_decoder_tokens))
    target_seq[0, 0, sampled_token_index] = 1.
    
    states_value = h
 
  return decoded_sentence

for seq_index in range(30):
  input_seq = encoder_input_data[seq_index: seq_index + 1]
  decoded_sentence = decode_sequence(input_seq)
  print('"{}" -> "{}"'.format(input_texts[seq_index], decoded_sentence.strip()))




Comments