tfrecord数据集训练验证-猫狗大战

#!/usr/bin/env python
# -*- coding:utf-8 -*-

from mk_tfrecord import *
#from model import *
from inception_v3 import *
import numpy as np
import os
import cv2

os.environ["CUDA_VISIBLE_DEVICES"] = "2"

def training():
    N_CLASSES = 2              # 分类数目
    IMG_W = 299                # 统一图片大小,宽度
    IMG_H = 299                # 统一图片大小,高度
    BATCH_SIZE = 64            # 批次大小
    MAX_STEP = 50000           # 迭代次数
    LEARNING_RATE = 0.0001     # 学习率
    min_after_dequeue = 1000

    tfrecord_filename = ‘/home/xieqi/project/cat_dog/train.tfrecords‘   # 训练数据集
    logs_dir = ‘/home/xieqi/project/cat_dog/log_v3‘     # 检查点保存路径

    # 输入--要生成的字符串的一维字符串张量,shuffle默认为True,输出--字符串队列
    # 将字符串(例如文件名)输出到输入管道的队列,不限制num_epoch。
    filename_queue = tf.train.string_input_producer([tfrecord_filename], num_epochs=150)
    train_image, train_label = read_and_decode(filename_queue, image_W=IMG_W, image_H=IMG_H,
                                            batch_size=BATCH_SIZE,min_after_dequeue=min_after_dequeue) # 返回的为tensor

    train_labels = tf.one_hot(train_label, N_CLASSES)

    train_logits,_ = inception_v3(train_image,num_classes=N_CLASSES)
    train_loss = loss(train_logits, train_labels) # 损失函数
    train_acc = accuracy(train_logits, train_labels) # 模型精确度
    my_global_step = tf.Variable(0, name=‘global_step‘, trainable=False) # 全局步长
    train_op = optimize(train_loss, LEARNING_RATE, my_global_step) #训练模型

    summary_op = tf.summary.merge_all() # 收集模型统计信息
    init_op = tf.group(tf.global_variables_initializer(), tf.local_variables_initializer())#初始化全局变量和局部变量

    # 限制GPU使用率
    # sess_config = tf.ConfigProto()
    # sess_config.gpu_options.per_process_gpu_memory_fraction = 0.70
    # sess = tf.Session(config=sess_config)

    sess = tf.Session()
    # FileWriter类提供了一个机制来创建指定目录的事件文件,并添加摘要和事件给它(异步更新,不影响训练速度)
    train_writer = tf.summary.FileWriter(logs_dir, sess.graph)
    # 将Save类添加OPS保存和恢复变量和检查点。对模型定期做checkpoint,通常用于模型恢复
    saver = tf.train.Saver()

    sess.run(init_op)
    coord = tf.train.Coordinator() # 线程协调员, 实现一种简单的机制来协调一组线程的终止
    threads = tf.train.start_queue_runners(sess=sess, coord=coord) #启动图中收集的所有队列, 开始填充队列
    try:
        for step in range(MAX_STEP):
            if coord.should_stop():
                break

            image_batch, label_batch = sess.run([train_image, train_label]) #获取一个批次的数据及标签
            sess.run(train_op)

            #每迭代100次计算一次loss和准确率
            if step % 100 == 0:
                losses, acc = sess.run([train_loss, train_acc])
                print(‘Step: %6d, loss: %.8f, accuracy: %.2f%%‘ % (step, losses, acc))
                summary_str = sess.run(summary_op)
                train_writer.add_summary(summary_str, step)

            if step % 1000 == 0 or step == MAX_STEP - 1:  # 保存检查点
                checkpoint_path = os.path.join(logs_dir, ‘model.ckpt‘)
                saver.save(sess, checkpoint_path, global_step=step)

    except tf.errors.OutOfRangeError:
        print(‘Done.‘)
    finally:
        coord.request_stop()

    coord.join(threads=threads)
    sess.close()

