tensorflow学习之(八)使用dropout解决overfitting(过拟合)问题

#使用dropout解决overfitting(过拟合)问题
#如果有dropout,在feed_dict的参数中一定要加入dropout的值
import tensorflow as tf
from sklearn.datasets import load_digits
from sklearn.cross_validation import train_test_split
from sklearn.preprocessing import LabelBinarizer

    #load datas 导入klearn中digits手写字体数据集
digits = load_digits()
X = digits.data         #加载从0-9的数字集
y = digits.target       #y为X所对应的标签
    #fit(y) 返回一个实例
    #fit_transform(y) 返回 和y一样的形状
y = LabelBinarizer().fit_transform(y)
    #train_test_split(train_data,train_target,test_size=0.4, random_state=0)
    # 是交叉验证中常用的函数,功能是从样本中随机的按比例选取train_data和test_data
    #参数解释:
    #train_data:所要划分的样本特征集
    #train_target:所要划分的样本结果
    #test_size:样本占比,如果是整数的话就是样本的数量
    #random_state:是随机数的种子。
    #随机数种子:其实就是该组随机数的编号,在需要重复试验的时候,保证得到一组一样的随机数。
    # 比如你每次都填1,其他参数一样的情况下你得到的随机数组是一样的。但填0或不填,每次都会不一样。
    #随机数的产生取决于种子,随机数和种子之间的关系遵从以下两个规则:
    #种子不同,产生不同的随机数;种子相同,即使实例不同也产生相同的随机数。
X_train,X_test,y_train,y_test = train_test_split(X,y,test_size=0.3)

‘‘‘
#fit_transform()、inverse_transform使用的例子
#程序
from sklearn import preprocessing
feature = [[0,1], [1,1], [0,0], [1,0]]
label= [‘yes‘, ‘no‘, ‘yes‘, ‘no‘]
lb = preprocessing.LabelBinarizer() #构建一个转换对象
Y = lb.fit_transform(label)
re_label = lb.inverse_transform(Y)#还原之前的label
print(Y)
print(re_label)
#结果
[[1]
 [0]
 [1]
 [0]]
[‘yes‘ ‘no‘ ‘yes‘ ‘no‘]
‘‘‘

# 定义一个神经层
def add_layer(inputs, in_size, out_size,layer_name, activation_function=None):
    #add one more layer and return the output of the layer
    Weights = tf.Variable(tf.random_normal([in_size, out_size]))
    biases = tf.Variable(tf.zeros([1, out_size]) + 0.1)
    Wx_plus_b = tf.matmul(inputs, Weights) + biases
    Wx_plus_b = tf.nn.dropout(Wx_plus_b, keep_pro)#使用dropout机制,解决overfitting问题
    if activation_function is None:
        outputs = Wx_plus_b
    else:
        outputs = activation_function(Wx_plus_b)
    tf.summary.histogram(layer_name+‘/output‘,outputs)
    return outputs

#define placeholder for inputs to network
keep_pro = tf.placeholder(tf.float32)#dropout机制使用
xs = tf.placeholder(tf.float32, [None, 64])  # none表示无论给多少个例子都行,64=8*8
ys = tf.placeholder(tf.float32, [None, 10])  #表示10个需要识别的数字

#add output layer
l1 = add_layer(xs, 64, 50,‘l1‘,activation_function=tf.nn.tanh)
prediction = add_layer(l1, 50, 10,‘l2‘, activation_function=tf.nn.softmax)

#the error between prediction and real data
cross_entropy = tf.reduce_mean(-tf.reduce_sum(ys * tf.log(prediction),reduction_indices=[1])) #loss function
tf.summary.scalar(‘loss‘,cross_entropy)
train_step = tf.train.GradientDescentOptimizer(0.6).minimize(cross_entropy)

sess = tf.Session()
merged = tf.summary.merge_all()
sess.run(tf.initialize_all_variables())#tf.initialize_all_variables()以被弃用
#sess.run(tf.global_variables_initializer())

#summary writer goes in here
train_writer = tf.summary.FileWriter("../../logs/train",sess.graph)
test_writer = tf.summary.FileWriter("../../logs/test",sess.graph)

for i in range(500):
    sess.run(train_step,feed_dict={xs: X_train, ys: y_train,keep_pro:0.6})#保持0.6的概率不被drop掉
    if i%50 == 0:
        # record loss
        train_result = sess.run(merged, feed_dict={xs: X_train, ys: y_train,keep_pro:1})
        test_result = sess.run(merged, feed_dict={xs: X_test, ys: y_test,keep_pro:1})
        train_writer.add_summary(train_result, i)
        test_writer.add_summary(test_result, i)

原文地址:https://www.cnblogs.com/Harriett-Lin/p/9591448.html

时间: 2024-10-09 13:17:37

tensorflow学习之(八)使用dropout解决overfitting(过拟合)问题的相关文章

TensorFlow实战第七课(dropout解决overfitting)

