用tensorflow神经网络实现一个简易的图片分类器

文章写的不清晰请大家原谅QAQ  

这篇文章我们将用 CIFAR-10数据集做一个很简易的图片分类器。 在 CIFAR-10数据集包含了60,000张图片。在此数据集中,有10个不同的类别,每个类别中有6,000个图像。每幅图像的大小为32 x 32像素。虽然这么小的尺寸通常给人类识别正确的类别带来了困难,但它实际上是对计算机模型的简化并且减少了分析图像所需的计算。

                                                                                     CIFAR-10数据集

我们可以通过输入模型的大量数字序列将这些图像输入到我们的模型中。每个像素由三个浮点数标识,这三个浮点数表示该像素的红色,绿色和蓝色值(RGB值)。所以每个图像有32 x 32 x 3 = 3,072 个值0.

使用非常大的卷积神经网络可以实现高质量的结果,你可以在这个连接中学习Rodrigo Benenson’s page

下载CIFAR-10数据集,网址:Python version of the dataset, 并把他安装在我们分类器代码所在的文件夹下

先上源代码

模型的源代码:

import numpy as np
import tensorflow as tf
import time
import data_helpers
beginTime = time.time()

batch_size = 100
learning_rate = 0.005
max_steps = 1000

data_sets = data_helpers.load_data()

# Define input placeholders
images_placeholder = tf.placeholder(tf.float32, shape=[None, 3072])
labels_placeholder = tf.placeholder(tf.int64, shape=[None])

# Define variables (these are the values we want to optimize)
weights = tf.Variable(tf.zeros([3072, 10]))
biases = tf.Variable(tf.zeros([10]))

# Define the classifier‘s result
logits = tf.matmul(images_placeholder, weights) + biases

# Define the loss function
loss = tf.reduce_mean(tf.nn.sparse_softmax_cross_entropy_with_logits(logits=logits,
                                                                     labels=labels_placeholder))

# Define the training operation
train_step = tf.train.GradientDescentOptimizer(learning_rate).minimize(loss)

# Operation comparing prediction with true label
correct_prediction = tf.equal(tf.argmax(logits, 1), labels_placeholder)

# Operation calculating the accuracy of our predictions
accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))

with tf.Session() as sess:
    # Initialize variables
    sess.run(tf.global_variables_initializer())

    # Repeat max_steps times
    for i in range(max_steps):

        # Generate input data batch
        indices = np.random.choice(data_sets[‘images_train‘].shape[0], batch_size)
        images_batch = data_sets[‘images_train‘][indices]
        labels_batch = data_sets[‘labels_train‘][indices]

        # Periodically print out the model‘s current accuracy
        if i % 100 == 0:
            train_accuracy = sess.run(accuracy, feed_dict={
                images_placeholder: images_batch, labels_placeholder: labels_batch})
            print(‘Step {:5d}: training accuracy {:g}‘.format(i, train_accuracy))

        # Perform a single training step
        sess.run(train_step, feed_dict={images_placeholder: images_batch,
                                        labels_placeholder: labels_batch})

    # After finishing the training, evaluate on the test set
    test_accuracy = sess.run(accuracy, feed_dict={
        images_placeholder: data_sets[‘images_test‘],
        labels_placeholder: data_sets[‘labels_test‘]})
    print(‘Test accuracy {:g}‘.format(test_accuracy))

endTime = time.time()
print(‘Total time: {:5.2f}s‘.format(endTime - beginTime))

处理数据集的代码

import numpy as np
import pickle
import sys

def load_CIFAR10_batch(filename):
    ‘‘‘load data from single CIFAR-10 file‘‘‘

    with open(filename, ‘rb‘) as f:
        if sys.version_info[0] < 3:
            dict = pickle.load(f)
        else:
            dict = pickle.load(f, encoding=‘latin1‘)
        x = dict[‘data‘]
        y = dict[‘labels‘]
        x = x.astype(float)
        y = np.array(y)
    return x, y

