tensorflow中关于vgg16的项目

转载请注明链接:http://www.cnblogs.com/SSSR/p/5630534.html

tflearn中的例子训练vgg16项目:https://github.com/tflearn/tflearn/blob/master/examples/images/vgg_network.py 尚未测试成功。

下面的项目是使用别人已经训练好的模型进行预测,测试效果非常好。

github:https://github.com/ry/tensorflow-vgg16 此项目已经测试成功,效果非常好,

如果在Ubuntu中的terminal中运行出现问题,可以参照以下部分解决(解决skimage读取图片的问题)。

#coding:utf-8

import skimage
import skimage.io
import skimage.transform
a=skimage.io.imread(‘cat.jpg‘)
import PIL
import numpy as np
import tensorflow as tf
synset = [l.strip() for l in open(‘/home/ubuntu/pythonproject/tensorflow/tensorflow-vgg16/synset.txt‘).readlines()]

def load_image(path):
  # load image
  img = skimage.io.imread(path)
  #img1=PIL.Image.open("/home/ubuntu/pythonproject/tensorflow/tensorflow-vgg16/pic/pig.jpg")
  #img=np.array(PIL.Image.open(path))
  #imgx=np.array(img)
  #print type(imgx),imgx.shape
  img = img/ 255.0
  assert (0 <= img).all() and (img <= 1.0).all()
  #print "Original Image Shape: ", img.shape
  # we crop image from center
  short_edge = min(img.shape[:2])
  yy = int((img.shape[0] - short_edge) / 2)
  xx = int((img.shape[1] - short_edge) / 2)
  crop_img = img[yy : yy + short_edge, xx : xx + short_edge]
  # resize to 224, 224
  resized_img = skimage.transform.resize(crop_img, (224, 224))
  return resized_img

# returns the top1 string
def print_prob(prob):
  #print prob
  print "prob shape", prob.shape
  pred = np.argsort(prob)[::-1]
  # Get top1 label
  top1 = synset[pred[0]]
  #print "Top1: ", top1
  # Get top5 label
  top5 = [synset[pred[i]] for i in range(5)]
  #print "Top5: ", top5
  return top1

print u‘加载模型文件‘
with open("/home/ubuntu/pythonproject/tensorflow/tensorflow-vgg16/vgg16.tfmodel", mode=‘rb‘) as f:
  fileContent = f.read()

print u‘创建图‘
graph_def = tf.GraphDef()
graph_def.ParseFromString(fileContent)

images = tf.placeholder("float", [None, 224, 224, 3])

tf.import_graph_def(graph_def, input_map={ "images": images })
print "graph loaded from disk"

graph = tf.get_default_graph()
print u‘加载图片‘
#img=np.array(PIL.Image.open("/home/ubuntu/pythonproject/tensorflow/tensorflow-vgg16/pic/pig.jpg"))
#cat = load_image(path)
print u‘进入sess执行‘

sess=tf.Session()
result=[]
for i in [‘cat.jpg‘,‘airplane.jpg‘,‘zebra.jpg‘,‘pig.jpg‘,‘12.jpg‘,‘23.jpg‘]:
  img=load_image(‘pic/‘+i)
  init = tf.initialize_all_variables()
  sess.run(init)
  print "variables initialized"
  batch = img.reshape((1, 224, 224, 3))
  assert batch.shape == (1, 224, 224, 3)
  feed_dict = { images: batch }
  print u‘开始执行‘
  prob_tensor = graph.get_tensor_by_name("import/prob:0")
  prob = sess.run(prob_tensor, feed_dict=feed_dict)
  print u‘输出结果‘
  #print_prob(prob[0])
  result.append(print_prob(prob[0]))

print result
sess.close()

‘‘‘
with tf.Session() as sess:
  init = tf.initialize_all_variables()
  sess.run(init)
  print "variables initialized"
  batch = cat.reshape((1, 224, 224, 3))
  assert batch.shape == (1, 224, 224, 3)
  feed_dict = { images: batch }
  print u‘开始执行‘
  prob_tensor = graph.get_tensor_by_name("import/prob:0")
  prob = sess.run(prob_tensor, feed_dict=feed_dict)

print u‘输出结果‘
print_prob(prob[0])

‘‘‘

  

时间: 2024-11-08 23:44:33

tensorflow中关于vgg16的项目的相关文章

TensorFlow中的并行执行引擎——StreamExecutor框架

背景 [作者:DeepLearningStack,阿里巴巴算法工程师] 在前一篇文章中,我们梳理了TensorFlow中各种异构Device的添加和注册机制,通过使用预先定义好的宏,各种自定义好的Device能够将自己注册到全局表中.TensorFlow期望通过这种模式,能够让Device的添加和注册于系统本身更好的解耦,从而体现了较好的模块化特性.在这篇文章中,我们选择直接去窥探TensorFlow底层架构较为复杂的一个部分--StreamExecutor框架.我们已经知道TensorFlow

