用TensorFlow搭建网络训练、验证并测试

原文连接  https://blog.csdn.net/yutingzhaomeng/article/details/81708261

本文总结tensorflow使用的相关方法,包括:

0、定义网络输入

1、如何利用tensorflow在已有网络入resnet基础上搭建自己的网络结构

2、如何添加自己的网络层

3、如何导入已有模块入resnet全连接层之前部分的参数

4、定义网络损失

5、定义优化算子以及衰减优化算子

6、预测网络输出

7、保存网络模型

8、自定义生成训练batch

9、训练网络

10、利用tensorboard可视化训练过程

0、定义网络输入

inputs = tf.placeholder(tf.float32, [None, 224, 224, 3], name=‘inputs‘)
labels = tf.placeholder(tf.int32, [None], name=‘lables‘)
is_training = tf.placeholder(tf.bool, name=‘is_training‘)
    这里inputs表示输入数据,labels表示对应的label,is_training主要用于区分如drop和batchnorm层的训练测试阶段。

1、如何利用tensorflow在已有网络入resnet基础上搭建自己的网络结构

with slim.arg_scope(nets.resnet_v1.resnet_arg_scope()):
if config.TRAIN.net_layer == ‘50‘:
logits, endpoints = nets.resnet_v1.resnet_v1_50(inputs, num_classes=None, is_training=is_training)
if config.TRAIN.net_layer == ‘101‘:
logits, endpoints = nets.resnet_v1.resnet_v1_101(inputs, num_classes=None, is_training=is_training)
if config.TRAIN.net_layer == ‘152‘:
logits, endpoints = nets.resnet_v1.resnet_v1_152(inputs, num_classes=None, is_training=is_training)
    以resnet为例,logits表示bottleneck特征,num_classes设置为None表示取bottleneck特征。

2、如何添加自己的网络层

with tf.variable_scope(‘Logits‘):
logits = tf.squeeze(logits, axis=[1,2])
logits = slim.dropout(logits, keep_prob=0.5, scope=‘scope‘)
logits = slim.fully_connected(logits, num_outputs=config.DATASET.num_classes, activation_fn=None, scope=‘fc‘)
    这里有一个scope,后面我们会发现,主要用来区别resnet已有参数,squeeze用于将1*1*512的特征拉伸为向量,我们添加dropout层和全连接层。

3、如何导入已有模块入resnet全连接层之前部分的参数

checkpoint_exclude_scopes = ‘Logits‘
exclusions = None
if checkpoint_exclude_scopes:
exclusions = [scope.strip() for scope in checkpoint_exclude_scopes.split(‘,‘)]
variables_to_restore = []
for var in slim.get_model_variables():
excluded = False
for exclusion in exclusions:
if var.op.name.startswith(exclusion):
excluded = True
if not excluded:
variables_to_restore.append(var)
logits scope下的变量我们不考虑,其他参数restore恢复。

4、定义网络损失

loss = tf.reduce_mean(tf.nn.sparse_softmax_cross_entropy_with_logits(labels=labels, logits=logits))
5、定义优化算子以及衰减优化算子

optimizer = tf.train.GradientDescentOptimizer(learning_rate=0.0001)
train_step = optimizer.minimize(loss)
batch = config.TRAIN.batch_size
sample_size = len(os.listdir(config.DATASET.image_root))
global_step = tf.Variable(0)
learning_rate = tf.train.exponential_decay(1e-4, global_step,
decay_steps=4 * sample_size / batch, decay_rate=0.98,
staircase=True)
train_step = tf.train.GradientDescentOptimizer(learning_rate).minimize(loss, global_step=global_step)
    上面的表示正常定义优化算子,下面的表示衰减优化算子。其中,batch表示每个batch样本数,sample_size即样本数,global_step用于获取当前iteration,sample_size / batch即每个epoch包含的iteration数目,计算衰减时,每一个decay_steps降低一次学习率。learning_rate_current = learning_rate_start * dacay_rate ** (global_step / decay_steps)。