def load_data():
    ‘‘‘load all CIFAR-10 data and merge training batches‘‘‘

    xs = []
    ys = []
    for i in range(1, 6):
        filename = ‘cifar-10-batches-py/data_batch_‘ + str(i)
        X, Y = load_CIFAR10_batch(filename)
        xs.append(X)
        ys.append(Y)

    x_train = np.concatenate(xs)
    y_train = np.concatenate(ys)
    del xs, ys

    x_test, y_test = load_CIFAR10_batch(‘cifar-10-batches-py/test_batch‘)

    classes = [‘plane‘, ‘car‘, ‘bird‘, ‘cat‘, ‘deer‘, ‘dog‘, ‘frog‘, ‘horse‘,
               ‘ship‘, ‘truck‘]

    # Normalize Data
    mean_image = np.mean(x_train, axis=0)
    x_train -= mean_image
    x_test -= mean_image

    data_dict = {
        ‘images_train‘: x_train,
        ‘labels_train‘: y_train,
        ‘images_test‘: x_test,
        ‘labels_test‘: y_test,
        ‘classes‘: classes
    }
    return data_dict

def reshape_data(data_dict):
    im_tr = np.array(data_dict[‘images_train‘])
    im_tr = np.reshape(im_tr, (-1, 3, 32, 32))
    im_tr = np.transpose(im_tr, (0, 2, 3, 1))
    data_dict[‘images_train‘] = im_tr
    im_te = np.array(data_dict[‘images_test‘])
    im_te = np.reshape(im_te, (-1, 3, 32, 32))
    im_te = np.transpose(im_te, (0, 2, 3, 1))
    data_dict[‘images_test‘] = im_te
    return data_dict

def gen_batch(data, batch_size, num_iter):
    data = np.array(data)
    index = len(data)
    for i in range(num_iter):
        index += batch_size
        if (index + batch_size > len(data)):
            index = 0
            shuffled_indices = np.random.permutation(np.arange(len(data)))
            data = data[shuffled_indices]
        yield data[index:index + batch_size]

def main():
    data_sets = load_data()
    print(data_sets[‘images_train‘].shape)
    print(data_sets[‘labels_train‘].shape)
    print(data_sets[‘images_test‘].shape)
    print(data_sets[‘labels_test‘].shape)

if __name__ == ‘__main__‘:
    main()

首先我们导入了tensorflow numpy time 以及自己写的data_help包

time是为了计算整个代码的运行时间。 data_help是将数据集做成我们训练用的数据结构

data_help中的load_data()会把60000张的CIFAR数据集分成两块:500000张的训练集和100000张的测试集,具体来说他会返回这样的一个包含如下内容的字典

  • images_train: 训练集。一个500000张 包含3072(32x32像素点x3颜色通道)值
  • labels_train: 训练集的50,000个标签(每个标签在0到9之间,代表训练图像所属的10个类别中的哪一个)
  • images_test: 测试集(10,000 by 3,072)
  • labels_test: 测试集的10,000个标签
  • classes: 10个文本标签,用于将数字类值转换为单词(例如0代表‘plane‘,1代表‘car‘)

原文地址:https://www.cnblogs.com/francischeng/p/9833201.html

时间: 2024-10-04 17:02:41

用tensorflow神经网络实现一个简易的图片分类器的相关文章

自己来实现一个简易的OCR

来做个简易的字符识别 ,既然是简易的 那么我们就不能用任何的第三方库 .啥谷歌的 tesseract-ocr, opencv 之类的 那些玩意是叼 至少图像处理 机器视觉这类课题对我这种高中没毕业的人来说是一座高山 对于大多数程序员都应该算难度不小吧. 但是我们这里 这么简陋的功能 还用那些玩意 作为一个程序员的自我修养 你还玩个球.管他代码写得咋个low 效率咋个低 被高手嗤之以鼻也好 其实那些高手也就那样 把你的代码走起来  ,这是一件很好玩的事情. 以前一直觉着这玩意挺神奇 什么OCR o

一个简易的发布电影票的项目

原文:一个简易的发布电影票的项目 源代码下载地址:http://www.zuidaima.com/share/1601881858886656.htm 在首页(index.html)页面上,按照影片发表时间显示所有影片 点击index.html页面右边的影片类型链接,在页面左边显示对应影片 根据影片的名称进行搜索 当鼠标悬停到影片的图片上,显示影片的详细信息 发布新影片(不要求实现上传图片功能) 发布新影片时可以使用struts2组件上传图片 提供数据库脚本 分层实现 使用struts2框架和j