TensorFlow中的通信机制——Rendezvous(二)gRPC传输

背景 [作者:DeepLearningStack,阿里巴巴算法工程师,开源TensorFlow Contributor] 本篇是TensorFlow通信机制系列的第二篇文章,主要梳理使用gRPC网络传输部分模块的结构和源码.如果读者对TensorFlow中Rendezvous部分的基本结构和原理还不是非常了解,那么建议先从这篇文章开始阅读.TensorFlow在最初被开源时还只是个单机的异构训练框架,在迭代到0.8版本开始正式支持多机分布式训练.与其他分布式训练框架不同,Google选用了开源项

使用tensorflow中的Dataset来读取制作好的tfrecords文件

上一篇我写了如何给自己的图像集制作tfrecords文件,现在我们就来讲讲如何读取已经创建好的文件,我们使用的是Tensorflow中的Dataset来读取我们的tfrecords,网上很多帖子应该是很久之前的了,绝大多数的做法是,先将tfrecords序列化成一个队列,然后使用TFRecordReader这个函数进行解析,解析出来的每一行都是一个record,然后再将每一个record进行还原,但是这个函数你在使用的时候会报出异常,原因就是它已经被dataset中新的读取方式所替代,下个版本中

Maven入门1-在Eclipse中新建Maven Web项目

在eclipse中新建Maven Web项目 很多时候开发效率低下,大部分原因是IDE环境不熟悉.配置不会配置:因此在学习一项技能之前,有必要对基本的环境配置有所了解,正所谓磨刀不误砍柴工.这篇文章主要针对初次接触Maven,不熟悉配置的研究人员. 1.Maven配置及介绍 Maven官网:http://maven.apache.org/ 以前开发Java Web工程时,需要导入很多依赖包,但是随着工程逐渐变大,所管理的包越来越多,有必要使用工具来管理这些包,这样不需要我们手动导入:Maven就

Tensorflow中使用CNN实现Mnist手写体识别

本文参考Yann LeCun的LeNet5经典架构,稍加ps得到下面适用于本手写识别的cnn结构,构造一个两层卷积神经网络,神经网络的结构如下图所示: 输入-卷积-pooling-卷积-pooling-全连接层-Dropout-Softmax输出 第一层卷积利用5*5的patch,32个卷积核,可以计算出32个特征.然后进行maxpooling.第二层卷积利用5*5的patch,64个卷积核,可以计算出64个特征.然后进行max pooling.卷积核的个数是我们自己设定,可以增加卷积核数目提高

VS中生成、清理项目、调试、开始执行(不调试)、Debug 和 Release等之间的区别

一.生成和重新生成 "生成"的时候只对你改动过的文件重新生成没有改动过的文件不会重新生成: "重新生成"是对所有的文件都重新生成. 以cpp为例当你只改动某些.cpp之类的文件的时候可以用生成省了编译没有改动的那些些文件的时间:但是改动了某些.h之类的文件最好用重新生成,因为有可能能有些文件包含.h文件也需要重新编译 选择生成或生成解决方案,将只编译自上次生成以来更改过的那些些项目文件和组件 注意 如果解决方案中包括多个项目,则生成命令将变成生成解决方案. 选择重新

AS中导入GitHub开源项目SlidingMenu总结,此方法有效,但是太耗时间。 「我用了半个多小时」

AS中导入GitHub开源项目SlidingMenu总结,我开始AS导入SlidingMenu的时候也百度了很多文章,写的都不是很详细,所以导入成功后,写了这篇文章,希望对想用AndroidStudio导入SlidingMenu的小伙伴有所启发. 先上最终效果图动画 1,下载SlidingMenu(https://github.com/jfeinstein10/SlidingMenu) 2.新建AS项目,把SlidingMenu-master中的library文件夹(我把这个文件夹重命名为sli

tensorflow中的共享变量(sharing variables)

为什么要使用共享变量? 当训练复杂模型时,可能经常需要共享大量的变量.例如,使用测试集来测试已训练好的模型性能表现时,需要共享已训练好模型的变量,如全连接层的权值. 而且我们还会遇到以下问题: 比如,我们创建了一个简单的图像滤波器模型.如果只使用tf.Variable,那么我们的模型可能如下 def my_image_filter(input_images): conv1_weights = tf.Variable(tf.random_normal([5, 5, 32, 32]), name="

ios项目中引用其他开源项目

1. 将开源项目的.xcodeproj拖入项目frameworks 2. Build Phases下 Links Binary With Libraries 引入.a文件.Target Dependencies里引入开源项目文件 3. Build Setting下的 Search Paths 里 Header Search Paths 加入开源项目src目录 例:$(SOURCE_ROOT)/IBAForms/headers ,IBA放在项目根目录里,headers就是src 如果和项目根目录平