# 测试检查点
def eval():
    N_CLASSES = 2
    IMG_W = 299
    IMG_H = 299
    BATCH_SIZE = 1
    MAX_STEP = 512
    min_after_dequeue=0

    test_dir = ‘/home/xieqi/project/cat_dog/val.tfrecords‘ #测试集数据
    logs_dir = ‘/home/xieqi/project/cat_dog/log_v3‘     # 检查点目录
    false_pic_dir = ‘/home/xieqi/project/cat_dog/false_pic/‘ #错误分类的图片存储地址

    filename_queue = tf.train.string_input_producer([test_dir], num_epochs=1)#输入要生成的字符串的一维字符张量,输出字符串队列,shuffle默认为True
    train_image, train_label = read_and_decode(filename_queue, image_W=IMG_W, image_H=IMG_H,
                                                batch_size=BATCH_SIZE,min_after_dequeue=min_after_dequeue) # 返回的为tensor

    train_labels = tf.one_hot(train_label, N_CLASSES)

    train_logits, _ = inception_v3(train_image, N_CLASSES)
    train_logits = tf.nn.softmax(train_logits)  # 用softmax转化为百分比数值

    #计算准确率
    correct_num = tf.placeholder(‘float‘)
    correct_pre = tf.div(correct_num, MAX_STEP)

    init_op = tf.group(tf.global_variables_initializer(), tf.local_variables_initializer())
    sess = tf.Session()
    sess.run(init_op)
    # 载入检查点
    saver = tf.train.Saver()
    print(‘\n载入检查点...‘)
    ckpt = tf.train.get_checkpoint_state(logs_dir) #通过checkpoint文件找到模型文件名,有两个属性:model_checkpoint_path最新的模型文件的文件名
                                                    # all_model_checkpoint_paths未被删除的所有模型文件的文件名
    if ckpt and ckpt.model_checkpoint_path:
        global_step = int(ckpt.model_checkpoint_path.split(‘/‘)[-1].split(‘-‘)[-1])
        saver.restore(sess, ckpt.model_checkpoint_path)
        print(‘载入成功,global_step = %d\n‘ % global_step)
    else:
        print(‘没有找到检查点‘)
    coord = tf.train.Coordinator()
    threads = tf.train.start_queue_runners(sess=sess, coord=coord)

    try:
        correct = 0
        wrong = 0
        dt_list = []
        for step in range(MAX_STEP):

            if coord.should_stop():
                break

            st = time.time()
            image, prediction, labels = sess.run([train_image, train_logits, train_labels])
            dt = time.time() - st
            dt_list.append(dt)

            p_max_index = np.argmax(prediction)
            c_max_index = np.argmax(labels)

            if p_max_index == c_max_index:
                for i in range(BATCH_SIZE):
                    correct += 1
            else:
                for i in range(BATCH_SIZE):
                    wrong += 1
                    cv2.imwrite(false_pic_dir + ‘ture‘ + str(labels) + ‘predict‘ + str(prediction) + ‘.jpg‘, image[i])

        accuray_rate = sess.run(correct_pre,feed_dict={correct_num: correct})
        velocity = np.mean(dt_list)
        print(‘Total: %5d, correct: %5d, wrong: %5d, accuracy: %3.2f%%, each speed: %.4fs‘ %
              (MAX_STEP, correct, wrong, accuray_rate * 100, velocity))
    except tf.errors.OutOfRangeError:
        print(‘OutOfRange‘)
    finally:
        coord.request_stop()

    coord.join(threads=threads)
    sess.close()

if __name__ == ‘__main__‘:
    training()
    #eval()

原文地址:https://www.cnblogs.com/xieqi/p/9685965.html

时间: 2024-08-30 16:04:05

tfrecord数据集训练验证-猫狗大战的相关文章

Keras猫狗大战五:采用全部数据集训练,精度提高到90%

