tf源码中的object_detection_tutorial.ipynb文件

今天看到原来下载的tf源码的目标检测源码中test的代码不知道跑哪儿去了,这里记录一下。。。

Imports

import numpy as np
import os
import six.moves.urllib as urllib
import sys
import tarfile
import tensorflow as tf
import zipfile

from collections import defaultdict
from io import StringIO
from matplotlib import pyplot as plt
from PIL import Image

# This is needed since the notebook is stored in the object_detection folder.
sys.path.append("..")
from utils import ops as utils_ops

if tf.__version__ < ‘1.4.0‘:
  raise ImportError(‘Please upgrade your tensorflow installation to v1.4.* or later!‘)
# This is needed to display the images.
%matplotlib inline

Object detection imports

Here are the imports from the object detection module.

from utils import label_map_util

from utils import visualization_utils as vis_util

Model preparation

Variables

Any model exported using the export_inference_graph.py tool can be loaded here simply by changing PATH_TO_CKPT to point to a new .pb file.

By default we use an "SSD with Mobilenet" model here. See the detection model zoo for a list of other models that can be run out-of-the-box with varying speeds and accuracies.

:
# What model to download.
MODEL_NAME = ‘ssd_mobilenet_v1_coco_2017_11_17‘
MODEL_FILE = MODEL_NAME + ‘.tar.gz‘
DOWNLOAD_BASE = ‘http://download.tensorflow.org/models/object_detection/‘

# Path to frozen detection graph. This is the actual model that is used for the object detection.
PATH_TO_CKPT = MODEL_NAME + ‘/frozen_inference_graph.pb‘

# List of the strings that is used to add correct label for each box.
PATH_TO_LABELS = os.path.join(‘data‘, ‘mscoco_label_map.pbtxt‘)

NUM_CLASSES = 90

Download Model

opener = urllib.request.URLopener()
opener.retrieve(DOWNLOAD_BASE + MODEL_FILE, MODEL_FILE)
tar_file = tarfile.open(MODEL_FILE)
for file in tar_file.getmembers():
  file_name = os.path.basename(file.name)
  if ‘frozen_inference_graph.pb‘ in file_name:
    tar_file.extract(file, os.getcwd())

Load a (frozen) Tensorflow model into memory.

detection_graph = tf.Graph()
with detection_graph.as_default():
  od_graph_def = tf.GraphDef()
  with tf.gfile.GFile(PATH_TO_CKPT, ‘rb‘) as fid:
    serialized_graph = fid.read()
    od_graph_def.ParseFromString(serialized_graph)
    tf.import_graph_def(od_graph_def, name=‘‘)

Loading label map

Label maps map indices to category names, so that when our convolution network predicts 5, we know that this corresponds to airplane. Here we use internal utility functions, but anything that returns a dictionary mapping integers to appropriate string labels would be fine

label_map = label_map_util.load_labelmap(PATH_TO_LABELS)
categories = label_map_util.convert_label_map_to_categories(label_map, max_num_classes=NUM_CLASSES, use_display_name=True)
category_index = label_map_util.create_category_index(categories)

Helper code

def load_image_into_numpy_array(image):
  (im_width, im_height) = image.size
  return np.array(image.getdata()).reshape(
      (im_height, im_width, 3)).astype(np.uint8)

Detection

# For the sake of simplicity we will use only 2 images:
# image1.jpg
# image2.jpg
# If you want to test the code with your images, just add path to the images to the TEST_IMAGE_PATHS.
PATH_TO_TEST_IMAGES_DIR = ‘test_images‘
TEST_IMAGE_PATHS = [ os.path.join(PATH_TO_TEST_IMAGES_DIR, ‘image{}.jpg‘.format(i)) for i in range(1, 3) ]

