转载请注明链接:http://www.cnblogs.com/SSSR/p/5630534.html
tflearn中的例子训练vgg16项目:https://github.com/tflearn/tflearn/blob/master/examples/images/vgg_network.py 尚未测试成功。
下面的项目是使用别人已经训练好的模型进行预测,测试效果非常好。
github:https://github.com/ry/tensorflow-vgg16 此项目已经测试成功,效果非常好,
如果在Ubuntu中的terminal中运行出现问题,可以参照以下部分解决(解决skimage读取图片的问题)。
#coding:utf-8 import skimage import skimage.io import skimage.transform a=skimage.io.imread(‘cat.jpg‘) import PIL import numpy as np import tensorflow as tf synset = [l.strip() for l in open(‘/home/ubuntu/pythonproject/tensorflow/tensorflow-vgg16/synset.txt‘).readlines()] def load_image(path): # load image img = skimage.io.imread(path) #img1=PIL.Image.open("/home/ubuntu/pythonproject/tensorflow/tensorflow-vgg16/pic/pig.jpg") #img=np.array(PIL.Image.open(path)) #imgx=np.array(img) #print type(imgx),imgx.shape img = img/ 255.0 assert (0 <= img).all() and (img <= 1.0).all() #print "Original Image Shape: ", img.shape # we crop image from center short_edge = min(img.shape[:2]) yy = int((img.shape[0] - short_edge) / 2) xx = int((img.shape[1] - short_edge) / 2) crop_img = img[yy : yy + short_edge, xx : xx + short_edge] # resize to 224, 224 resized_img = skimage.transform.resize(crop_img, (224, 224)) return resized_img # returns the top1 string def print_prob(prob): #print prob print "prob shape", prob.shape pred = np.argsort(prob)[::-1] # Get top1 label top1 = synset[pred[0]] #print "Top1: ", top1 # Get top5 label top5 = [synset[pred[i]] for i in range(5)] #print "Top5: ", top5 return top1 print u‘加载模型文件‘ with open("/home/ubuntu/pythonproject/tensorflow/tensorflow-vgg16/vgg16.tfmodel", mode=‘rb‘) as f: fileContent = f.read() print u‘创建图‘ graph_def = tf.GraphDef() graph_def.ParseFromString(fileContent) images = tf.placeholder("float", [None, 224, 224, 3]) tf.import_graph_def(graph_def, input_map={ "images": images }) print "graph loaded from disk" graph = tf.get_default_graph() print u‘加载图片‘ #img=np.array(PIL.Image.open("/home/ubuntu/pythonproject/tensorflow/tensorflow-vgg16/pic/pig.jpg")) #cat = load_image(path) print u‘进入sess执行‘ sess=tf.Session() result=[] for i in [‘cat.jpg‘,‘airplane.jpg‘,‘zebra.jpg‘,‘pig.jpg‘,‘12.jpg‘,‘23.jpg‘]: img=load_image(‘pic/‘+i) init = tf.initialize_all_variables() sess.run(init) print "variables initialized" batch = img.reshape((1, 224, 224, 3)) assert batch.shape == (1, 224, 224, 3) feed_dict = { images: batch } print u‘开始执行‘ prob_tensor = graph.get_tensor_by_name("import/prob:0") prob = sess.run(prob_tensor, feed_dict=feed_dict) print u‘输出结果‘ #print_prob(prob[0]) result.append(print_prob(prob[0])) print result sess.close() ‘‘‘ with tf.Session() as sess: init = tf.initialize_all_variables() sess.run(init) print "variables initialized" batch = cat.reshape((1, 224, 224, 3)) assert batch.shape == (1, 224, 224, 3) feed_dict = { images: batch } print u‘开始执行‘ prob_tensor = graph.get_tensor_by_name("import/prob:0") prob = sess.run(prob_tensor, feed_dict=feed_dict) print u‘输出结果‘ print_prob(prob[0]) ‘‘‘
时间: 2024-11-08 23:44:33