NeuroWhAI의 잡블로그
[Python] Keras로 MNIST 학습하고 직접 그린 이미지 추측시켜보기 본문
케라스로 MNIST 데이터를 학습시키고 외부 이미지 하나를 불러와서 무슨 숫자인지 출력해보는 예제입니다.
기본적으로 CNN이고 배치정규화, 드롭아웃을 추가로 사용했습니다.
이미지는 무조건 input.png라는 이름으로 작업 경로에 존재해야 하고 28x28 크기여야 합니다.
배경은 검은색, 숫자는 흰색으로 그리세요.
input.png가 없으면 구글 드라이브에서 테스트 이미지를 다운로드하는데 번거로우니 그냥 직접 그리세요.
미리 학습된 가중치 데이터를 원하시면 아래 링크에서 받으시고 작업 경로에 넣어주세요.
코드:
# !pip install -U -q PyDrive
import sys
import os.path
import numpy as np
import keras
from keras import layers, models, datasets, backend
from keras.utils import np_utils
import matplotlib.pyplot as plt
from PIL import Image
def get_file_from_drive(file_id):
"""
구글 드라이브에서 file_id에 해당하는 파일을 가져와 읽습니다.
Colab에서 돌리기 위해 필요합니다.
"""
# Install the PyDrive wrapper & import libraries.
# This only needs to be done once per notebook.
from pydrive.auth import GoogleAuth
from pydrive.drive import GoogleDrive
from google.colab import auth
from oauth2client.client import GoogleCredentials
# Authenticate and create the PyDrive client.
# This only needs to be done once per notebook.
auth.authenticate_user()
gauth = GoogleAuth()
gauth.credentials = GoogleCredentials.get_application_default()
drive = GoogleDrive(gauth)
# Download a file based on its file ID.
#
# A file ID looks like: laggVyWshwcyP6kEI-y_W3P8D26sz
downloaded = drive.CreateFile({'id': file_id})
return downloaded
class CNN(models.Sequential):
def __init__(self, input_shape, num_classes):
super().__init__()
self.add(layers.Conv2D(32, kernel_size=(3, 3), input_shape=input_shape))
self.add(layers.BatchNormalization(axis=1))
self.add(layers.Activation('relu'))
self.add(layers.MaxPooling2D(pool_size=(2, 2)))
self.add(layers.Conv2D(64, kernel_size=(3, 3)))
self.add(layers.BatchNormalization(axis=1))
self.add(layers.Activation('relu'))
self.add(layers.MaxPooling2D(pool_size=(2, 2)))
self.add(layers.Dropout(0.25))
self.add(layers.Flatten())
self.add(layers.Dense(128))
self.add(layers.BatchNormalization(axis=1))
self.add(layers.Activation('relu'))
self.add(layers.Dropout(0.25))
self.add(layers.Dense(num_classes, activation='softmax'))
self.compile(loss=keras.losses.categorical_crossentropy,
optimizer='adam', metrics=['accuracy'])
class MnistData():
def __init__(self):
(x_train, y_train), (x_test, y_test) = datasets.mnist.load_data()
y_train = np_utils.to_categorical(y_train)
y_test = np_utils.to_categorical(y_test)
img_rows, img_cols = x_train.shape[1:]
if backend.image_data_format() == 'channels_first':
x_train = x_train.reshape(x_train.shape[0], 1, img_rows, img_cols)
x_test = x_test.reshape(x_test.shape[0], 1, img_rows, img_cols)
input_shape = (1, img_rows, img_cols)
else:
x_train = x_train.reshape(x_train.shape[0], img_rows, img_cols, 1)
x_test = x_test.reshape(x_test.shape[0], img_rows, img_cols, 1)
input_shape = (img_rows, img_cols, 1)
x_train = x_train.astype('float32')
x_test = x_test.astype('float32')
x_train /= 255.0
x_test /= 255.0
self.input_shape = input_shape
self.num_classes = 10
self.x_train, self.y_train = x_train, y_train
self.x_test, self.y_test = x_test, y_test
def plot_loss(history):
plt.plot(history.history['loss'])
plt.plot(history.history['val_loss'])
plt.title('Model Loss')
plt.ylabel('Loss')
plt.xlabel('Epoch')
plt.legend(['Train', 'Test'], loc=0)
def plot_acc(history):
plt.plot(history.history['acc'])
plt.plot(history.history['val_acc'])
plt.title('Model accuracy')
plt.ylabel('Accuracy')
plt.xlabel('Epoch')
plt.legend(['Train', 'Test'], loc=0)
def train():
batch_size = 100
epochs = 15
data = MnistData()
model = CNN(data.input_shape, data.num_classes)
history = model.fit(data.x_train, data.y_train, epochs=epochs,
batch_size=batch_size, validation_split=0.2, verbose=2)
performance_test = model.evaluate(data.x_test, data.y_test, batch_size=100,
verbose=0)
print('\nTest Result ->', performance_test)
model.save_weights('cnn_model_mnist.h5')
plot_acc(history)
plt.show()
plot_loss(history)
plt.show()
def predict():
img = Image.open("input.png")
img_data = np.array(img)
plt.imshow(img_data)
plt.show()
if backend.image_data_format() == 'channels_first':
input_shape = (1, 28, 28)
img_data = img_data.transpose(2, 0, 1)[1].reshape(1, 1, 28, 28)
else:
input_shape = (28, 28, 1)
img_data = img_data[:, :, 1].reshape(1, 28, 28, 1)
img_data = img_data.astype('float32') / 255.0
model = CNN(input_shape, 10)
model.load_weights('cnn_model_mnist.h5')
output = model.predict(img_data)
print("Answer :", np.argmax(output))
if __name__ == '__main__':
choice = input("train or predict")
if choice == "train":
train()
elif choice == "predict":
if os.path.exists("input.png") == False:
print("Download a input image...")
get_file_from_drive("1NaBmL3T3EPB0uByFagXVnhYRHZ0Hque8").GetContentFile("input.png")
predict()
매번 학습만 주구장창 시키고 정작 직접 써보질 않아서 가장 만만한 MNIST로 한번 해봤습니다.
시간이 더 있었으면 C#으로 그림판 비스무리하게 만들어서 숫자 그리면 바로 예측을 돌릴 수 있게 할텐데 그건 뭐 주말에 심심하면...
그나저나 이렇게 대충 만들어도 학습, 테스트 정확도 99%는 그냥 찍네요 ㄷㄷ;;
'개발 및 공부 > 라이브러리&프레임워크' 카테고리의 다른 글
[Rust] Rocket 사용해서 20줄로 정적 파일 서버 만들기 (0) | 2018.10.04 |
---|---|
[Rust] Rocket으로 웹 서버 만들어서 Heroku에 올리기 (0) | 2018.10.03 |
VRChat API를 사용해봤습니다. (0) | 2018.06.27 |
[C#] Selenium을 이용해 YouTube 추천 동영상 파싱하기 (0) | 2018.05.07 |
[Keras] 레이어 직접 만들기 (0) | 2018.04.08 |
Comments