NeuroWhAI의 잡블로그

[Keras] U-Net으로 흑백 이미지를 컬러로 바꾸기 본문

개발 및 공부/라이브러리&프레임워크

[Keras] U-Net으로 흑백 이미지를 컬러로 바꾸기

NeuroWhAI 2018. 3. 29. 19:57 ...


※ 이 글은 '코딩셰프의 3분 딥러닝 케라스맛'이라는 책을 보고 실습한걸 기록한 글입니다.


왜 UNet인진 모르겠는데 신경망 구조를 보니까 U처럼 생겨서 UNet인가 싶네요 ㅋㅋ


오토인코더와 비슷한 구조인데 인코딩 과정에서 나온 각 층의 출력을 디코딩 과정의 각 층에서 입력으로 사용하고 있는게 차이점입니다.
뭐 이론적인건 잘 모르겠고 책에선 이렇게 함으로서 이미지 복원력이 더 뛰어나진다고 하네요.

이번에도 코드는 책의 코드를 좀 간소화해서 다를겁니다.

코드:
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
from keras import layers, models, optimizers
from keras import datasets
from keras import backend
import matplotlib.pyplot as plt
import numpy as np
 
class UNET(models.Model):
  def __init__(self, org_shape, n_ch):
    channel_index = 3 if backend.image_data_format() == 'channels_last' else 1
    
    def conv(x, n_f, mp_flag=True):
      x = layers.MaxPooling2D((22), padding='same')(x) if mp_flag else x
      x = layers.Conv2D(n_f, (33), padding='same')(x)
      x = layers.BatchNormalization()(x)
      x = layers.Activation('tanh')(x)
      x = layers.Conv2D(n_f, (33), padding='same')(x)
      x = layers.BatchNormalization()(x)
      x = layers.Activation('tanh')(x)
      return x
    
    def deconv_unet(x, e, n_f):
      x = layers.UpSampling2D((22))(x)
      x = layers.Concatenate(axis=channel_index)([x, e])
      x = layers.Conv2D(n_f, (33), padding='same')(x)
      x = layers.BatchNormalization()(x)
      x = layers.Activation('tanh')(x)
      x = layers.Conv2D(n_f, (33), padding='same')(x)
      x = layers.BatchNormalization()(x)
      x = layers.Activation('tanh')(x)
      return x
    
    original = layers.Input(shape=org_shape)
    
    c1 = conv(original, 16, False)
    c2 = conv(c1, 32)
    encoded = conv(c2, 64)
    
    x = deconv_unet(encoded, c2, 32)
    x = deconv_unet(x, c1, 16)
    decoded = layers.Conv2D(n_ch, (33), activation='sigmoid',
                           padding='same')(x)
    
    super().__init__(original, decoded)
    self.compile(optimizer='adadelta', loss='mse')
    
