NeuroWhAI의 잡블로그

[Keras] GAN으로 입력 데이터의 확률분포 변환하기 본문

개발 및 공부/라이브러리&프레임워크

[Keras] GAN으로 입력 데이터의 확률분포 변환하기

NeuroWhAI 2018. 3. 24. 14:23


※ 이 글은 '코딩셰프의 3분 딥러닝 케라스맛'이라는 책을 보고 실습한걸 기록한 글입니다.


다른 강좌나 텐서플로 책에서는 2D 이미지를 가지고 GAN를 실습했었는데
여기서는 단순한 수의 나열인 1D 데이터를 가지고 GAN를 쓰더라고요. (다음 챕터에서 2D 이미지 쓰는 것도 나오지만)

생성망의 입력 데이터는 균등분포의 랜덤한 데이터인데 출력은 정규분포로 내도록 학습시키는 예제입니다.
이번 코드는 책의 코드와 좀 많이 다를 수 있습니다.

코드:
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
from keras import layers, models
from keras import datasets
from keras import backend as K
from keras import optimizers
import matplotlib.pyplot as plt
import numpy as np
 
def add_decorate(x):
  m = K.mean(x, axis=-1, keepdims=True)
  d = K.square(x - m)
  return K.concatenate([x, d], axis=-1)
 