6、预测网络输出

logits = tf.nn.softmax(logits, name=‘logits‘)
classes = tf.argmax(logits, axis=1, name=‘classes‘)
accuracy = tf.reduce_mean(tf.cast(tf.equal(tf.cast(classes, tf.int32), labels), tf.float32))
7、保存网络模型

init = tf.global_variables_initializer()
saver_restore = tf.train.Saver(var_list=variables_to_restore)
saver = tf.train.Saver(tf.global_variables())
8、自定义生成训练batch

images, truths, valid_imgs, valid_trus = get_batch()
def get_label(xml_path):
tree = ET.parse(xml_path)
objs = tree.findall(‘object‘)

objs = [obj for obj in objs if ‘b‘ in obj.find(‘name‘).text] # select all pointer pannels
if not len(objs) == 1:
return [[], []]
obj = objs[0] # suppose there is only one pannel, otherwise use center selection
label = str(float(obj.find(‘name‘).text.split(‘b‘)[-1]))
return [label]

def get_list():
image_list = []
label_list = []
for file in os.listdir(config.DATASET.image_root):
image_label = get_label(os.path.join(config.DATASET.label_root,file.split(‘.jpg‘)[0]+‘.xml‘))
if len(image_label) > 1:
continue
else:
image_label = image_label[0]
if image_label in config.DATASET.range_dict.keys():
label_list.append(config.DATASET.range_dict[image_label])
else:
label_list.append(len(config.DATASET.range_dict))
image_list.append(os.path.join(config.DATASET.image_root,file))
valid_num = int(len(image_list)*config.DATASET.valid_ratio)
train_list = image_list[valid_num:]
valid_list = image_list[:valid_num]
train_label = label_list[valid_num:]
valid_label = label_list[:valid_num]
return train_list, train_label, valid_list, valid_label

def process_batch(input_quene):

label = input_quene[1]
image = tf.read_file(input_quene[0])
image = tf.image.decode_jpeg(image, channels=3)
image = tf.image.resize_image_with_crop_or_pad(image, config.DATASET.width, config.DATASET.height)
image = tf.image.per_image_standardization(image)

image_batch, label_batch = tf.train.batch([image, label], batch_size=config.TRAIN.batch_size,
capacity=config.TRAIN.capacity, num_threads=config.TRAIN.num_threads)
label_batch = tf.reshape(label_batch, [config.TRAIN.batch_size])
image_batch = tf.cast(image_batch, tf.float32)

return image_batch, label_batch

def get_batch():
train_image_list, train_label_list, valid_image_list, valid_label_list = get_list()

input_quene = tf.train.slice_input_producer([train_image_list, train_label_list])
trian_image_batch, trian_label_batch = process_batch(input_quene)

valid_quene = tf.train.slice_input_producer([valid_image_list, valid_label_list])
valid_image_batch, valid_label_batch = process_batch(valid_quene)

return trian_image_batch, trian_label_batch, valid_image_batch, valid_label_batch
9、训练网络

with tf.Session(config=tfConfig) as sess:

sess.run(init)
coord = tf.train.Coordinator()
threads = tf.train.start_queue_runners(sess=sess, coord=coord)

# =============================Import Pretrained Parameter=========================== #
saver_restore.restore(sess, config.TRAIN.model_path)

# ================================TensorBoard Related================================ #
tf.summary.image(‘inputs‘,inputs)
tf.summary.scalar(‘loss‘,loss)
tf.summary.scalar(‘accuracy‘,accuracy)
tf.summary.scalar(‘learning rate‘, learning_rate)
merged_summary_op = tf.summary.merge_all()
if os.path.exists(os.path.join(config.TRAIN.log_path, ‘train‘)):
shutil.rmtree(os.path.join(config.TRAIN.log_path, ‘train‘))
if os.path.exists(os.path.join(config.TRAIN.log_path, ‘valid‘)):
shutil.rmtree(os.path.join(config.TRAIN.log_path, ‘valid‘))
train_writer = tf.summary.FileWriter(os.path.join(config.TRAIN.log_path, ‘train‘), sess.graph)
valid_writer = tf.summary.FileWriter(os.path.join(config.TRAIN.log_path, ‘valid‘))

