NeuroWhAI의 잡블로그

[Python] Keras DCGAN으로 포켓몬 이미지 생성 (+소스코드) 본문

개발 및 공부/라이브러리&프레임워크

[Python] Keras DCGAN으로 포켓몬 이미지 생성 (+소스코드)

NeuroWhAI 2018. 10. 31. 20:18 ...

저번에 따로 해봤는데 잘 안되었던 이유가 모델이 이상한건지 이미지가 커서 그랬던건지 모르겠어서
일단 모델을 새로 설계하고 32x32x3(RGB)의 작은 사이즈로 시도했습니다.
그랬더니 학습이 잘 진행되더군요.

아래는 대략 800장의 이미지를 900~1000 에포크 동안 학습한 결과물입니다.


코드는 아래와 같습니다.
import os.path
import numpy as np
from keras.models import *
from keras.layers import *
from keras.optimizers import *
import keras.backend as K
import matplotlib.pyplot as plt

K.set_image_data_format('channels_last')

class Gan:
  def __init__(self, img_data):
    img_size = img_data.shape[1]
    channel = img_data.shape[3] if len(img_data.shape) >= 4 else 1
    
    self.img_data = img_data
    self.input_shape = (img_size, img_size, channel)
    
    self.img_rows = img_size
    self.img_cols = img_size
    self.channel = channel
    self.noise_size = 100
    
    # Create D and G.
    self.create_d()
    self.create_g()
    
    # Build model to train D.
    optimizer = Adam(lr=0.0008)
    self.D.compile(loss='binary_crossentropy', optimizer=optimizer)
    
    # Build model to train G.
    optimizer = Adam(lr=0.0004)
    self.D.trainable = False
    self.AM = Sequential()
    self.AM.add(self.G)
    self.AM.add(self.D)
    self.AM.compile(loss='binary_crossentropy', optimizer=optimizer)
  
  def create_d(self):
    self.D = Sequential()
    depth = 64
    dropout = 0.4
    self.D.add(Conv2D(depth*1, 5, strides=2, input_shape=self.input_shape,
                      padding='same'))
    self.D.add(LeakyReLU(alpha=0.2))
    self.D.add(Dropout(dropout))
    self.D.add(Conv2D(depth*2, 5, strides=2, padding='same'))
    self.D.add(LeakyReLU(alpha=0.2))
    self.D.add(Dropout(dropout))
    self.D.add(Conv2D(depth*4, 5, strides=2, padding='same'))
    self.D.add(LeakyReLU(alpha=0.2))
    self.D.add(Dropout(dropout))
    self.D.add(Conv2D(depth*8, 5, strides=1, padding='same'))
    self.D.add(LeakyReLU(alpha=0.2))
    self.D.add(Dropout(dropout))
    self.D.add(Flatten())
    self.D.add(Dense(1))
    self.D.add(Activation('sigmoid'))
    self.D.summary()
    return self.D
  
  def create_g(self):
    self.G = Sequential()
    dropout = 0.4
    depth = 64+64+64+64
    dim = 8
    self.G.add(Dense(dim*dim*depth, input_dim=self.noise_size))
    self.G.add(BatchNormalization(momentum=0.9))
    self.G.add(Activation('relu'))
    self.G.add(Reshape((dim, dim, depth)))
    self.G.add(Dropout(dropout))
    self.G.add(UpSampling2D())
    self.G.add(Conv2DTranspose(int(depth/2), 5, padding='same'))
    self.G.add(BatchNormalization(momentum=0.9))
    self.G.add(Activation('relu'))
    self.G.add(UpSampling2D())
    self.G.add(Conv2DTranspose(int(depth/4), 5, padding='same'))
    self.G.add(BatchNormalization(momentum=0.9))
    self.G.add(Activation('relu'))
    self.G.add(Conv2DTranspose(int(depth/8), 5, padding='same'))
    self.G.add(BatchNormalization(momentum=0.9))
    self.G.add(Activation('relu'))
    self.G.add(Conv2DTranspose(self.channel, 5, padding='same'))
    self.G.add(Activation('sigmoid'))
    self.G.summary()
    return self.G
  
  def train(self, batch_size=100):
    # Pick image data randomly.
    images_train = self.img_data[np.random.randint(0, self.img_data.shape[0], size=batch_size), :, :, :]
    
    # Generate images from noise.
    noise = np.random.uniform(-1.0, 1.0, size=[batch_size, self.noise_size])
    images_fake = self.G.predict(noise)
    
    # Train D.
    x = np.concatenate((images_train, images_fake))
    y = np.ones([2*batch_size, 1])
    y[batch_size:, :] = 0
    self.D.trainable = True
    d_loss = self.D.train_on_batch(x, y)
    
    # Train G.
    y = np.ones([batch_size, 1])
    noise = np.random.uniform(-1.0, 1.0, size=[batch_size, self.noise_size])
    self.D.trainable = False
    a_loss = self.AM.train_on_batch(noise, y)
    
    return d_loss, a_loss, images_fake
  
  def save(self):
    self.G.save_weights('gan_g_weights.h5')
    self.D.save_weights('gan_d_weights.h5')
    
  def load(self):
    if os.path.isfile('gan_g_weights.h5'):
      self.G.load_weights('gan_g_weights.h5')
      print("Load G from file.")
    if os.path.isfile('gan_d_weights.h5'):
      self.D.load_weights('gan_d_weights.h5')
      print("Load D from file.")

