TFlearn——(2)SVHN

1,数据集简介

  SVHN(Street View House Number)Dateset 来源于谷歌街景门牌号码,原生的数据集1也就是官网的 Format 1 是一些原始的未经处理的彩色图片,如下图所示(不含有蓝色的边框),下载的数据集含有 PNG 的图像和 digitStruct.mat  的文件,其中包含了边框的位置信息,这个数据集每张图片上有好几个数字,适用于 OCR 相关方向。

  这里采用 Format2, Format2 将这些数字裁剪成32x32的大小,如图所示,并且数据是 .mat 文件。

    

2,数据处理

  数据集含有两个变量 X 代表图像, 训练集 X 的 shape 是  (32,32,3,73257) 也就是(width, height, channels, samples),  tensorflow 的张量需要 (samples, width, height, channels),所以需要转换一下,由于直接调用 cifar 10 的网络模型,数据只需要先做个归一化,所有像素除于255就 OK,另外原始数据 0 的标签是 10,这里要转化成 0,并提供 one_hot 编码。

#!/usr/bin/env python2
# -*- coding: utf-8 -*-
"""
Created on Thu Jan 19 09:55:36 2017

@author: cheers
"""

import scipy.io as sio
import matplotlib.pyplot as plt
import numpy as np

image_size = 32
num_labels = 10

def display_data():
    print ‘loading Matlab data...‘
    train = sio.loadmat(‘train_32x32.mat‘)
    data=train[‘X‘]
    label=train[‘y‘]
    for i in range(10):
        plt.subplot(2,5,i+1)
        plt.title(label[i][0])
        plt.imshow(data[...,i])
        plt.axis(‘off‘)
    plt.show()

def load_data(one_hot = False):

    train = sio.loadmat(‘train_32x32.mat‘)
    test = sio.loadmat(‘test_32x32.mat‘)

    train_data=train[‘X‘]
    train_label=train[‘y‘]
    test_data=test[‘X‘]
    test_label=test[‘y‘]

    train_data = np.swapaxes(train_data, 0, 3)
    train_data = np.swapaxes(train_data, 2, 3)
    train_data = np.swapaxes(train_data, 1, 2)
    test_data = np.swapaxes(test_data, 0, 3)
    test_data = np.swapaxes(test_data, 2, 3)
    test_data = np.swapaxes(test_data, 1, 2)

    test_data = test_data / 255.
    train_data =train_data / 255.

    for i in range(train_label.shape[0]):
         if train_label[i][0] == 10:
             train_label[i][0] = 0

    for i in range(test_label.shape[0]):
         if test_label[i][0] == 10:
             test_label[i][0] = 0

    if one_hot:
        train_label = (np.arange(num_labels) == train_label[:,]).astype(np.float32)
        test_label = (np.arange(num_labels) == test_label[:,]).astype(np.float32)

    return train_data,train_label, test_data,test_label

if __name__ == ‘__main__‘:
    load_data(one_hot = True)
    display_data()

3,TFearn 训练

注意 ImagePreprocessing 对数据做了 0 均值化。网络结构也比较简单,直接调用 TFlearn 的 cifar10 例子。

from __future__ import division, print_function, absolute_import

import tflearn
from tflearn.data_utils import shuffle, to_categorical
from tflearn.layers.core import input_data, dropout, fully_connected
from tflearn.layers.conv import conv_2d, max_pool_2d
from tflearn.layers.estimator import regression
from tflearn.data_preprocessing import ImagePreprocessing
from tflearn.data_augmentation import ImageAugmentation

# Data loading and preprocessing
import svhn_data as SVHN
X, Y, X_test, Y_test = SVHN.load_data(one_hot = True)
X, Y = shuffle(X, Y)

# Real-time data preprocessing
img_prep = ImagePreprocessing()
img_prep.add_featurewise_zero_center()
img_prep.add_featurewise_stdnorm()

# Convolutional network building
network = input_data(shape=[None, 32, 32, 3],
                     data_preprocessing=img_prep)
