使用Inception v3实现图像识别
官方相关模型及文件下载:http://download.tensorflow.org/models/image/imagenet/inception-2015-12-05.tgz
下载的文件列表:
里面有训练好的pb模型,以及标签特征对应文件
载入相关库:
import tensorflow as tf
import os
import numpy as np
import re
from PIL import Image
import matplotlib.pyplot as plt
code:
class NodeLookup(object):
def __init__(self):
#加载对应的分类文件
label_path = "inception_model/imagenet_2012_challenge_label_map_proto.pbtxt"
uid_path = "inception_model/imagenet_synset_to_human_label_map.txt"
self.node_lookup = self.load(label_path, uid_path)
def load(self, label_path, uid_path):
#读取pbtxt文件,按行读取
proto_label_line = tf.gfile.GFile(label_path).readlines()
#将读取的内容以字典的方式存储,key-value
node_to_label_uid = { }
#遍历每一行
for line in proto_label_line:
#去掉换行符
line = line.strip('\n')
#如果每一行以target_class开头:则对应的内容为 target_class: 442
if line.startswith(" target_class:"):
target_class = int(line.split(": ")[1])#442
if line.startswith(" target_class_string:"):
#包含引号
target_class_string = line.split(": ")[1]
# target_class_string: "n01494475",取中间的字母数字部分,不包含引号
node_to_label_uid[target_class] = target_class_string[1:-1]
proto_uid_line = tf.gfile.GFile(uid_path).readlines()
uid_to_label_human = { }
for line in proto_uid_line:
line = line.strip('\n')
#分割
parse_items = line.split("\t")
uid = parse_items[0]
target = parse_items[1]
uid_to_label_human[uid] = target
#uid_to_label_human[parse_items[0]] = parse_items[1]
#重新对应如种类编号442对应类别cat(假设)
node_id_to_name = { }
for key, val in node_to_label_uid.items():
name = uid_to_label_human[val]
node_id_to_name[key] = name
return node_id_to_name
def id_to_string(self, node_id):
if node_id not in self.node_lookup:
return ''
return self.node_lookup[node_id]
#读取模型,重新创建相应的图,固定书写模式
with tf.gfile.GFile("inception_model/classify_image_graph_def.pb",'rb') as f:
graph_def = tf.GraphDef()
graph_def.ParseFromString(f.read())
tf.import_graph_def(graph_def, name="")
with tf.Session() as sess:
softmax_tensor = sess.graph.get_tensor_by_name("softmax:0")
for root, dirs, files in os.walk("images/"):
for file in files:
image_data = tf.gfile.GFile(os.path.join(root, file), 'rb').read()
predict = sess.run(softmax_tensor, { "DecodeJpeg/contents:0":image_data})#图片是jpg格式
predict = np.squeeze(predict)#将结果转为一维数据
print(predict.shape)
image_path = os.path.join(root, file)
print(image_path)
img = Image.open(image_path)
plt.imshow(img)
plt.axis("off")
plt.show()
top_pre = predict.argsort()[-3:][::-1]#升序排列,选择最后三个数据,再倒叙,即降序
node_lookup = NodeLookup()
for node_id in top_pre:
human_string = node_lookup.id_to_string(node_id)
score = predict[node_id]
print("%s (score=%.5f)" %(human_string, score))
print("*************************")
结果:
bug
1、readline()&readlines(),这里使用readlines(),否则只读取一行数据
---------------------------------------------------------------------------
IndexError Traceback (most recent call last)
<ipython-input-21-74081abf60d9> in <module>
74 #对预测结果排序,从小到大排序,取最后几个,[::-1]:对最后的五个概率值取倒叙
75 top_k = predictions.argsort()[-5:][: :-1]
---> 76 node_lookup = NodeLookUp()
77 for node_id in top_k:
78 #获取分类名称
<ipython-input-21-74081abf60d9> in __init__(self)
4 label_lookup_path = "inception_model/imagenet_2012_challenge_label_map_proto.pbtxt"
5 uid_lookup_path = "inception_model/imagenet_synset_to_human_label_map.txt"
----> 6 self.node_lookup = self.load(label_lookup_path, uid_lookup_path)
7
8 #载入数据进行处理
<ipython-input-21-74081abf60d9> in load(self, label_lookup_path, uid_lookup_path)
19 #用键值对存储数据
20 uid = parsed_items[0]
---> 21 human_string = parsed_items[1]
22 #保存id和字符串为对应的映射关系
23 uid_to_human[uid] = human_string
IndexError: list index out of range
2、载入模型时没有对其采用二进制读取,‘rb’
---------------------------------------------------------------------------
UnicodeDecodeError Traceback (most recent call last)
<ipython-input-27-052cb1c2f17f> in <module>
40 with tf.gfile.GFile("inception_model/classify_image_graph_def.pb") as f:
41 graph_def = tf.GraphDef()
---> 42 graph_def.ParseFromString(f.read())
43 tf.import_graph_def(graph_def, name="")
UnicodeDecodeError: 'utf-8' codec can't decode byte 0xbb in position 1: invalid start byte
‘r’:默认值,表示从文件读取数据。
‘w’:表示要向文件写入数据,并截断以前的内容
‘a’:表示要向文件写入数据,添加到当前内容尾部
‘r+’:表示对文件进行可读写操作(删除以前的所有数据)
‘r+a’:表示对文件可进行读写操作(添加到当前文件尾部)
‘b’:表示要读写二进制数据
3、未定义,tab,代码格式问题
if line.startswith(" target_class_string:"):
target_class_string = line.split(": ")[1]
node_to_label_uid[target_class] = target_class_string[1:-1]
<ipython-input-28-ecfdf82b86d1> in load(self, label_path, uid_path)
14 if line.startswith(" target_class_string:"):
15 target_class_string = line.split(": ")[1]
---> 16 node_to_label_uid[target_class] = target_class_string
17
18 proto_uid_line = tf.gfile.GFile(uid_path).readlines()
UnboundLocalError: local variable 'target_class_string' referenced before assignment
还没有评论,来说两句吧...