# Size, in inches, of the output images.
IMAGE_SIZE = (12, 8)
def run_inference_for_single_image(image, graph):
  with graph.as_default():
    with tf.Session() as sess:
      # Get handles to input and output tensors
      ops = tf.get_default_graph().get_operations()
      all_tensor_names = {output.name for op in ops for output in op.outputs}
      tensor_dict = {}
      for key in [
          ‘num_detections‘, ‘detection_boxes‘, ‘detection_scores‘,
          ‘detection_classes‘, ‘detection_masks‘
      ]:
        tensor_name = key + ‘:0‘
        if tensor_name in all_tensor_names:
          tensor_dict[key] = tf.get_default_graph().get_tensor_by_name(
              tensor_name)
      if ‘detection_masks‘ in tensor_dict:
        # The following processing is only for single image
        detection_boxes = tf.squeeze(tensor_dict[‘detection_boxes‘], [0])
        detection_masks = tf.squeeze(tensor_dict[‘detection_masks‘], [0])
        # Reframe is required to translate mask from box coordinates to image coordinates and fit the image size.
        real_num_detection = tf.cast(tensor_dict[‘num_detections‘][0], tf.int32)
        detection_boxes = tf.slice(detection_boxes, [0, 0], [real_num_detection, -1])
        detection_masks = tf.slice(detection_masks, [0, 0, 0], [real_num_detection, -1, -1])
        detection_masks_reframed = utils_ops.reframe_box_masks_to_image_masks(
            detection_masks, detection_boxes, image.shape[0], image.shape[1])
        detection_masks_reframed = tf.cast(
            tf.greater(detection_masks_reframed, 0.5), tf.uint8)
        # Follow the convention by adding back the batch dimension
        tensor_dict[‘detection_masks‘] = tf.expand_dims(
            detection_masks_reframed, 0)
      image_tensor = tf.get_default_graph().get_tensor_by_name(‘image_tensor:0‘)

      # Run inference
      output_dict = sess.run(tensor_dict,
                             feed_dict={image_tensor: np.expand_dims(image, 0)})

      # all outputs are float32 numpy arrays, so convert types as appropriate
      output_dict[‘num_detections‘] = int(output_dict[‘num_detections‘][0])
      output_dict[‘detection_classes‘] = output_dict[
          ‘detection_classes‘][0].astype(np.uint8)
      output_dict[‘detection_boxes‘] = output_dict[‘detection_boxes‘][0]
      output_dict[‘detection_scores‘] = output_dict[‘detection_scores‘][0]
      if ‘detection_masks‘ in output_dict:
        output_dict[‘detection_masks‘] = output_dict[‘detection_masks‘][0]
  return output_dict
for image_path in TEST_IMAGE_PATHS:
  image = Image.open(image_path)
  # the array based representation of the image will be used later in order to prepare the
  # result image with boxes and labels on it.
  image_np = load_image_into_numpy_array(image)
  # Expand dimensions since the model expects images to have shape: [1, None, None, 3]
  image_np_expanded = np.expand_dims(image_np, axis=0)
  # Actual detection.
  output_dict = run_inference_for_single_image(image_np, detection_graph)
  # Visualization of the results of a detection.
  vis_util.visualize_boxes_and_labels_on_image_array(
      image_np,
      output_dict[‘detection_boxes‘],
      output_dict[‘detection_classes‘],
      output_dict[‘detection_scores‘],
      category_index,
      instance_masks=output_dict.get(‘detection_masks‘),
      use_normalized_coordinates=True,
      line_thickness=8)
  plt.figure(figsize=IMAGE_SIZE)
  plt.imshow(image_np)

总结:实际测试的时候多使用glob模块(或os)读文件,opencv(+矩形框)展示检测效果。

原文地址:https://www.cnblogs.com/kongweisi/p/12318144.html

时间: 2024-08-27 02:13:00

tf源码中的object_detection_tutorial.ipynb文件的相关文章

The Independent JPEG Group&#39;s JPEG software Android源码中 JPEG的ReadMe文件

The Independent JPEG Group's JPEG software========================================== README for release 6b of 27-Mar-1998==================================== This distribution contains the sixth public release of the Independent JPEGGroup's free JPEG

Python3.4 获取百度网页源码并保存在本地文件中

最近学习python 版本 3.4 抓取网页源码并且保存在本地文件中 import urllib.request url='http://www.baidu.com' #上面的url一定要写明确,如果写成www.baidu.com,下一步就会报错. response=urllib.request.urlopen(url) #下一步获取html,但是是Byte格式的,我们要解码 html=response.read() html_str=html.decode('utf-8') #下面我们把get

Redis源码中探秘SHA-1算法原理及其编程实现

