NeuroWhAI의 잡블로그

[Keras] Seq2Seq에 Attention 매커니즘 적용 실패 본문

개발 및 공부/라이브러리&프레임워크

[Keras] Seq2Seq에 Attention 매커니즘 적용 실패

NeuroWhAI 2018. 12. 2. 21:01


실패했다.

적용 방법도 떠오르지 않았던 것과 다르게 이번엔 확실히 적용할 수 있는 방법이 생각나서 진행했다.

그러나 결과적으로 실패한 것 같다.

번역 데이터 세트를 사용했고 실패했지만 번역 품질이 못봐줄 수준은 아니었다.

그러나 최종 검증인 어텐션 매트릭스에서 실패가 보였다.

어텐션 매트릭스라고 함은 입력 단어들과 출력 단어들간에 관련성을 보여주는 매트릭스인데

예를 들면 영어로 'I'가 한국어의 '나'와 관련성이 높게 나와야 한다는 뜻이다.

하지만 직접 출력해본 매트릭스는 그러한 특성이 전혀 나타나지 않았고 무언가 이상한 벡터의 반복이었다.

그렇다는건 연산을 잘못 적용했다는 소리인데 차원만 맞추느라 연산의 연결을 제대로 생각하지 못한게 실패의 원인인듯 하다.


힘들다..

이렇게 오래, 많이 도전했지만 실패해본게 정말 오랜만인 것 같다.

사실 attention도 한물 간 기술이고 요즘은 transformer를 쓴다고 한다.

물론 attention을 기반으로 한 기술이지만 아직 이것도 구현하지 못하겠는데 너무 길이 멀다.

나에게 벅찬 분야가 아닐까 하는 생각이 들지만 접어두었다.

정말 마지막으로 계획을 하나 세웠다.

번역 문제는 내가 검증하기 힘드니 간단한 seq2seq 문제를 정의해서 attention 매커니즘을 적용해볼 예정이다.

뭐 문장에서 숫자만 골라 출력하기 같은?

이것도 실패하면 어텐션 구현은 접고 이론적인 이해로 만족해야겠다.


아래는 코드이다.

정리가 안되어서 더럽다.

from keras import layers, models
from keras import datasets
from keras import backend as K
from keras.utils import plot_model
import matplotlib
from matplotlib import ticker
import matplotlib.pyplot as plt
import numpy as np
#from IPython.display import Image, display
  
def get_data(file_id):
  """
  구글 드라이브에서 file_id에 해당하는 파일을 가져와 읽습니다.
  Colab에서 돌리기 위해 필요합니다.
  """
  
  # Install the PyDrive wrapper & import libraries.
  # This only needs to be done once per notebook.
  from pydrive.auth import GoogleAuth
  from pydrive.drive import GoogleDrive
  from google.colab import auth
  from oauth2client.client import GoogleCredentials
 
  # Authenticate and create the PyDrive client.
  # This only needs to be done once per notebook.
  auth.authenticate_user()
  gauth = GoogleAuth()
  gauth.credentials = GoogleCredentials.get_application_default()
  drive = GoogleDrive(gauth)
 
  # Download a file based on its file ID.
  #
  # A file ID looks like: laggVyWshwcyP6kEI-y_W3P8D26sz
  downloaded = drive.CreateFile({'id': file_id})
  return downloaded.GetContentString()

# 학습 정보
batch_size = 100
epochs = 32
latent_dim = 256
num_samples = 4096
 
# 문장 벡터화
input_texts = []
target_texts = []
input_characters = set()
target_characters = set()
 
data_set = {
    'kor': '12UL7MH38rxRFskhYuj1B9zU7wUXgv8Q7',
    'jpn': '1ULdUnurb_DEDTzJFaycdWHWkYWqPu4Bo',
    'deu': '1OhP8vPwTGrLyweo0HuLhhGpe0VRWn4k6',
}
 
