TensorFlow 图像分类模型 inception_resnet_v2 模型导出、冻结与使用

1. 背景

作为一名深度学习萌新,项目突然需要使用图像分类模型去作分类,因此找到了TensorFlow的模型库,使用它的框架进行训练和后续的操作,项目地址:https://github.com/tensorflow/models/tree/master/research/slim

在使用真正的数据集之前,我首先使用的是它提供的flowers的数据集,用的模型是inception_resnet_v2,因为top-5 Accuracy比较高嘛。

然后我安装flowers的目录结构,将我的数据按照类似的结构进行组织;

仿照download_and_convert_flowers.py增加了自己的数据处理文件convert_normal_data.py;

仿照数据集读取文件flowers.py增加了自己的文件normal.py;

然后使用项目的教程,一步步的进行fine-tuning,直到准确率到了百分之九十以上,停止训练。

但是这个时候在导出模型的时候遇到了坑。

2. 导出Inference Graph

实际上教程写得很简单,就是先导出模型的框架:

Saves out a GraphDef containing the architecture of the model.

然后再往框架里把训练好的checkpoints写到graph中:

If you then want to use the resulting model with your own or pretrained checkpoints as part of a mobile model, you can run freeze_graph to get a graph def with the variables inlined

它放出来的教程是这样的:

$ python export_inference_graph.py   --alsologtostderr   --model_name=inception_v3   --output_file=/tmp/inception_v3_inf_graph.pb

我安装这个格式去把模型改成inception_resnet_v2,然后把checkpoint导进去,总是会报:

tensorflow.python.framework.errors_impl.InvalidArgumentError: Assign requires shapes of both tensors to match. lhs shape= [1001] rhs shape= [2]
[[{{node save/Assign_916}}]]

找了个群问了一下,说是模型最后一层输出的数目没有改变,于是重新理了思路,去看了export_inference_graph.py的源码,发现里面有个num_classes的参数,是用来决定最后输出层的数量的,于是最后增加了一下导出参数,最后的命令为:

python export_inference_graph.py   --alsologtostderr   --model_name=${MODEL_NAME}   --dataset_name=normal   --dataset_dir=${DATASET_DIR}   --output_file=/you/path/to/sava/${MODEL_NAME}_inf_graph.pb

最后获得我的graph.pb。

3. 冻结Graph

冻结是个大坑,为什么呢,因为官方给出的教程是使用bazel先编译freeze_graph,然后再使用它进行模型冻结。麻烦来了,首先Ubuntu 18.04无法使用apt进行安装,所以一番折腾,使用它放出的install脚本进行了安装。

然后是需要git clone TensorFlow的源码进行编译,这个编译期间又报了很多错,而且我编译失败后,conda环境的TensorFlow GPU版本还不能用了。。。

最后发现,如果你已经使用conda或者git安装了TensorFlow,直接使用

find / -name freeze_graph.py

找出这个python文件的位置就行了,最后使用命令:

python tensorflow/python/tools/freeze_graph.py   --input_graph=/you/path/to/sava/${MODEL_NAME}_inf_graph.pb   --input_checkpoint=/you/trained/checkpoints/model.ckpt-10000   --input_binary=true   --output_node_names=InceptionResnetV2/Logits/Predictions   --output_graph=/your/path/to/save/frozen_graph.pb

最后终于导出了模型。

4. 使用模型进行预测

主要参考了博文【深度学习-模型eval+模型导出】使用Tensorflow Slim对训练的模型进行评估+导出模型,进行微调:

# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import argparse
import os.path
import re
import sys
import tarfile

import numpy as np
from six.moves import urllib
import tensorflow as tf

FLAGS = None

class NodeLookup(object):
  def __init__(self, label_lookup_path=None):
    self.node_lookup = self.load(label_lookup_path)

  def load(self, label_lookup_path):
    node_id_to_name = {}
    with open(label_lookup_path) as f:
      for line in f:
        line_list = line.strip().split(":")
        node_id_to_name[int(line_list[0])] = line_list[1]
    return node_id_to_name

  def id_to_string(self, node_id):
    if node_id not in self.node_lookup:
      return ‘‘
    return self.node_lookup[node_id]