使用cocos制作一个简易的小闹钟

使用cocos制作一个简易的小闹钟 本文转载至学习使用Cocos制作<闹钟> 使用的引擎版本是cocos2.1 具体开发过程指导 (1)Cocos Studio部分 1.打开Cocos工具,新建一个项目: 2.设置好相关的配置,点击完成,从而发布到Cocos Studio中: 3.Cocos Studio IDE介绍: 左上角的是开发常用的游戏元素.UI控件.容器等,可以像VS2013一样拖拽,并在右边设置对应的属性:左下角是资源导入,可以导入所需的图片背景:下面是时间戳,用于设置基于时间戳的

宝塔面板+Fikker+BBR算法+CloudXNS---搭建一个简易的全球CDN缓存节点给网站加速

一.组件简介1)宝塔面板 宝塔面板是一款服务器管理软件,支持windows和linux系统,可以通过Web端轻松管理服务器,提升运维效率.例如:创建管理网站.FTP.数据库,拥有可视化文件管理器,可视化软件管理器,可视化CPU.内存.流量监控图表,计划任务等功能.我们在这里只用到它的LNMP/LAMP一键安装功能. linux(centos)版:yum install -y wget && wget -O install.sh http://download.bt.cn/install/i

如何搭建一个简易的Web框架

Web框架本质 什么是Web框架, 如何自己搭建一个简易的Web框架?其实, 只要了解了HTTP协议, 这些问题将引刃而解. 简单的理解:  所有的Web应用本质上就是一个socket服务端, 而用户的浏览器就是一个socket客户端. 用户在浏览器的地址栏输入网址, 敲下回车键便会给服务端发送数据, 这个数据是要遵守统一的规则(格式)的, 这个规则便是HTTP协议. HTTP协议主要规定了客户端和服务器之间的通信格式 浏览器收到的服务器响应的相关信息可以在浏览器调试窗口(F12键开启)的Net

使用TensorFlow 来实现一个简单的验证码识别过程

本文我们来用 TensorFlow 来实现一个深度学习模型,用来实现验证码识别的过程,这里识别的验证码是图形验证码,首先我们会用标注好的数据来训练一个模型,然后再用模型来实现这个验证码的识别. 1.验证码准备 这里我们使用 python 的 captcha 库来生成即可,这个库默认是没有安装的,所以这里我们需要先安装这个库,另外我们还需要安装 pillow 库 安装好之后,我们就可以用如下代码来生成一个简单的图形验证码 可以看到图中的文字正是我们所定义的内容,这样我们就可以得到一张图片和其对应的

iOS:制作一个简易的计算器

初步接触视图,制作了一个简易的计算器,基本上简单的计算是没有问题的,不是很完美,可能还有一些bug,再接再厉. 1 // 2 // ViewController.m 3 // 计算器 4 // 5 // Created by ma c on 15/8/25. 6 // Copyright (c) 2015年 bjsxt. All rights reserved. 7 // 8 9 #import "ViewController.h" 10 11 @interface ViewContr

Angularjs,WebAPI 搭建一个简易权限管理系统

Angularjs,WebAPI 搭建一个简易权限管理系统 Angularjs名词与概念(一) 1. 目录 前言 Angularjs名词与概念 权限系统原型 权限系统业务 数据库设计和实现 WebAPI项目主体结构 Angularjs前端主体结构 2. 前言 Angularjs开发CRUD类型的Web系统生产力惊人,与jQuery,YUI,kissy,Extjs等前端框架区别非常大,初学者在学习的过程中容易以自己以往的经验来学习Angularjs 往往走入误区,最典型的特征是在的开发过程中,使用

Socket 初识 用Socket建立一个简易Web服务器

摘自<Asp.Net 本质论>作者:郝冠军 //在.Net中.system.Net命名空间提供了网络编程的大多数数据据类型以及常用操作,其中常用的类型如下:/*IPAddress 类表示一个IP地址* IPEndPoint类用来表示一个IP地址和一个端口号的组合,成为网络的端点.* System.Net.Sockets命名空间中提供了基于Socked编程的数据类型.* Socket类封装了Socked的操作.* 常见的操作:* Listen:设置基于连接通信的Socket进入监听状态,并设置等