lines = get_data(data_set['jpn']).split('\n')
for line in lines[: min(num_samples, len(lines) - 1)]:
  input_text, target_text = line.split('\t')
  # "\t"문자를 시작 문자로, "\n"문자를 종료 문자로 사용.
  target_text = '\t' + target_text + '\n'
  input_texts.append(input_text)
  target_texts.append(target_text)
 
  # 문자 집합 생성
  for char in input_text:
    if char not in input_characters:
      input_characters.add(char)
  for char in target_text:
    if char not in target_characters:
      target_characters.add(char)

# 학습 데이터 개수
num_samples = len(input_texts)
            
input_characters = sorted(list(input_characters))
target_characters = sorted(list(target_characters))
num_encoder_tokens = len(input_characters)
num_decoder_tokens = len(target_characters)
max_encoder_seq_length = max([len(txt) for txt in input_texts])
max_decoder_seq_length = max([len(txt) for txt in target_texts])
 
print('Number of samples:', num_samples)
print('Number of unique input tokens:', num_encoder_tokens)
print('Number of unique output tokens:', num_decoder_tokens)
print('Max sequence length for inputs:', max_encoder_seq_length)
print('Max sequence length for outputs:', max_decoder_seq_length)

# 문자 -> 숫자 변환용 사전
input_token_index = dict(
    [(char, i) for i, char in enumerate(input_characters)])
target_token_index = dict(
    [(char, i) for i, char in enumerate(target_characters)])

# 학습에 사용할 데이터를 담을 3차원 배열
encoder_input_data = np.zeros(
    (num_samples, max_encoder_seq_length, num_encoder_tokens),
    dtype='float32')
decoder_input_data = np.zeros(
    (num_samples, max_decoder_seq_length, num_decoder_tokens),
    dtype='float32')
decoder_target_data = np.zeros(
    (num_samples, max_decoder_seq_length, num_decoder_tokens),
    dtype='float32')

# 문장을 문자 단위로 원 핫 인코딩하면서 학습용 데이터를 만듬
for i, (input_text, target_text) in enumerate(zip(input_texts, target_texts)):
  for t, char in enumerate(input_text):
    encoder_input_data[i, t, input_token_index[char]] = 1.
  for t, char in enumerate(target_text):
    decoder_input_data[i, t, target_token_index[char]] = 1.
    if t > 0:
      decoder_target_data[i, t - 1, target_token_index[char]] = 1.

# 숫자 -> 문자 변환용 사전
reverse_input_char_index = dict(
    (i, char) for char, i in input_token_index.items())
reverse_target_char_index = dict(
    (i, char) for char, i in target_token_index.items())

# 인코더 생성
encoder_inputs = layers.Input(shape=(max_encoder_seq_length, num_encoder_tokens))
encoder = layers.GRU(latent_dim, return_sequences=True, return_state=True)
encoder_outputs, state_h = encoder(encoder_inputs)

# 디코더 생성.
decoder_inputs = layers.Input(shape=(max_decoder_seq_length, num_decoder_tokens))
decoder = layers.GRU(latent_dim, return_sequences=True, return_state=True)
decoder_outputs, _ = decoder(decoder_inputs, initial_state=state_h)

# 어텐션 매커니즘.
flatten_h = layers.Reshape((max_encoder_seq_length * latent_dim,))(encoder_outputs)
repeat_h = layers.RepeatVector(max_decoder_seq_length)(flatten_h)
repeat_h = layers.Reshape((max_encoder_seq_length * max_decoder_seq_length, latent_dim))(repeat_h)

repeat_d = layers.Lambda(lambda x: K.concatenate([K.repeat(x[:, i, :], max_encoder_seq_length) for i in range(0, max_decoder_seq_length)], axis=-2),
                        lambda x: tuple([x[0], max_encoder_seq_length * max_decoder_seq_length, latent_dim]))
repeat_d = repeat_d(decoder_outputs)

score_input = layers.Concatenate(axis=-1)([repeat_h, repeat_d])
score_dense = layers.Dense(1)
#score_time_dense = layers.TimeDistributed(score_dense)
#score = score_time_dense(score_input)
score = score_dense(score_input)