导读 SHA-1算法是第一代"安全散列算法"的缩写,其本质就是一个Hash算法.SHA系列标准主要用于数字签名,生成消息摘要,曾被认为是MD5算法的后继者.如今SHA家族已经出现了5个算法.Redis使用的是SHA-1,它能将一个最大2^64比特的消息,转换成一串160位的消息摘要,并能保证任何两组不同的消息产生的消息摘要是不同的.虽然SHA1于早年间也传出了破解之道,但作为SHA家族的第一代算法,对我们仍然很具有学习价值和指导意义. SHA-1算法的详细内容可以参考官方的RFC:ht

Redis源码中的CRC校验码(crc16、crc64)原理浅析

在阅读Redis源码的时候,看到了两个文件:crc16.c.crc64.c.下面我抛砖引玉,简析一下原理. CRC即循环冗余校验码,是信息系统中一种常见的检错码.大学课程中的"计算机网络"."计算机组成"等课程中都有提及.我们可能都了解它的数学原理,在试卷上手工计算一个CRC校验码,并不是难事.但是计算机不是人,现实世界中的数学原理需要转化为计算机算法才能实现目的.实际上作为计算机专业背景人并不会经常使用或接触到CRC的计算机算法实现的原理,通常是电子学科背景的人士

从源码中浅析Android中如何利用attrs和styles定义控件

一直有个问题就是,Android中是如何通过布局文件,就能实现控件效果的不同呢?比如在布局文件中,我设置了一个TextView,给它设置了textColor,它就能够改变这个TextView的文本的颜色.这是如何做到的呢?我们分3个部分来看这个问题1.attrs.xml  2.styles.xml  3.看组件的源码. 1.attrs.xml: 我们知道Android的源码中有attrs.xml这个文件,这个文件实际上定义了所有的控件的属性,就是我们在布局文件中设置的各类属性 你可以找到attr

[C/C++]_[VS2010源码中使用UTF8中文字符串被转码为ANSI的问题]

场景: 1.本以为vs设置了源文件的UTF8编码,代码中出现的中文字符串就一定是utf8编码了,可惜不是,如果源码中出现了中文字符串,会在内存中转码为ANSI编码. Unicode(UTF8带签名) 代码页(65001),从菜单->文件->高级保存选项 设置. 例子: char path[] = "resources\\中文\\"; for(int i = 0; i < strlen(path); ++i) { printf("0x%x,",(un

第三篇:属性_第二节:控件属性在页面及源码中的表示方式

一.属性在页面及源码中的表示方式 认真地看看页面中声明控件的代码,你会发现控件属性在页面中的表示千变万化.我们看看下面这些: <%@ Page Language="C#" AutoEventWireup="true" CodeBehind="控件属性在页面源码中的表达方式.aspx.cs" Inherits="CustomServerControlTest.控件属性在页面源码中的表达方式" %> <!DOCT

从Android4.0源码中提取的截图实现(在当前activity中有效,不能全局截图)

原文:http://blog.csdn.net/xu_fu/article/details/39268771 从这个大神的博客看到了这篇文章,感觉写的挺好的.挺实用的功能.虽然是从源码中提取的,但是看得出费了一番心思.而且讲解的很透彻.我这里补充的是这个仅仅能在一个acitvity中使用,不能实现在服务中截图.getWindow().getDecorView()这个方法得到的是当前根视图,这样等于得到了当前屏幕展示的图片,截取即可.这里为了方便没有做图片保存的工作.仅仅作为演示. 一.使用方式

Android系统篇之----编写系统服务并且将其编译到系统源码中

在之前已经介绍了一篇关于如何编写简单的驱动以及访问该驱动的小程序,最后将程序编译到Android内核源码中通过程序访问驱动验证是可以通过的,那么本文就继续这个知识点,把这个驱动程序通过JNI连接创建一个系统服务,提供给上层应用访问改服务功能,可以看到前一篇介绍驱动程序的功能是属于内核层的,而本文介绍的内容是Framework层的知识. 声明:本文内容参考罗升阳的书籍:<Android系统源代码情景分析> 如果想了解更详细的内容非常建议购买此书 非常感谢罗神的这本书,给我带来很多未知的知识,大神