for i in range(config.TRAIN.num_iterations):
images_, truths_ = sess.run([images, truths])
valid_imgs_, valid_trus_ = sess.run([valid_imgs, valid_trus])

summary_str, _, loss_, acc_ = sess.run([merged_summary_op, train_step, loss, accuracy], \
feed_dict={inputs: images_, labels: truths_, is_training: True})
valid_str, vloss, vacc = sess.run([merged_summary_op, loss, accuracy], \
feed_dict={inputs: valid_imgs_, labels: valid_trus_, is_training: False})

print(‘Step: {}, Loss: {:.4f}, Accuracy: {:.4f}, Valid Loss: {:.4f}, Valid Accuracy: {:.4f}‘.format(i+1, loss_, acc_, vloss, vacc))

# if (i+1) % 1000 == 0:
# saver.save(sess, config.TRAIN.save_path)
# print(‘save mode to {}‘.format(config.TRAIN.save_path))

# summary_str = sess.run(merged_summary_op)
train_writer.add_summary(summary_str, i)
valid_writer.add_summary(valid_str, i)

coord.request_stop()
coord.join(threads)
10、利用tensorboard可视化训练过程

tf.summary.image(‘inputs‘,inputs)
tf.summary.scalar(‘loss‘,loss)
tf.summary.scalar(‘accuracy‘,accuracy)
tf.summary.scalar(‘learning rate‘, learning_rate)
merged_summary_op = tf.summary.merge_all()
if os.path.exists(os.path.join(config.TRAIN.log_path, ‘train‘)):
shutil.rmtree(os.path.join(config.TRAIN.log_path, ‘train‘))
if os.path.exists(os.path.join(config.TRAIN.log_path, ‘valid‘)):
shutil.rmtree(os.path.join(config.TRAIN.log_path, ‘valid‘))
train_writer = tf.summary.FileWriter(os.path.join(config.TRAIN.log_path, ‘train‘), sess.graph)
valid_writer = tf.summary.FileWriter(os.path.join(config.TRAIN.log_path, ‘valid‘))

原文地址:https://www.cnblogs.com/happytaiyang/p/11618659.html

时间: 2024-08-30 06:45:45

用TensorFlow搭建网络训练、验证并测试的相关文章

使用CNN(convolutional neural nets)检测脸部关键点教程(二):浅层网络训练和测试

第三部分 第一个模型:一个隐层结构的传统神经网络 这一部分让我们从代码开始: # add to kfkd.py from lasagne import layers from lasagne.updates import nesterov_momentum from nolearn.lasagne import NeuralNet net1 = NeuralNet( layers=[ # three layers: one hidden layer ('input', layers.InputL

Caffe-python interface 学习|网络训练、部署、测试

继续python接口的学习.剩下还有solver.deploy文件的生成和模型的测试. 网络训练 solver文件生成 其实我觉得用python生成solver并不如直接写个配置文件,它不像net配置一样有很多重复的东西. 对于一下的solver配置文件: base_lr: 0.001 display: 782 gamma: 0.1 lr_policy: "step" max_iter: 78200 #训练样本迭代次数=max_iter/782(训练完一次全部样本的迭代数) momen

关于训练集,验证集,测试集的划分

首先需要说明的是:训练集(training set).验证集(validation set)和测试集(test set)本质上并无区别,都是把一个数据集分成三个部分而已,都是(feature, label)造型.尤其是训练集与验证集,更无本质区别.测试集可能会有一些区别,比如在一些权威计算机视觉比赛中,测试集的标签是private的,也就是参赛者看不到测试集的标签,可以把预测的标签交给大赛组委会,他们根据你提交的预测标签来评估参赛者模式识别系统的好坏,以防作弊. 通常,在训练有监督的机器学习模型

