NeuroWhAI의 잡블로그

[TensorFlow] Google의 Inception 모델로 꽃 분류하기 본문

개발 및 공부/라이브러리&프레임워크

[TensorFlow] Google의 Inception 모델로 꽃 분류하기

NeuroWhAI 2018. 2. 15. 16:14


※ 이 글은 '골빈해커의 3분 딥러닝 텐서플로맛'이라는 책을 보고 실습한걸 기록한 글입니다.


이론적인 설명은 저도 여기서 공부하면 될것 같고 그냥 책에 나온대로 따라만 해봤습니다.

이번엔 이미 만들어진 모델을 사용하는거라서 아래의 단계로 진행되었습니다.

  1. 학습할 꽃 데이터 다운로드
  2. Inception 모델을 학습시키는 스크립트 다운로드
  3. 꽃 데이터로 모델 학습
  4. 학습된 모델을 불러와서 사용

이 글에는 4번 단계만 적었습니다.


모델을 불러와서 사용하는 코드:
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
#-*- coding: utf-8 -*-
 
import sys
import tensorflow as tf
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.image as mpimg
 
tf.app.flags.DEFINE_string("output_graph",
    "./workspace/flowers_graph.pb",
    "학습된 신경망이 저장된 위치")
tf.app.flags.DEFINE_string("output_labels",
    "./workspace/flowers_labels.txt",
    "학습할 레이블 데이터 파일")
tf.app.flags.DEFINE_string("show_image",
    True,
    "이미지 추론 후 이미지를 보여줍니다.")
 
FLAGS = tf.app.flags.FLAGS
 
def main(_):
    labels = [line.rstrip() for line in tf.gfile.GFile(
        FLAGS.output_labels
    )]
 
    with tf.gfile.FastGFile(FLAGS.output_graph, 'rb') as fp:
        graph_def = tf.GraphDef()
        graph_def.ParseFromString(fp.read())
        tf.import_graph_def(graph_def, name='')
 
    with tf.Session() as sess:
        logits = sess.graph.get_tensor_by_name('final_result:0')
        image = tf.gfile.FastGFile(sys.argv[1], 'rb').read()
        prediction = sess.run(logits, {'DecodeJpeg/contents:0': image})
 
        print("=== 예측 결과 ===")
        for i in range(len(labels)):
            name = labels[i]
            score = prediction[0][i]
            print('%s (%.2f%%)' % (name, score * 100))
        top_result = int(np.argmax(prediction[0]))
        name = labels[top_result]
        score = prediction[0][top_result]
        print('> %s (%.2f%%)' % (name, score * 100))
 
        if FLAGS.show_image:
            img = mpimg.imread(sys.argv[1])
            plt.imshow(img)
            plt.show()
 
if __name__ == '__main__':
    tf.app.run()
cs

임의로 튤립 이미지를 하나 지정해서 실행한 결과:
=== 예측 결과 ===
tulips (91.31%)
sunflowers (3.88%)
roses (4.31%)
daisy (0.09%)
dandelion (0.41%)
> tulips (91.31%)



이번 글에선 딱히 적을게 없네요.
똑똑한 사람들이 만든 이런 모델도 정확도가 99% 막 이러는게 아닌걸 보면 확실히 실세계의 문제를 해결하는건 정말 어려운 일인것 같습니다.





Comments