class PokemonData():
  def __init__(self):
    img_data_list = []
    images = os.listdir("pokemon_rgb")
        
    for path in images:
      img = Image.open("pokemon_rgb/" + path)
      img_data_list.append([np.array(img).astype('float32')])
    
    self.x_train = np.vstack(img_data_list) / 255.0
    print(self.x_train.shape)
  
# Load dataset.
dataset = PokemonData()
x_train = dataset.x_train

# Init network
gan = Gan(x_train)
gan.load()

# Some parameters.
epochs = 500
sample_size = 10
batch_size = 100
train_per_epoch = x_train.shape[0] // batch_size

for epoch in range(0, epochs):
  total_d_loss = 0.0
  total_a_loss = 0.0
  imgs = None
  
  for batch in range(0, train_per_epoch):
    d_loss, a_loss, t_imgs = gan.train(batch_size)
    total_d_loss += d_loss
    total_a_loss += a_loss
    if imgs is None:
      imgs = t_imgs

  if epoch % 20 == 0 or epoch == epochs - 1:
    total_d_loss /= train_per_epoch
    total_a_loss /= train_per_epoch

    print("Epoch: {}, D Loss: {}, AM Loss: {}"
          .format(epoch, total_d_loss, total_a_loss))
  
    # Show generated images.
    fig, ax = plt.subplots(1, sample_size, figsize=(sample_size, 1))
    for i in range(0, sample_size):
      ax[i].set_axis_off()
      ax[i].imshow(imgs[i].reshape((gan.img_rows, gan.img_cols, gan.channel)),
                  interpolation='nearest');
    plt.show()
    plt.close(fig);
    
    # Save weights
    gan.save()
보면 사실 어제 올린 MNIST 예제랑 코드가 별반 다르지 않습니다.
그래도 동작한다는 것은 그만큼 딥러닝이 범용성에 강하다는 것이겠죠.

학습된 모델의 가중치 파일도 업로드하였으니 바로 시험해보고 싶으신 분들은 아래 링크에서 받으시면 됩니다.

데이터셋의 이미지 크기를 32x32로 낮추는 작업은 아래 코드를 이용했습니다.
import os
from PIL import Image

images = os.listdir("pokemon")

for i, name in enumerate(images):
  png = Image.open("pokemon/" + name)
  png.load()

  background = Image.new("RGB", png.size, (0, 0, 0))
  background.paste(png, mask=png.split()[3]) # 3 is the alpha channel

  background.thumbnail((32, 32), Image.ANTIALIAS)
  background.save("pokemon_rgb/" + str(i) + ".jpg", 'JPEG', quality=80)
아, 그리고 학습에 사용한 데이터셋은 Kaggle에 올려진 것을 사용했습니다.

코드 테스트는 Colab에서 했습니다.
역시 GPU 환경에서 돌리니 빨리 끝나더라구요. 1시간 정도? 댕꿀~

다음 목표는 64x64 이미지를 생성하는 것입니다.
너무 작아서 뭐가 뭔지 모르겠거든요 ㅎㅎ...


Tag
, ,

0 Comments
댓글쓰기 폼