def create_graph():
  """Creates a graph from saved GraphDef file and returns a saver."""
  # Creates graph from saved graph_def.pb.
  with tf.gfile.FastGFile(FLAGS.model_path, ‘rb‘) as f:
    graph_def = tf.GraphDef()
    graph_def.ParseFromString(f.read())
    _ = tf.import_graph_def(graph_def, name=‘‘)

def preprocess_for_eval(image, height, width,
                        central_fraction=0.875, scope=None):
  with tf.name_scope(scope, ‘eval_image‘, [image, height, width]):
    if image.dtype != tf.float32:
      image = tf.image.convert_image_dtype(image, dtype=tf.float32)
    # Crop the central region of the image with an area containing 87.5% of
    # the original image.
    if central_fraction:
      image = tf.image.central_crop(image, central_fraction=central_fraction)

    if height and width:
      # Resize the image to the specified height and width.
      image = tf.expand_dims(image, 0)
      image = tf.image.resize_bilinear(image, [height, width],
                                       align_corners=False)
      image = tf.squeeze(image, [0])
    image = tf.subtract(image, 0.5)
    image = tf.multiply(image, 2.0)
    return image

def run_inference_on_image(image):
  """Runs inference on an image.
  Args:
    image: Image file name.
  Returns:
    Nothing
  """
  with tf.Graph().as_default():
    image_data = tf.gfile.FastGFile(image, ‘rb‘).read()
    image_data = tf.image.decode_jpeg(image_data)
    image_data = preprocess_for_eval(image_data, 299, 299)
    image_data = tf.expand_dims(image_data, 0)
    with tf.Session() as sess:
      image_data = sess.run(image_data)

  # Creates graph from saved GraphDef.
  create_graph()

  with tf.Session() as sess:
    softmax_tensor = sess.graph.get_tensor_by_name(‘InceptionResnetV2/Logits/Predictions:0‘)
    predictions = sess.run(softmax_tensor,
                           {‘input:0‘: image_data})
    predictions = np.squeeze(predictions)

    # Creates node ID --> English string lookup.
    node_lookup = NodeLookup(FLAGS.label_path)

    top_k = predictions.argsort()[-FLAGS.num_top_predictions:][::-1]
    for node_id in top_k:
      human_string = node_lookup.id_to_string(node_id)
      score = predictions[node_id]
      print(‘%s (score = %.5f)‘ % (human_string, score))

def main(_):
  image = FLAGS.image_file
  run_inference_on_image(image)

if __name__ == ‘__main__‘:
  parser = argparse.ArgumentParser()
  parser.add_argument(
      ‘--model_path‘,
      type=str,
  )
  parser.add_argument(
      ‘--label_path‘,
      type=str,
  )
  parser.add_argument(
      ‘--image_file‘,
      type=str,
      default=‘‘,
      help=‘Absolute path to image file.‘
  )
  parser.add_argument(
      ‘--num_top_predictions‘,
      type=int,
      default=5,
      help=‘Display this many predictions.‘
  )
  FLAGS, unparsed = parser.parse_known_args()
  tf.app.run(main=main, argv=[sys.argv[0]] + unparsed)

最后使用一张图片进行测试:

python classify_image_inception_resnet_v2.py   --model_path /your/saved/path/frozen_graph.pb   --label_path /your/path/labels.txt   --image_file /your/path/test.jpg

最后输出:

unsuited (score = 0.94713)
suited (score = 0.05287)

虽然有点高兴,但是蓦然回首,还是很心累,然后现在conda的TensorFlow GPU版本跪了,需要修复。

5. 参考

(1) 【深度学习-模型eval+模型导出】使用Tensorflow Slim对训练的模型进行评估+导出模型

(2) 【Tensorflow系列】使用Inception_resnet_v2训练自己的数据集并用Tensorboard监控

(完)

原文地址:https://www.cnblogs.com/harrymore/p/12149756.html

时间: 2024-08-27 03:55:04

TensorFlow 图像分类模型 inception_resnet_v2 模型导出、冻结与使用的相关文章

TensorFlow Saver 保存最佳模型 tf.train.Saver Save Best Model

TensorFlow Saver 保存最佳模型 tf.train.Saver Save Best Model Checkmate is designed to be a simple drop-in solution for a very common Tensorflow use-case: keeping track of the best model checkpoints during training. The BestCheckpointSaver is a wrapper arou