network = conv_2d(network, 32, 3, activation=‘relu‘)
network = max_pool_2d(network, 2)
network = conv_2d(network, 64, 3, activation=‘relu‘)
network = conv_2d(network, 64, 3, activation=‘relu‘)
network = max_pool_2d(network, 2)
network = fully_connected(network, 512, activation=‘relu‘)
network = dropout(network, 0.5)
network = fully_connected(network, 10, activation=‘softmax‘)
network = regression(network, optimizer=‘adam‘,
                     loss=‘categorical_crossentropy‘,
                     learning_rate=0.001)

# Train using classifier
model = tflearn.DNN(network, tensorboard_verbose=0)
model.fit(X, Y, n_epoch=15, shuffle=True, validation_set=(X_test, Y_test),
          show_metric=True, batch_size=96, run_id=‘svhn_cnn‘)

训练结果:

Training Step: 11452  | total loss: 0.68217 | time: 7.973s
| Adam | epoch: 015 | loss: 0.68217 - acc: 0.9329 -- iter: 72576/73257
Training Step: 11453  | total loss: 0.62980 | time: 7.983s
| Adam | epoch: 015 | loss: 0.62980 - acc: 0.9354 -- iter: 72672/73257
Training Step: 11454  | total loss: 0.58649 | time: 7.994s
| Adam | epoch: 015 | loss: 0.58649 - acc: 0.9356 -- iter: 72768/73257
Training Step: 11455  | total loss: 0.53254 | time: 8.005s
| Adam | epoch: 015 | loss: 0.53254 - acc: 0.9421 -- iter: 72864/73257
Training Step: 11456  | total loss: 0.49179 | time: 8.016s
| Adam | epoch: 015 | loss: 0.49179 - acc: 0.9416 -- iter: 72960/73257
Training Step: 11457  | total loss: 0.45679 | time: 8.027s
| Adam | epoch: 015 | loss: 0.45679 - acc: 0.9433 -- iter: 73056/73257
Training Step: 11458  | total loss: 0.42026 | time: 8.038s
| Adam | epoch: 015 | loss: 0.42026 - acc: 0.9469 -- iter: 73152/73257
Training Step: 11459  | total loss: 0.38929 | time: 8.049s
| Adam | epoch: 015 | loss: 0.38929 - acc: 0.9491 -- iter: 73248/73257
Training Step: 11460  | total loss: 0.35542 | time: 9.928s
| Adam | epoch: 015 | loss: 0.35542 - acc: 0.9542 | val_loss: 0.40315 - val_acc: 0.9085 -- iter: 73257/73257
时间: 2024-10-12 02:22:33

TFlearn——(2)SVHN的相关文章

深度学习方法(十):卷积神经网络结构变化——Maxout Networks,Network In Network,Global Average Pooling

技术交流QQ群:433250724,欢迎对算法.技术感兴趣的同学加入. 最近接下来几篇博文会回到神经网络结构的讨论上来,前面我在"深度学习方法(五):卷积神经网络CNN经典模型整理Lenet,Alexnet,Googlenet,VGG,Deep Residual Learning"一文中介绍了经典的CNN网络结构模型,这些可以说已经是家喻户晓的网络结构,在那一文结尾,我提到"是时候动一动卷积计算的形式了",原因是很多工作证明了,在基本的CNN卷积计算模式之外,很多简

(转)【重磅】无监督学习生成式对抗网络突破,OpenAI 5大项目落地

[重磅]无监督学习生成式对抗网络突破,OpenAI 5大项目落地 [新智元导读]"生成对抗网络是切片面包发明以来最令人激动的事情!"LeCun前不久在Quroa答问时毫不加掩饰对生成对抗网络的喜爱,他认为这是深度学习近期最值得期待.也最有可能取得突破的领域.生成对抗学习是无监督学习的一种,该理论由 Ian Goodfellow 提出,此人现在 OpenAI 工作.作为业内公认进行前沿基础理论研究的机构,OpenAI 不久前在博客中总结了他们的5大项目成果,结合丰富实例介绍了生成对抗网络