Dropout 解决 overfitting overfitting也被称为过度学习,过度拟合.他是机器学习中常见的问题. 图中的黑色曲线是正常模型,绿色曲线就是overfitting模型.尽管绿色曲线很精确的区分了所有的训练数据,但是并没有描述数据的整体特征,对新测试的数据适应性比较差. 举个Regression(回归)的例子. 第三条曲线存在overfitting问题,尽管它经过了所有的训练点,但是不能很好地反映数据的趋势,预测能力严重不足.tensorflow提供了强大的dropout方法

java基础知识回顾之java Thread类学习(八)--java多线程通信等待唤醒机制经典应用(生产者消费者)

 *java多线程--等待唤醒机制:经典的体现"生产者和消费者模型 *对于此模型,应该明确以下几点: *1.生产者仅仅在仓库未满的时候生产,仓库满了则停止生产. *2.消费者仅仅在有产品的时候才能消费,仓空则等待. *3.当消费者发现仓储没有产品可消费的时候,会唤醒等待生产者生产. *4.生产者在生产出可以消费的产品的时候,应该通知等待的消费者去消费. 下面先介绍个简单的生产者消费者例子:本例只适用于两个线程,一个线程生产,一个线程负责消费. 生产一个资源,就得消费一个资源. 代码如下: pub

MyBatis学习总结(八)——Mybatis3.x与Spring4.x整合(转载)

孤傲苍狼 只为成功找方法,不为失败找借口! MyBatis学习总结(八)--Mybatis3.x与Spring4.x整合 一.搭建开发环境 1.1.使用Maven创建Web项目 执行如下命令: mvn archetype:create -DgroupId=me.gacl -DartifactId=spring4-mybatis3 -DarchetypeArtifactId=maven-archetype-webapp -DinteractiveMode=false 如下图所示: 创建好的项目如下

马哥学习笔记八——LAMP编译安装之PHP及xcache

1.解决依赖关系: 请配置好yum源(可以是本地系统光盘)后执行如下命令: # yum -y groupinstall "X Software Development" 如果想让编译的php支持mcrypt扩展,此处还需要下载如下两个rpm包并安装之: libmcrypt-2.5.7-5.el5.i386.rpm libmcrypt-devel-2.5.7-5.el5.i386.rpm 2.编译安装php-5.4.13 首先下载源码包至本地目录. # tar xf php-5.4.13

从零开始学习jQuery (八) 插播:jQuery实施方案

原文:从零开始学习jQuery (八) 插播:jQuery实施方案 本系列文章导航 从零开始学习jQuery (一) 开天辟地入门篇 从零开始学习jQuery (二) 万能的选择器 从零开始学习jQuery (三) 管理jQuery包装集 从零开始学习jQuery (四) 使用jQuery操作元素的属性与样式 从零开始学习jQuery (五) 事件与事件对象 从零开始学习jQuery (六) jQuery中的Ajax 从零开始学习jQuery (七) jQuery动画-让页面动起来! 从零开始学

从零学习Fluter(八):Flutter的四种运行模式--Debug、Release、Profile和test以及命名规范

从零学习Fluter(八):Flutter的四种运行模式--Debug.Release.Profile和test以及命名规范 好几天没有跟新我的这个系列文章,一是因为这两天我又在之前的基础上,重新认识flutter,觉得flutter这个东西越来越有意思.并且水很深 今天简单分享一下开发学习中的小知识点 Flutter有四种运行模式:Debug.Release.Profile和test,这四种模式在build的时候是完全独立的 Debug ??Debug模式可以在真机和模拟器上同时运行:会打开所

linux学习第八周总结

linux学习第八周总结 本周学习了两个服务,DNS和ansible 由于这些服务很复杂,我也只能是到达刚了解或者是刚刚入门的程度,所以只说一些简单基本的东西,简单总结. 一.DNS服务 1.简介 域名系统(英文:DomainNameSystem,缩写:DNS)是互联网的一项服务.它作为将域名和IP地址相互映射的一个分布式数据库,能够使人更方便地访问互联网.DNS使用TCP和UDP端口53.当前,对于每一级域名长度的限制是63个字符,域名总长度则不能超过253个字符. 记录类型 主条目:域名服务

Tensorflow学习笔记2:About Session, Graph, Operation and Tensor

简介 上一篇笔记:Tensorflow学习笔记1:Get Started 我们谈到Tensorflow是基于图(Graph)的计算系统.而图的节点则是由操作(Operation)来构成的,而图的各个节点之间则是由张量(Tensor)作为边来连接在一起的.所以Tensorflow的计算过程就是一个Tensor流图.Tensorflow的图则是必须在一个Session中来计算.这篇笔记来大致介绍一下Session.Graph.Operation和Tensor. Session Session提供了O

javascript基础学习(八)

javascript之日期对象 学习要点: 日期对象 将日期对象转换为字符串 将日期对象中的日期和时间转换为字符串 日期对象中的日期 日期对象中的时间 设置日期对象中的日期 设置日期对象中的时间 与毫秒相关的方法 一.日期对象 在javascript中并没有日期型的数据类型,但是提供了一个日期对象可以操作日期和时间. 日期对象的创建: new Date(); 二.将日期对象转换为字符串 将日期对象转换为字符串可以使用以下4种方法: date.toString();//将日期对象转换为字符串时,采