使用Tensorflow搭建回归预测模型之二:数据准备与预处理

前言: 在前一篇中,已经搭建好了Tensorflow环境,本文将介绍如何准备数据与预处理数据. 正文: 在机器学习中,数据是非常关键的一个环节,在模型训练前对数据进行准备也预处理是非常必要的. 一.数据准备: 一般分为三个步骤:数据导入,数据清洗,数据划分. 1.数据导入: 数据存放在原始格式多种多样,具体取决于用于导入数据的机制和数据的来源.比如:有*.csv,*.txt,*xls,*.json等. 2.数据清洗: 数据清洗主要发现并纠正数据中的错误,包含检查数据的一致性,数据的无效值,以及缺

(转)一文学会用 Tensorflow 搭建神经网络

一文学会用 Tensorflow 搭建神经网络 本文转自:http://www.jianshu.com/p/e112012a4b2d 字数2259 阅读3168 评论8 喜欢11 cs224d-Day 6: 快速入门 Tensorflow 本文是学习这个视频课程系列的笔记,课程链接是 youtube 上的,讲的很好,浅显易懂,入门首选, 而且在github有代码,想看视频的也可以去他的优酷里的频道找. Tensorflow 官网 神经网络是一种数学模型,是存在于计算机的神经系统,由大量的神经元相

基于tensorflow搭建一个神经网络

一,tensorflow的简介 Tensorflow是一个采用数据流图,用于数值计算的 开源软件库.节点在图中表示数字操作,图中的线 则表示在节点间相互联系的多维数据数组,即张量 它灵活的架构让你可以在多种平台上展开计算,例 如台式计算机中的一个或多个CPU(或GPU), 服务器,移动设备等等.Tensorflow最初由Google 大脑小组的研究员和工程师们开发出来,用于机器 学习和深度神经网络方面的研究,但这个系统的通 用性使其也可广泛用于其他计算领域. 二,tensorflow的架构 Te

【TensorFlow/简单网络】MNIST数据集-softmax、全连接神经网络,卷积神经网络模型

初学tensorflow,参考了以下几篇博客: soft模型 tensorflow构建全连接神经网络 tensorflow构建卷积神经网络 tensorflow构建卷积神经网络 tensorflow构建CNN[待学习] 全连接+各种优化[待学习] BN层[待学习] 先解释以下MNIST数据集,训练数据集有55,000 条,即X为55,000 * 784的矩阵,那么Y为55,000 * 10的矩阵,每个图片是28像素*28像素,带有标签,Y为该图片的真实数字,即标签,每个图片10个数字,1所在位置

模式识别之卷及网络---卷及网络 训练太慢

摘要:CIFAR-10竞赛之后,卷积网络之父Yann LeCun接受相关采访.他认为:卷积网络需要大数据和高性能计算机的支持:深层卷积网络的训练时间不是问题,运行时间才是关键.Yann LeCun还分享了他正在做的一些最新研究. Kaggle近期举办了一场 关于CIFAR-10数据集的竞赛,该数据集包含有6万个32*32的彩色图像,共分为10种类型,由 Alex Krizhevsky, Vinod Nair和 Geoffrey Hinton收集而来. 很多竞赛选手使用了卷积网络来完成这场竞赛,其

网络上可供测试的Web Service

腾讯QQ在线状态 WEB 服务Endpoint: http://www.webxml.com.cn/webservices/qqOnlineWebService.asmx Disco: http://www.webxml.com.cn/webservices/qqOnlineWebService.asmx?discoWSDL: http://www.webxml.com.cn/webservices/qqOnlineWebService.asmx?wsdl通过输入QQ号码(String)检测QQ