def add_decorate_shape(input_shape):
  shape = list(input_shape)
  assert len(shape) == 2
  shape[1*= 2
  return tuple(shape)
 
class Data:
  def __init__(self, mu, sigma, nInputD):
    self.real_sample = lambda nBatch: np.random.normal(mu, sigma,
                                                       (nBatch, nInputD))
    self.in_sample = lambda nBatch: np.random.rand(nBatch, nInputD)
    
class GAN:
  def __init__(self, nInputD, nHiddenD, nHiddenG):
    self.nInputD = nInputD
    self.nHiddenD = nHiddenD
    self.nHiddenG = nHiddenG
    
    self.adam = optimizers.Adam(lr=2e-4, beta_1=0.9, beta_2=0.999)
    
    self.D = self.create_D()
    self.G = self.create_G()
    self.GD = self.create_GD()
    
  def set_data(self, data):
    self.data = data
    
  def compile_model(self, model):
    return model.compile(loss='binary_crossentropy', optimizer=self.adam,
                        metrics=['accuracy'])
    
  def create_D(self):
    model = models.Sequential()
    model.add(layers.Lambda(add_decorate, output_shape=add_decorate_shape,
                           input_shape=(self.nInputD,)))
    model.add(layers.Dense(self.nHiddenD, activation='relu'))
    model.add(layers.Dense(self.nHiddenD, activation='relu'))
    model.add(layers.Dense(1, activation='sigmoid'))
    
    self.compile_model(model)
    return model
  
  def create_G(self):
    model = models.Sequential()
    model.add(layers.Reshape((self.nInputD, 1), input_shape=(self.nInputD,)))
    model.add(layers.Conv1D(self.nHiddenG, 1, activation='relu'))
    model.add(layers.Conv1D(self.nHiddenG, 1, activation='sigmoid'))
    model.add(layers.Conv1D(11))
    model.add(layers.Flatten())
    
    self.compile_model(model)
    return model
  
  def create_GD(self):
    model = models.Sequential()
    model.add(self.G)
    model.add(self.D)
    
    self.D.trainable = False
    self.compile_model(model)
    self.D.trainable = True
    
    return model
  
  def train(self, batch_size, epochs):
    for epoch in range(epochs):
      result_D = self.train_D(batch_size)
      result_GD = self.train_GD(batch_size)
    print("Discriminator", result_D)
    print("Generator", result_GD)
  
  def train_D(self, batch_size):
    real_data = self.data.real_sample(batch_size)
    
    noise_data = self.data.in_sample(batch_size)
    gen_data = self.G.predict(noise_data)
    
    self.D.trainable = True
    return self.train_D_on_batch(real_data, gen_data)
  
  def train_GD(self, batch_size):
    noise_data = self.data.in_sample(batch_size)
    
    self.D.trainable = False
    return self.train_GD_on_batch(noise_data)
  
  def train_D_on_batch(self, real_data, gen_data):
    x = np.concatenate([real_data, gen_data], axis=0)
    y = np.array([1* real_data.shape[0+ [0* gen_data.shape[0])
    return self.D.train_on_batch(x, y)
    
  def train_GD_on_batch(self, noise_data):
    y = np.array([1* noise_data.shape[0])
    return self.GD.train_on_batch(noise_data, y)
    
  def test(self, n_test):
    noise = self.data.in_sample(n_test)
    gen = self.G.predict(noise)
    return (gen, noise)
  
  def test_and_show(self, n_test):
    gen, noise = self.test(n_test)
    real = self.data.real_sample(n_test)
    
    print('Mean and Std of Real:', (np.mean(real), np.std(real)))
    print('Mean and Std of Gen:', (np.mean(gen), np.std(gen)))
    
    show_hist(real, gen, noise)
    plt.show()
 
def show_hist(real, gen, z):
  plt.hist(real.reshape(-1), histtype='step', label='Real')
  plt.hist(gen.reshape(-1), histtype='step', label='Generated')
  plt.hist(z.reshape(-1), histtype='step', label='Noise')
  plt.legend(loc=0)
    
def main():
  nInputD = 100
  batch_size = 100
  epochs=10
  
  data = Data(41.25, nInputD)
  gan = GAN(nInputD, nHiddenD=50, nHiddenG=50)
  gan.set_data(data)
  
  for epoch in range(epochs):
    print("Epoch:", epoch)
    gan.train(batch_size, epochs=2000)
    gan.test_and_show(batch_size)
 
if __name__ == '__main__':
  main()
cs

결과:
Epoch: 0
Discriminator [0.0005219373, 1.0]
Generator [6.8659143, 0.0]
Mean and Std of Real: (3.986625544961122, 1.2553951217307804)
Mean and Std of Gen: (0.010475923, 0.004396004)


(중략)
Epoch: 9
Discriminator [0.45969188, 0.785]
Generator [1.4480966, 0.18]
Mean and Std of Real: (4.020679477053212, 1.2473444120894257)
Mean and Std of Gen: (3.8809788, 1.3134831)



그래프가 흐릿해서 잘 안보이지만
파란색은 목표 데이터고 빨간색은 생성망의 입력 데이터이며 초록색이 생성된 데이터입니다.

뭔가... 학습이 그렇게 잘 되는 것 같지는 않습니다.
그래도 생성망이 출력한 데이터의 평균, 편차가 목표 데이터의 평균, 편차와 비슷해지는 모습을 볼 수 있습니다.

코드 설명은 다음에 2D 이미지 예제에서 자세히 할 예정이고
모델의 trainable 속성만 짚고 끝내겠습니다.
코드를 보시면 모델을 컴파일하기 전과 학습을 진행하기 전에 trainable 속성을 설정하는걸 보실 수 있습니다.
공식 문서를 보시면 Layer별로도 설정할 수 있는 것 같은데
예제를 보시면 모델(G, D)이 곧 다른 모델(GD)의 레이어처럼 쓰이고 있으니 그냥 각 레이어에 설정할 수 있다고 보시면 될 것 같습니다.
모델을 컴파일하기 전에 파라미터가 변하지 않기(=학습되지 않기)를 원하는 레이어의 trainable 속성을 False로 바꾸고
compile을 하면 그 설정이 고정됩니다.
이후에 trainable 속성을 바꿔도 다시 컴파일하기 전까진 영향이 없지만
바꾼 trainable 속성과 컴파일시 설정했던 trainable 속성의 값이 다르면 경고 문구가 뜹니다.
그래서 train_on_batch로 학습을 하기 전에 trainable 속성을 또 설정해주고 있는거죠.

끝!




Comments