score_softmax = layers.Lambda(lambda x: K.concatenate([K.softmax(x[:, i*max_encoder_seq_length:(i+1)*max_encoder_seq_length, 0], axis=-1) for i in range(0, max_decoder_seq_length)], axis=-1),
                             lambda x: tuple([x[0], max_encoder_seq_length * max_decoder_seq_length]))
score_softmax_reshape = layers.Reshape((-1, 1))
score_softmax = score_softmax_reshape(score_softmax(score))

score_mul = layers.Lambda(lambda x: K.repeat_elements(x, latent_dim, axis=-1),
                         lambda x: tuple(x[:-1] + (latent_dim,)))
score_mul_t = score_mul(score_softmax)

scored_h = layers.Multiply()([repeat_h, score_mul_t])

context = layers.Lambda(lambda x: K.sum(K.reshape(x, (-1, max_decoder_seq_length, max_encoder_seq_length, latent_dim)), axis=-2),
                       lambda x: tuple([x[0], max_decoder_seq_length, latent_dim]))
context = context(scored_h)

attention_output = layers.Concatenate(axis=-1)([context, decoder_outputs])

decoder_dense = layers.Dense(num_decoder_tokens, activation='softmax')
decoder_outputs = decoder_dense(attention_output)

# 모델 생성
model = models.Model([encoder_inputs, decoder_inputs], decoder_outputs)

choice = input("Load weights?")
if choice == 'y' or choice == 'Y':
  model.load_weights('att_seq2seq_weights.h5')

model.compile(optimizer='rmsprop', loss='categorical_crossentropy')
model.summary()
#plot_model(model, show_shapes=True, to_file='model.png')
#display(Image(filename='model.png'))

choice = input("Train?")
if choice == 'y' or choice == 'Y':
  # 학습
  history = model.fit([encoder_input_data, decoder_input_data], decoder_target_data,
                      batch_size=batch_size,
                      epochs=epochs,
                      validation_split=0.2,
                      verbose=2)

  model.save_weights('att_seq2seq_weights.h5')

  # 손실 그래프
  plt.plot(history.history['loss'], 'y', label='train loss')
  plt.plot(history.history['val_loss'], 'r', label='val loss')
  plt.legend(loc='upper left')
  plt.show()

# 어텐션 검증
test_data_num = 0
test_max_len = 0
for i, s in enumerate(input_texts):
  if len(s) > test_max_len:
    test_max_len = len(s)
    test_data_num = i
test_data_num = 60
    
test_enc_input = encoder_input_data[test_data_num].reshape(
    (1, max_encoder_seq_length, num_encoder_tokens))
test_dec_input = decoder_input_data[test_data_num].reshape(
    (1, max_decoder_seq_length, num_decoder_tokens))

attention_layer = score_softmax_reshape
func = K.function([encoder_inputs, decoder_inputs] + [K.learning_phase()], [attention_layer.output])
score_values = func([test_enc_input, test_dec_input, 1.0])[0]
#score_values = score_values.reshape((max_encoder_seq_length, max_decoder_seq_length))
score_values = score_values.reshape((max_decoder_seq_length, max_encoder_seq_length))

#score_values = score_values[:len(input_texts[test_data_num]), :len(target_texts[test_data_num])]
score_values = score_values[:len(target_texts[test_data_num]), :len(input_texts[test_data_num])]

fig = plt.figure()
ax = fig.add_subplot(111)
cax = ax.matshow(score_values, interpolation='nearest')
fig.colorbar(cax)

test_enc_names = []
for vec in test_enc_input[0]:
  sampled_token_index = np.argmax(vec)
  sampled_char = reverse_input_char_index[sampled_token_index]
  test_enc_names.append(sampled_char)
test_dec_names = []
for vec in test_dec_input[0]:
  sampled_token_index = np.argmax(vec)
  sampled_char = reverse_target_char_index[sampled_token_index]
  test_dec_names.append(sampled_char)