深度学习严重依赖训练数据量的大小,前面(https://www.cnblogs.com/zhengbiqing/p/11070783.html)只随机抽取猫狗图片各1000.500.200分别作为训练.验证.测试集,即使采用了数据增强,精度只达到83%. 采用kaggle 猫狗数据集全部25000张进行训练学习,随机选取猫狗图片各9000.2250.1250分别作为训练.验证.测试集,进行训练. 训练100次迭代: history = model.fit_generator( train_gen

YOLOv3自有数据集训练

YOLO的作者表明他已经放弃CV,darknet是一个C语言库.无论从哪个方面来看,YOLO都是非常奇葩的一个类库.俄罗斯人AlexeyAB也属于其中一部分,他的代码以win平台为主,有很多有趣特点. 一.基本情况 https://github.com/AlexeyAB/darknet 非常详细地讲解了AlexeyAB版的darknet的配置方法.最为直观的是可以实时显示loss和mAP图. 它的中文翻译版本(有所简化) https://zhuanlan.zhihu.com/p/10262837

从零到一:caffe-windows(CPU)配置与利用mnist数据集训练第一个caffemodel

一.前言 本文会详细地阐述caffe-windows的配置教程.由于博主自己也只是个在校学生,目前也写不了太深入的东西,所以准备从最基础的开始一步步来.个人的计划是分成配置和运行官方教程,利用自己的数据集进行训练和利用caffe来实现别人论文中的模型(目前在尝试的是轻量级的SqueezeNet)三步走.不求深度,但求详细.因为说实话caffe-windows的配置当初花了挺多时间的,目前貌似还真没有从头开始一步步讲起的教程,所以博主就争取试着每一步都讲清楚吧. 这里说些题外话:之所以选择Sque

tensorflow 2.0 学习 (十一)卷积神经网络 (一)MNIST数据集训练与预测 LeNet-5网络

网络结构如下: 代码如下: 1 # encoding: utf-8 2 3 import tensorflow as tf 4 from tensorflow import keras 5 from tensorflow.keras import layers, Sequential, losses, optimizers, datasets 6 import matplotlib.pyplot as plt 7 8 Epoch = 30 9 path = r'G:\2019\python\mn

使用自己的数据集训练和测试"caffenet"

主要步骤可参考: http://blog.csdn.net/u010194274/article/details/50575284 补充几点: 1. convert函数是ImageMagick包里面的,在使用之前要进行安装 sudo apt-get install ImageMagick 2. 在将图片大小处理为256x256的时候,这里需要注意,数字之间使用的是字母x,而不是乘号 3. shell脚本中使用到的路径,最好都使用绝对路径 4. 作者在网络定义部分说的并不明确,补充如下:solve

机器学习:验证数据集与交叉验证

# 问题:如果将所有的数据集都作为训练数据集,则对于训练出的模型是否发生了过拟合会不自知,因为过拟合情况下,模型在训练数据集上的误差非常的小,使人觉得模型效果很好,但实际上可能泛化能力不足: # 方案:将数据集分割为训练数据集和测试数据集,通过测试数据集判断模型的好坏--如果通过学习曲线发现,模型在训练数据集上效果较好,在测试数据集上效果不好,模型出现过拟合,需要调整参数来重新得到模型,然后再次进行测试:以此类推循环此过程,最终得到最佳模型. # 最佳模型:也就是在测试数据集上表现的比较好的模型

Ubuntu14.04+caffe+cuda7.5 环境搭建以及MNIST数据集的训练与测试

Ubuntu14.04+caffe+cuda 环境搭建以及MNIST数据集的训练与测试 一.ubuntu14.04的安装: ubuntu的安装是一件十分简单的事情,这里给出一个参考教程: http://jingyan.baidu.com/article/76a7e409bea83efc3b6e1507.html 二.cuda的安装: 1.首先下载nvidia cuda的仓库安装包(我的是ubuntu 14.04 64位,所以下载的是ubuntu14.04的安装包,如果你是32位的可以参看具体的地

使用faster-rcnn.pytorch训练自己数据集

引言 最近在实验室复现faster-rcnn代码,基于此项目jwyang/faster-rcnn.pytorch(目前GitHub上star最多的faster-rcnn实现),成功测试源码数据集后,想使用自己的数据集爽一下. 本文主要介绍如何跑通源代码并“傻瓜式”训练自己的数据集~之前的此类博客都是介绍如何在原作者的caffe源码下进行数据集训练,那么本文针对目前形势一片大好的pytorh版faster-rcnn源码进行训练新的数据集,废话不多说,Lets go! faster-rcnn pyt

用交叉验证改善模型的预测表现

预测模型为何无法保持稳定? 让我们通过以下几幅图来理解这个问题: 此处我们试图找到尺寸(size)和价格(price)的关系.三个模型各自做了如下工作: 第一个模型使用了线性等式.对于训练用的数据点,此模型有很大误差.这样的模型在初期排行榜和最终排行榜都会表现不好.这是"拟合不足"("Under fitting")的一个例子.此模型不足以发掘数据背后的趋势. 第二个模型发现了价格和尺寸的正确关系,此模型误差低/概括程度高. 第三个模型对于训练数据几乎是零误差.这是因