osg fbx模型删除模型中的某几个节点,实现编辑模型的功能

fbx model element count:80 三维视图: {三维} 4294967295 osg::MatrixTransform1 基本墙 wall_240 [361750] 4294967295 osg::MatrixTransform2 基本墙 wall_240 [361813] 4294967295 osg::MatrixTransform3 基本墙 wall_240 [361889] 4294967295 osg::MatrixTransform4 基本墙 wall_240 [

多路复用I/O模型poll() 模型 代码实现

多路复用I/O模型poll() 模型 代码实现 poll()机制和select()机制是相似的,都是对多个描述符进行轮询的方式. 不同的是poll()没有描述符数目的限制. 是通过struct pollfd结构体,对每个描述符进行轮询的 struct pollfd fdarray { int fd;    /*文件描述符*/ short events; /*表示等待的事件*/ short revents;/*表示返回事件即实际发生的事件*/ }; data.h #ifndef DATA_H #d

KVC简单介绍 -字典转模型,模型转字典

// 以下两个方法,都属于 KVC 的方法 // KVC 是 cocoa 的大招!间接给对象属性设置数值 // 程序执行过程中,动态给对象属性设置数值,不关心 .h 中是如何定义的 //      只要对象有属性(无论是在.h中还是在.m中定义的属性),就能够读取/设置! //      这种方式,有点违背程序的开发原则! // 字典转模型 setValuesForKeysWithDictionary // 模型转字典 dictionaryWithValuesForKeys //假设self.p

行为型模型 解释模型

行为型模型 解释模型 /** * 行为型模型 解释模型 * 给定一个语言,定义它的文法表示,并定义一个解释器,这个解释器使用该标识来解释语言中的句子. * */ #define _CRT_SECURE_NO_WARNINGS #include <iostream> #include <string> class Context { public: Context(int num) { m_num = num; } void setNum(int num) { m_num = num

iOS 自定义对象及子类及模型套模型的拷贝、归档存储的通用代码

一.runtime实现通用copy 如果自定义类的子类,模型套模型你真的会copy吗,小心有坑. copy需要自定义类继承NSCopying协议 #import <objc/runtime.h> - (id)copyWithZone:(NSZone *)zone { id obj = [[[self class] allocWithZone:zone] init]; Class class = [self class]; while (class != [NSObject class]) {

概念辨析-生成模型/产生模型

机器学习的任务是从属性X预测标记Y,即求概率P(Y|X): 有监督学习 training data给了正确的答案即label,任务就是建立相应的模型,训练样本集外的数据进行分类预测. 生成式模型 生成模型学习一个联合概率分布P(x,y) 常见的判别方法有 k近邻法.感知机.决策树.逻辑回归.线性回归.最大熵模型.支持向量机(SVM).提升方法.条件随机场(CRF) 判别式模型 判别模型学习一个条件概率分布P(y|x) 常见的生成方法有混合高斯模型.朴素贝叶斯法和隐形马尔科夫模型 判别式模型举例:

django 模型-----定义模型

定义模型 在模型中定义属性,会生成表中的字段 django根据属性的类型确定以下信息: 当前选择的数据库支持字段的类型 渲染管理表单时使用的默认html控件 在管理站点最低限度的验证 django会为表增加自动增长的主键列,每个模型只能有一个主键列,如果使用选项设置某属性为主键列后,则django不会再生成默认的主键列 属性命名限制 不能是python的保留关键字 由于django的查询方式,不允许使用连续的下划线 定义属性 定义属性时,需要字段类型 字段类型被定义在django.db.mode

[模型优化]模型欠拟合及过拟合判断、优化方法

[模型优化]模型欠拟合及过拟合判断.优化方法 一.模型欠拟合及过拟合简介 模型应用时发现效果不理想,有多种优化方法,包含: 添加新特征 增加模型复杂度 减小正则项权重 获取更多训练样本 减少特征数目 增加正则项权重 具体采用哪种方法,才能够有效地提高模型精度,我们需要先判断模型是欠拟合,还是过拟合,才能确定下一步优化方向. 图1 模型欠拟合,即高偏差(high bias),是指模型未训练出数据集的特征,导致模型在训练集.测试集上的精度都很低.如图1左图所示. 模型过拟合,即高方差(high va