class DATA():
  def __init__(self):
    (x_train, y_train), (x_test, y_test) = datasets.cifar10.load_data()
    
    if backend.image_data_format() == 'channels_first':
      n_ch, img_rows, img_cols = x_train.shape[1:]
      input_shape = (1, img_rows, img_cols)
    else:
      img_rows, img_cols, n_ch = x_train.shape[1:]
      input_shape = (img_rows, img_cols, 1)
      
    x_train = x_train.astype('float32'/ 255.0
    x_test = x_test.astype('float32'/ 255.0
    
    def RGB2Gray(img, fmt):
      if fmt == 'channels_first':
        R = img[:, 0:1]
        G = img[:, 1:2]
        B = img[:, 2:3]
      else:
        R = img[..., 0:1]
        G = img[..., 1:2]
        B = img[..., 2:3]
      return 0.299 * R + 0.587 * G + 0.114 * B
    
    x_train_in = RGB2Gray(x_train, backend.image_data_format())
    x_test_in = RGB2Gray(x_test, backend.image_data_format())
    
    self.input_shape = input_shape
    self.x_train_in, self.x_train_out = x_train_in, x_train
    self.x_test_in, self.x_test_out = x_test_in, x_test
    self.n_ch = n_ch
    
def plot_loss(history):
    plt.plot(history.history['loss'])
    plt.plot(history.history['val_loss'])
    plt.title('Model Loss')
    plt.ylabel('Loss')
    plt.xlabel('Epoch')
    plt.legend(['Train''Test'], loc=0)
 
def plot_acc(history):
    plt.plot(history.history['acc'])
    plt.plot(history.history['val_acc'])
    plt.title('Model accuracy')
    plt.ylabel('Accuracy')
    plt.xlabel('Epoch')
    plt.legend(['Train''Test'], loc=0)
    
def show_images(in_imgs, out_imgs, unet, sample_size=10):
  x_test_in = in_imgs[:sample_size]
  x_test_out = out_imgs[:sample_size]
  decoded_imgs = unet.predict(x_test_in, batch_size=sample_size)
  
  print("Before")
  print("x_test_in:", x_test_in.shape)
  print("decoded_imgs:", decoded_imgs.shape)
  print("x_test_out:", x_test_out.shape)
  
  if backend.image_data_format() == 'channels_first':
    x_test_out = x_test_out.swapaxes(13).swapaxes(12)
    decoded_imgs = decoded_imgs.swapaxes(13).swapaxes(12)
    
    x_test_in = x_test_in[:, 0, ...]
  else:
    x_test_in = x_test_in[..., 0]
  
  print("After")
  print("x_test_in:", x_test_in.shape)
  print("decoded_imgs:", decoded_imgs.shape)
  print("x_test_out:", x_test_out.shape)
    
  plt.figure(figsize=(206))
  for i in range(sample_size):
    ax = plt.subplot(3, sample_size, i + 1)
    plt.imshow(x_test_in[i])
    ax.get_xaxis().set_visible(False)
    ax.get_yaxis().set_visible(False)
    
    ax = plt.subplot(3, sample_size, i + 1 + sample_size)
    plt.imshow(decoded_imgs[i])
    ax.get_xaxis().set_visible(False)
    ax.get_yaxis().set_visible(False)
    
    ax = plt.subplot(3, sample_size, i + 1 + sample_size * 2)
    plt.imshow(x_test_out[i])
    ax.get_xaxis().set_visible(False)
    ax.get_yaxis().set_visible(False)
    
  plt.show()
  
def main(epochs=10, batch_size=128):
  data = DATA()
  unet = UNET(data.input_shape, data.n_ch)
  
  history = unet.fit(data.x_train_in, data.x_train_out,
                    epochs=epochs,
                    batch_size=batch_size,
                    shuffle=True,
                    validation_split=0.2,
                    verbose=2)
  
  plot_loss(history)
  show_images(data.x_test_in, data.x_test_out, unet)
  
if __name__ == '__main__':
  main()
cs

결과:
Train on 40000 samples, validate on 10000 samples
Epoch 1/10
- 35s - loss: 0.0133 - val_loss: 0.0104
Epoch 2/10
- 31s - loss: 0.0088 - val_loss: 0.0098
Epoch 3/10
- 31s - loss: 0.0079 - val_loss: 0.0076
Epoch 4/10
- 31s - loss: 0.0073 - val_loss: 0.0076
Epoch 5/10
- 31s - loss: 0.0070 - val_loss: 0.0159
Epoch 6/10
- 31s - loss: 0.0068 - val_loss: 0.0073
Epoch 7/10
- 31s - loss: 0.0067 - val_loss: 0.0071
Epoch 8/10
- 31s - loss: 0.0066 - val_loss: 0.0070
Epoch 9/10
- 31s - loss: 0.0065 - val_loss: 0.0066
Epoch 10/10
- 31s - loss: 0.0064 - val_loss: 0.0081
Before
x_test_in: (10, 32, 32, 1)
decoded_imgs: (10, 32, 32, 3)
x_test_out: (10, 32, 32, 3)
After
x_test_in: (10, 32, 32)
decoded_imgs: (10, 32, 32, 3)
x_test_out: (10, 32, 32, 3)




마지막 사진의 첫번째 줄은 복원할 흑백 이미지(인데 왜 색반전이 되어있지;;)이고
두번째 줄이 복원된 컬러 이미지, 세번째는 실제 컬러 이미지입니다.
뭐 보시면 아시겠지만 결과가 그렇게 마음에 들진 않네요.
컬러 이미지가 됬다기 보다는 그냥 소피아 필터 씌운 느낌이 되버렸습니다.


아래는 코드 설명.

신경망을 만드는 부분은 특이한게 없습니다.
새로운 레이어도 없지만 중요한건 구조가 되겠네요.
디코딩할때 인코딩 층의 중간 결과(c1, c2)를 사용하고 있습니다.

학습 데이터를 만드는 부분을 보시면 cifar10은 컬러 이미지 뿐이므로 컬러 이미지를 흑백으로 바꾸는게 반을 차지합니다.

show_images 함수를 보시면 swapaxes라는걸 사용하고 있는데
말 그대로 두 축을 swap하는 함수입니다.
channels_first이므로 (Batch, Channel, Width, Height)일텐데 이거를
먼저 1, 3 축을 바꾸면 (Batch, Height, Width, Channel)이 되고
이어서 1, 2 축을 바꾸면 (Batch, Width, Height, Channel)이 됩니다.
채널을 마지막으로 옮겨주는 작업인거죠.

끝!


코드엔 설명할게 별로 없네요.
이번엔 케라스의 새로운 요소보다는 U-Net이라는 신경망 구조에 집중한 예시라서 그런거 같습니다.
결과가 만족스럽지 않으므로 좀더 검색해서 구조를 바꿔봐야겠습니다.


검색하니까 이렇게 친절한 구조 설명도 나오던데 나중에 이거대로 한번 만들어 보겠습니다 ㅋㅋ



Tag
,

0 Comments
댓글쓰기 폼