#print(test_enc_names[:len(input_texts[test_data_num])])
print(test_dec_names[:len(target_texts[test_data_num])])

ax.xaxis.set_major_locator(ticker.MultipleLocator(1))
ax.yaxis.set_major_locator(ticker.MultipleLocator(1))
#ax.set_xticklabels(['']+test_dec_names)
#ax.set_yticklabels(['']+test_enc_names)
ax.set_yticklabels(['']+test_dec_names)
ax.set_xticklabels(['']+test_enc_names)

plt.show()
  
# 추론(테스트)

# 추론 모델 생성
encoder_model = models.Model(encoder_inputs, [encoder_outputs, state_h])

decoder_inputs = layers.Input(shape=(1, num_decoder_tokens))
decoder_state_input_h = layers.Input(shape=(latent_dim,))
decoder_outputs, decoder_h = decoder(decoder_inputs, initial_state=decoder_state_input_h)

repeat_d = layers.Reshape((latent_dim,))(decoder_outputs)
repeat_d = layers.RepeatVector(max_encoder_seq_length)(repeat_d)

encoder_output_input = layers.Input(shape=(max_encoder_seq_length, latent_dim))
score_input = layers.Concatenate(axis=-1)([encoder_output_input, repeat_d])
#score = layers.TimeDistributed(score_dense)(score_input)
score = score_dense(score_input)

score_softmax = layers.Lambda(lambda x: K.softmax(x, axis=-2))
score_softmax = score_softmax(score)

score_mul_t = score_mul(score_softmax)
scored_h = layers.Multiply()([encoder_output_input, score_mul_t])

context = layers.Lambda(lambda x: K.sum(x, axis=-2, keepdims=True),
                       lambda x: tuple([x[0], 1, latent_dim]))
context = context(scored_h)

attention_output = layers.Concatenate(axis=-1)([context, decoder_outputs])

decoder_att_outputs = decoder_dense(attention_output)

decoder_model = models.Model([decoder_inputs, decoder_state_input_h, encoder_output_input],
                            [decoder_outputs, decoder_h, decoder_att_outputs])
decoder_model.summary()
#plot_model(decoder_model, show_shapes=True, to_file='decoder_model.png')
#display(Image(filename='decoder_model.png'))

def decode_sequence(input_seq):
  # 입력 문장을 인코딩
  enc_outputs, states_value = encoder_model.predict(input_seq)
 
  # 디코더의 입력으로 쓸 단일 문자
  target_seq = np.zeros((1, 1, num_decoder_tokens))
  # 첫 입력은 시작 문자인 '\t'로 설정
  target_seq[0, 0, target_token_index['\t']] = 1.
 
  # 문장 생성
  stop_condition = False
  decoded_sentence = ''
  while not stop_condition:
    # 이전의 출력, 상태를 디코더에 넣어서 새로운 출력, 상태를 얻음
    # 이전 문자와 상태로 다음 문자와 상태를 얻는다고 보면 됨.
    dec_outputs, h, output_tokens = decoder_model.predict(
        [target_seq, states_value, enc_outputs])
 
    # 사전을 사용해서 원 핫 인코딩 출력을 실제 문자로 변환
    sampled_token_index = np.argmax(output_tokens[0, -1, :])
    sampled_char = reverse_target_char_index[sampled_token_index]
    decoded_sentence += sampled_char
 
    # 종료 문자가 나왔거나 문장 길이가 한계를 넘으면 종료
    if (sampled_char == '\n' or len(decoded_sentence) > max_decoder_seq_length):
      stop_condition = True
 
    # 디코더의 다음 입력으로 쓸 데이터 갱신
    target_seq = np.zeros((1, 1, num_decoder_tokens))
    target_seq[0, 0, sampled_token_index] = 1.
    
    states_value = h
 
  return decoded_sentence

for seq_index in range(30):
  input_seq = encoder_input_data[seq_index: seq_index + 1]
  decoded_sentence = decode_sequence(input_seq)
  print('"{}" -> "{}"'.format(input_texts[seq_index], decoded_sentence.strip()))




Comments