激活函数()(转)

神经网络之激活函数(Activation Function) 转载::http://blog.csdn.net/cyh_24 :http://blog.csdn.net/cyh_24/article/details/50593400 日常 coding 中,我们会很自然的使用一些激活函数,比如:sigmoid.ReLU等等.不过好像忘了问自己一(n)件事: 为什么需要激活函数? 激活函数都有哪些?都长什么样?有哪些优缺点? 怎么选用激活函数? 本文正是基于这些问题展开的,欢迎批评指正! (此图并

两种开源聊天机器人的性能测试(二)——基于tensorflow的chatbot

http://blog.csdn.net/hfutdog/article/details/78155676 开源项目链接:https://github.com/dennybritz/chatbot-retrieval/ 它实现一个检索式的机器人.采用检索式架构,有预定好的语料答复库.检索式模型的输入是上下文潜在的答复.模型输出对这些答复的打分,选择最高分的答案作为回复. 下面进入正题. 1.环境配置 首先此项目需要的基本条件是使用Python3(我用的是Python3.4),tensorflow

使用 IDEA 创建 Maven Web 项目 (异常)- Disconnected from the target VM, address: '127.0.0.1:59770', transport: 'socket'

运行环境: JDK 版本:1.8 Maven 版本:apache-maven-3.3.3 IDEA 版本:14 maven-jetty-plugin 配置: <plugin> <groupId>org.eclipse.jetty</groupId> <artifactId>jetty-maven-plugin</artifactId> <configuration> <webAppSourceDirectory>${pro

在深圳有娃的家长必须要懂的社保少儿医保,不然亏大了!(收藏)

在深圳有娃的家长必须要懂的社保少儿医保,不然亏大了!(收藏) 转载2016-07-26 17:21:47 标签:深圳少儿医保社保医疗保险住院 在深圳工作或生活的家长们可能还有人不清楚,其实小孩子最大的基础保障福利就是少儿医保.如果以前没重视关注的,现在您看到这篇文章还来得及!少儿医保每年政府财政补贴384元,自己只需交200元左右,就可以享受门诊报销1000元,住院报销比例90%,最高报销额度达148万,大病门诊最高报销比例90%!如何享受?有哪些待遇?接下来就详细来做一个介绍: 少儿医保投保需

彻底解决_OBJC_CLASS_$_某文件名&quot;, referenced from:问题(转)

最近在使用静态库时,总是出现这个问题.下面总结一下我得解决方法: 1. .m文件没有导入   在Build Phases里的Compile Sources 中添加报错的文件 2. .framework文件没有导入静态库编译时往往需要一些库的支持,查看你是否有没有导入的库文件同样是在Build Phases里的Link Binary With Libraries中添加 3. 重复编译,可能你之前复制过两个地方,在这里添加过两次,删除时系统没有默认删除编译引用地址在Build Settings里搜索

爱奇艺、优酷、腾讯视频竞品分析报告2016(一)

1 背景 1.1 行业背景 1.1.1 移动端网民规模过半,使用时长份额超PC端 2016年1月22日,中国互联网络信息中心 (CNNIC)发布第37次<中国互联网络发展状况统计报告>,报告显示,网民的上网设备正在向手机端集中,手机成为拉动网民规模增长的主要因素.截至2015年12月,我国手机网民规模达6.20亿,有90.1%的网民通过手机上网. 图 1  2013Q1~2015Q3在线视频移动端和PC端有效使用时长份额对比 根据艾瑞网民行为监测系统iUserTracker及mUserTrac

Android 导航条效果实现(六) TabLayout+ViewPager+Fragment

TabLayout 一.继承结构 public class TabLayout extends HorizontalScrollView java.lang.Object ? android.view.View ? android.view.ViewGroup ? android.widget.FrameLayout ? android.widget.HorizontalScrollView ? android.support.design.widget.TabLayout 二.TabLayou