NeuroWhAI의 잡블로그

[Python] Keras로 MNIST 학습하고 직접 그린 이미지 추측시켜보기 본문

개발 및 공부/라이브러리&프레임워크

[Python] Keras로 MNIST 학습하고 직접 그린 이미지 추측시켜보기

NeuroWhAI 2018. 7. 30. 20:41 ...


케라스로 MNIST 데이터를 학습시키고 외부 이미지 하나를 불러와서 무슨 숫자인지 출력해보는 예제입니다.

기본적으로 CNN이고 배치정규화, 드롭아웃을 추가로 사용했습니다.


이미지는 무조건 input.png라는 이름으로 작업 경로에 존재해야 하고 28x28 크기여야 합니다.

배경은 검은색, 숫자는 흰색으로 그리세요.

input.png가 없으면 구글 드라이브에서 테스트 이미지를 다운로드하는데 번거로우니 그냥 직접 그리세요.


미리 학습된 가중치 데이터를 원하시면 아래 링크에서 받으시고 작업 경로에 넣어주세요.

cnn_model_mnist.h5


코드:

# !pip install -U -q PyDrive import sys import os.path import numpy as np import keras from keras import layers, models, datasets, backend from keras.utils import np_utils import matplotlib.pyplot as plt from PIL import Image def get_file_from_drive(file_id): """ 구글 드라이브에서 file_id에 해당하는 파일을 가져와 읽습니다. Colab에서 돌리기 위해 필요합니다. """ # Install the PyDrive wrapper & import libraries. # This only needs to be done once per notebook. from pydrive.auth import GoogleAuth from pydrive.drive import GoogleDrive from google.colab import auth from oauth2client.client import GoogleCredentials # Authenticate and create the PyDrive client. # This only needs to be done once per notebook. auth.authenticate_user() gauth = GoogleAuth() gauth.credentials = GoogleCredentials.get_application_default() drive = GoogleDrive(gauth) # Download a file based on its file ID. # # A file ID looks like: laggVyWshwcyP6kEI-y_W3P8D26sz downloaded = drive.CreateFile({'id': file_id}) return downloaded class CNN(models.Sequential): def __init__(self, input_shape, num_classes): super().__init__() self.add(layers.Conv2D(32, kernel_size=(3, 3), input_shape=input_shape)) self.add(layers.BatchNormalization(axis=1)) self.add(layers.Activation('relu')) self.add(layers.MaxPooling2D(pool_size=(2, 2))) self.add(layers.Conv2D(64, kernel_size=(3, 3))) self.add(layers.BatchNormalization(axis=1)) self.add(layers.Activation('relu')) self.add(layers.MaxPooling2D(pool_size=(2, 2))) self.add(layers.Dropout(0.25)) self.add(layers.Flatten()) self.add(layers.Dense(128)) self.add(layers.BatchNormalization(axis=1)) self.add(layers.Activation('relu')) self.add(layers.Dropout(0.25)) self.add(layers.Dense(num_classes, activation='softmax')) self.compile(loss=keras.losses.categorical_crossentropy, optimizer='adam', metrics=['accuracy']) class MnistData(): def __init__(self): (x_train, y_train), (x_test, y_test) = datasets.mnist.load_data() y_train = np_utils.to_categorical(y_train) y_test = np_utils.to_categorical(y_test) img_rows, img_cols = x_train.shape[1:] if backend.image_data_format() == 'channels_first': x_train = x_train.reshape(x_train.shape[0], 1, img_rows, img_cols) x_test = x_test.reshape(x_test.shape[0], 1, img_rows, img_cols) input_shape = (1, img_rows, img_cols) else: x_train = x_train.reshape(x_train.shape[0], img_rows, img_cols, 1) x_test = x_test.reshape(x_test.shape[0], img_rows, img_cols, 1) input_shape = (img_rows, img_cols, 1) x_train = x_train.astype('float32') x_test = x_test.astype('float32') x_train /= 255.0 x_test /= 255.0 self.input_shape = input_shape self.num_classes = 10 self.x_train, self.y_train = x_train, y_train self.x_test, self.y_test = x_test, y_test def plot_loss(history): plt.plot(history.history['loss']) plt.plot(history.history['val_loss']) plt.title('Model Loss') plt.ylabel('Loss') plt.xlabel('Epoch') plt.legend(['Train', 'Test'], loc=0) def plot_acc(history): plt.plot(history.history['acc']) plt.plot(history.history['val_acc']) plt.title('Model accuracy') plt.ylabel('Accuracy') plt.xlabel('Epoch') plt.legend(['Train', 'Test'], loc=0) def train(): batch_size = 100 epochs = 15 data = MnistData() model = CNN(data.input_shape, data.num_classes) history = model.fit(data.x_train, data.y_train, epochs=epochs, batch_size=batch_size, validation_split=0.2, verbose=2) performance_test = model.evaluate(data.x_test, data.y_test, batch_size=100, verbose=0) print('\nTest Result ->', performance_test) model.save_weights('cnn_model_mnist.h5') plot_acc(history) plt.show() plot_loss(history) plt.show() def predict(): img = Image.open("input.png") img_data = np.array(img) plt.imshow(img_data) plt.show() if backend.image_data_format() == 'channels_first': input_shape = (1, 28, 28) img_data = img_data.transpose(2, 0, 1)[1].reshape(1, 1, 28, 28) else: input_shape = (28, 28, 1) img_data = img_data[:, :, 1].reshape(1, 28, 28, 1) img_data = img_data.astype('float32') / 255.0 model = CNN(input_shape, 10) model.load_weights('cnn_model_mnist.h5') output = model.predict(img_data) print("Answer :", np.argmax(output)) if __name__ == '__main__': choice = input("train or predict") if choice == "train": train() elif choice == "predict": if os.path.exists("input.png") == False: print("Download a input image...") get_file_from_drive("1NaBmL3T3EPB0uByFagXVnhYRHZ0Hque8").GetContentFile("input.png") predict()


매번 학습만 주구장창 시키고 정작 직접 써보질 않아서 가장 만만한 MNIST로 한번 해봤습니다.

시간이 더 있었으면 C#으로 그림판 비스무리하게 만들어서 숫자 그리면 바로 예측을 돌릴 수 있게 할텐데 그건 뭐 주말에 심심하면...


그나저나 이렇게 대충 만들어도 학습, 테스트 정확도 99%는 그냥 찍네요 ㄷㄷ;;


Tag
,

0 Comments
댓글쓰기 폼