NeuroWhAI의 잡블로그

[TensorFlow] DCGAN으로 MNIST 이미지 생성 성공(?) 본문

개발 및 공부/라이브러리&프레임워크

[TensorFlow] DCGAN으로 MNIST 이미지 생성 성공(?)

NeuroWhAI 2018. 2. 9. 13:53


저번에 했다가 실패했다고 말씀드렸었는데 이번에 어느정도 성공했습니다! (이전 글)

코드:
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
#-*- coding: utf-8 -*-
 
import tensorflow as tf
import numpy as np
import matplotlib.pyplot as plt
 
from tensorflow.examples.tutorials.mnist import input_data
mnist = input_data.read_data_sets("./mnist/data/", one_hot=True)
 
total_epoch = 100
batch_size = 100
n_noise = 100
 
D_global_step = tf.Variable(0, trainable=False, name='D_global_step')
G_global_step = tf.Variable(0, trainable=False, name='G_global_step')
 
= tf.placeholder(tf.float32, [None, 28281])
= tf.placeholder(tf.float32, [None, n_noise])
is_training = tf.placeholder(tf.bool)
 
def leaky_relu(x, leak=0.2):
    return tf.maximum(x, x * leak)
 
def generator(noise):
    with tf.variable_scope('generator'):
        output = tf.layers.dense(noise, 128*7*7)
        output = tf.reshape(output, [-177128])
        output = tf.nn.relu(tf.layers.batch_normalization(output, training=is_training))
        output = tf.layers.conv2d_transpose(output, 128, [55], strides=(22), padding='SAME')
        output = tf.nn.relu(tf.layers.batch_normalization(output, training=is_training))
        output = tf.layers.conv2d_transpose(output, 64, [55], strides=(22), padding='SAME')
        output = tf.nn.relu(tf.layers.batch_normalization(output, training=is_training))
        output = tf.layers.conv2d_transpose(output, 1, [55], strides=(11), padding='SAME')
        output = tf.tanh(output)
    return output
 
def discriminator(inputs, reuse=None):
    with tf.variable_scope('discriminator') as scope:
        if reuse:
            scope.reuse_variables()
        output = tf.layers.conv2d(inputs, 32, [55], strides=(22), padding='SAME')
        output = tf.layers.conv2d(output, 64, [55], strides=(22), padding='SAME')
        output = leaky_relu(tf.layers.batch_normalization(output, training=is_training))
        output = tf.layers.conv2d(output, 128, [55], strides=(22), padding='SAME')
        output = leaky_relu(tf.layers.batch_normalization(output, training=is_training))
        flat = tf.contrib.layers.flatten(output)
        output = tf.layers.dense(flat, 1, activation=None)
    return output
 
def get_noise(batch_size, n_noise):
    return np.random.uniform(-1.1., size=[batch_size, n_noise])
 
= generator(Z)
D_real = discriminator(X)
D_gene = discriminator(G, True)
 
loss_D_real = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(
    logits=D_real, labels=tf.ones_like(D_real)
))
loss_D_gene = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(
    logits=D_gene, labels=tf.zeros_like(D_gene)
))
 
loss_D = loss_D_real + loss_D_gene
loss_G = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(
    logits=D_gene, labels=tf.ones_like(D_gene)
))
 
vars_D = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES,
    scope='discriminator')
vars_G = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES,
    scope='generator')
 
train_D = tf.train.AdamOptimizer().minimize(loss_D,
    var_list=vars_D, global_step=D_global_step)
train_G = tf.train.AdamOptimizer().minimize(loss_G,
    var_list=vars_G, global_step=G_global_step)
 
tf.summary.scalar('loss_D', loss_D)
tf.summary.scalar('loss_G', loss_G)
 
with tf.Session() as sess:
    sess.run(tf.global_variables_initializer())
 
    merged = tf.summary.merge_all()
    writer = tf.summary.FileWriter('./logs', sess.graph)
 
    total_batch = int(mnist.train.num_examples / batch_size)
 
    for epoch in range(total_epoch):
        loss_val_D, loss_val_G = 00
 
        batch_xs, batch_ys = None, None
        noise = None
 
        for i in range(total_batch):
            batch_xs, batch_ys = mnist.train.next_batch(batch_size)
            batch_xs = batch_xs.reshape(-128281)
            noise = get_noise(batch_size, n_noise)
 
            _, loss_val_D = sess.run([train_D, loss_D],
                feed_dict={X: batch_xs, Z: noise, is_training: True})
            _, loss_val_G = sess.run([train_G, loss_G],
                feed_dict={Z: noise, is_training: True})
 
        summary = sess.run(merged,
            feed_dict={X: batch_xs, Z: noise, is_training: True})
        writer.add_summary(summary, global_step=sess.run(G_global_step))
 
        if epoch == 0 or (epoch + 1) % 10 == 0:
            print('Epoch:''%04d' % epoch,
                'D loss: {:.4}'.format(loss_val_D),
                'G loss: {:.4}'.format(loss_val_G))
 
            sample_size = 10
            noise = get_noise(sample_size, n_noise)
            f_samples = sess.run(G, feed_dict={Z: noise, is_training: False})
            t_samples = sess.run(G, feed_dict={Z: noise, is_training: True})
 
            fig, ax = plt.subplots(2, sample_size, figsize=(sample_size, 2))
 
            for i in range(sample_size):
                ax[0][i].set_axis_off()
                ax[1][i].set_axis_off()
                ax[0][i].imshow(np.reshape(f_samples[i], (2828)))
                ax[1][i].imshow(np.reshape(t_samples[i], (2828)))
 
            plt.savefig('ft_{}.png'.format(str(epoch).zfill(3)),
                bbox_inches='tight')
            
            plt.close(fig)
 
cs

결과:
※ 사진의 첫번째 줄은 batch_normalization의 is_training 옵션을 False로하고 생성했고
    두번째 줄은 True로 하고 생성한겁니다.
※ 아래로 갈수록 학습이 더 진행된 결과입니다.










보시다시피 batch_normalization의 is_training 옵션이 False로 되어있으면 이미지가 뭔가 좀 이상하게 나옵니다 ㅠㅠ
이름대로라면 False로 하고 생성하는게 맞을텐데;;

이전 코드와의 변경점이라면 일단 중대한 오류를 하나 수정했습니다.
판별기에서 첫번째 레이어의 출력을 두번째 레이어의 입력으로 넣지 않고 입력을 또 다시 두번째 레이어로 넣어버리는
복붙으로 인한 참사가 있었습니다.
또한 판별기와 생성기의 층 깊이를 더 늘렸고 학습 세대도 100번으로 늘렸습니다.
이전엔 텐서플로 CPU버전으로 돌려서 소박하게 만들었었지만 이번엔 GPU버전을 사용했기에 이렇게 했습니다.
또한 이게 의미가 있을진 모르겠지만 다른 코드에서 그러길래 저도 판별기의 Dense 레이어를 두개에서 하나로 줄였습니다.
그리고 판별기의 입력 레이어엔 활성화 함수와 batch_normalization을 사용하지 않았습니다.
이제보니 활성화 함수는 왜 뺐나 싶네요;;

conv2d_transpose도 그렇고 아직 잘 모르는것들을 무리해서 쓰니까 잘 안되는것 같습니다 ㅠㅠ




Comments