神经网络优化(二) - 搭建神经网络八股

为提高程序的可复用性,搭建模块化的神经网络八股

1 前向传播

前向传播就是设计、搭建从输入(参数 x ) 到输出(返回值为预测或分类结果 y )的完整网络结构,实现前向传播过程,一般将其放在 forward.py 文件中

前向传播需要定义三个函数(实际上第一个函数是框架,第二、三个函数是赋初值过程)

def forward(x, regularizer):
    w =
    b =
    y =
    return y

函数功能:

  • 定义前向传播过程,返回值为y
  • 完成网络结构的设计,实现从输入到输出的数据通路
  • regularizer 为正则化权重
def get_weight(shape, regularizer):
    w = tf.Variable()
    tf.add_to_collection(‘losses‘, tf.contrib.layers.l2_regularizer(regularizer)(w))
    return w 

函数功能:

  • 为 w 赋初值,
  • 把每一个 w 的正则化损失加到总损失losses中
  • 返回 w
def get_bias(shape):
    b = tf.Variable()
    return b

函数功能

  • 为 b 赋初值
  • shape的形状实际上就是某层 b 的个数

2 反向传播

反向传播是神经网络训练过程,优化神经网络参数,一般将其放在 backward.py 文件中

def backward():
    x = tf.placeholder()
    y_ = tf.placeholder()
    y = forward.forward(x, REGULARIZER)
    global_step = tf.Variable(0, trainable=False)
    loss =

函数功能:

  • backward 函数用来描述反向传播过程
  • placeholder 给 x、y_ 占位
  • 调用forward.forward()模块复现前向传播的网络结构,用于计算求算 y
  • 定义轮数计数器global_step
  • 定义损失函数loss
# 方案1 梯度下降
loss_mse = tf.reduce_mean(tf.square(y - y_))
loss = loss_mse + tf.add_n(tf.get_collection(‘losses‘))

# 方案2 交叉熵
ce = tf.nn.sparse_softmax_cross_entropy_with_logits(logits=y,labels=tf.argmax(y_,1))
cem = tf.reduce_mean(ce)
loss = cem + tf.add_n(tf.get_collection(‘losses‘))

代码功能

  • 损失函数正则化
  • 首先定义损失函数 loss,也即 y 与 y_ 差距的描述方式
  • 方案1为梯度下降、方案2为交叉熵
  • 正则化损失函数  loss = loss_mse/ cem + tf.add_n(tf.get_collection(‘losses‘))
  • losses 的值在 w 赋初值时会有体现

    w = tf.Variable(tf.random_normal(shape), dtype=tf.float32)
    tf.add_to_collection(‘losses‘, tf.contrib.layers.l2_regularizer(regularizer)(w))


learning_rate = tf.train.exponential_decay(
    RATE_BASE,
    global_step,
    数据集总样本数/ BATCH_SIZE,
    RATE_DECAY,
    staircase=True
)
train_step=tf.train.GradientDescentOptimizer(learning_rate).minimize(loss, global_step=global_step)

函数功能:

  • 利用指数衰减学习率,动态计算学习率
ema = tf.train.ExponentialMovingAverage(MOVING_AVERAGE_DECAY, global_step)
ema_op = ema.apply(tf.trainable_variables())
with tf.control_dependencies([train_step, ema_op]):
    train_op = tf.no_op(name=‘train‘)

函数功能:

  • 滑动平均

建立会话

with tf.Session() as sess:
    init_op = tf.global_variables_initializer()
    sess.run(init_op)
    for i in range(STEPS):
        sess.run(train_step, feed_dict={x:  , y_: })
        if i % 轮数 == 0:
            print()

with结构初始化所有参数,并调用训练函数,实现待优化参数训练过程。

运行文件是否为主文件

if name == ‘__main__‘:
    backward()

该部分用来判断 python 运行的文件是否为主文件。若是主文件,则执行 backword()函数。

正则化 - 避免过拟合,提高泛化性

指数学习率 - 加快优化效率

3 代码示例

代码总共分三个文件:

  • 生成数据集 generateds.py
  • 前向传播 forward.py
  • 反向传播 backward.py

3.1 生成数据集generateds.py

 1 # coding:utf-8
 2 # 导入模块 ,生成模拟数据集
 3 import numpy as np
 4 import matplotlib.pyplot as plt
 5 seed = 2
 6 def generateds():
 7     # 基于seed产生随机数
 8     rdm = np.random.RandomState(seed)
 9     # 随机数返回300行2列的矩阵,表示300组坐标点(x0,x1)作为输入数据集
10     X = rdm.randn(300,2)
11     # 从X这个300行2列的矩阵中取出一行,判断如果两个坐标的平方和小于2,给Y赋值1,其余赋值0
12     # 作为输入数据集的标签(正确答案)
13     Y_ = [int(x0*x0 + x1*x1 <2) for (x0,x1) in X]
14     # 遍历Y中的每个元素,1赋值‘red‘其余赋值‘blue‘,这样可视化显示时人可以直观区分
15     Y_c = [[‘red‘ if y else ‘blue‘] for y in Y_]
16     # 对数据集X和标签Y进行形状整理,第一个元素为-1表示跟随第二列计算,第二个元素表示多少列,可见X为两列,Y为1列
17     X = np.vstack(X).reshape(-1,2)
18     Y_ = np.vstack(Y_).reshape(-1,1)
19
20     # print(X)
21     # print(Y_)
22     # print(Y_c)
23     # # 用plt.scatter画出数据集X各行中第0列元素和第1列元素的点即各行的(x0,x1),用各行Y_c对应的值表示颜色(c是color的缩写)
24     # plt.scatter(X[:,0], X[:,1], c=np.squeeze(Y_c))
25     # plt.show()
26     return X, Y_, Y_c
27 # generateds()

将20、21、22、24、25、27行代码 “ 解禁 ” 就可以看到该文件所能得到的数据集(可视化)

运行的结果代码有

[[-4.16757847e-01 -5.62668272e-02]
 [-2.13619610e+00  1.64027081e+00]
 [-1.79343559e+00 -8.41747366e-01]
 [ 5.02881417e-01 -1.24528809e+00]
 [-1.05795222e+00 -9.09007615e-01]
 [ 5.51454045e-01  2.29220801e+00]
 [ 4.15393930e-02 -1.11792545e+00]
 [ 5.39058321e-01 -5.96159700e-01]
 [-1.91304965e-02  1.17500122e+00]
 [-7.47870949e-01  9.02525097e-03]
 [-8.78107893e-01 -1.56434170e-01]
 [ 2.56570452e-01 -9.88779049e-01]
 [-3.38821966e-01 -2.36184031e-01]
 [-6.37655012e-01 -1.18761229e+00]
 [-1.42121723e+00 -1.53495196e-01]
 [-2.69056960e-01  2.23136679e+00]
 [-2.43476758e+00  1.12726505e-01]
 [ 3.70444537e-01  1.35963386e+00]
 [ 5.01857207e-01 -8.44213704e-01]
 [ 9.76147160e-06  5.42352572e-01]
 [-3.13508197e-01  7.71011738e-01]
 [-1.86809065e+00  1.73118467e+00]
 [ 1.46767801e+00 -3.35677339e-01]
 [ 6.11340780e-01  4.79705919e-02]
 [-8.29135289e-01  8.77102184e-02]
 [ 1.00036589e+00 -3.81092518e-01]
 [-3.75669423e-01 -7.44707629e-02]
 [ 4.33496330e-01  1.27837923e+00]
 [-6.34679305e-01  5.08396243e-01]
 [ 2.16116006e-01 -1.85861239e+00]
 [-4.19316482e-01 -1.32328898e-01]
 [-3.95702397e-02  3.26003433e-01]
 [-2.04032305e+00  4.62555231e-02]
 [-6.77675577e-01 -1.43943903e+00]
 [ 5.24296430e-01  7.35279576e-01]
 [-6.53250268e-01  8.42456282e-01]
 [-3.81516482e-01  6.64890091e-02]
 [-1.09873895e+00  1.58448706e+00]
 [-2.65944946e+00 -9.14526229e-02]
 [ 6.95119605e-01 -2.03346655e+00]
 [-1.89469265e-01 -7.72186654e-02]
 [ 8.24703005e-01  1.24821292e+00]
 [-4.03892269e-01 -1.38451867e+00]
 [ 1.36723542e+00  1.21788563e+00]
 [-4.62005348e-01  3.50888494e-01]
 [ 3.81866234e-01  5.66275441e-01]
 [ 2.04207979e-01  1.40669624e+00]
 [-1.73795950e+00  1.04082395e+00]
 [ 3.80471970e-01 -2.17135269e-01]
 [ 1.17353150e+00 -2.34360319e+00]
 [ 1.16152149e+00  3.86078048e-01]
 [-1.13313327e+00  4.33092555e-01]
 [-3.04086439e-01  2.58529487e+00]
 [ 1.83533272e+00  4.40689872e-01]
 [-7.19253841e-01 -5.83414595e-01]
 [-3.25049628e-01 -5.60234506e-01]
 [-9.02246068e-01 -5.90972275e-01]
 [-2.76179492e-01 -5.16883894e-01]
 [-6.98589950e-01 -9.28891925e-01]
 [ 2.55043824e+00 -1.47317325e+00]
 [-1.02141473e+00  4.32395701e-01]
 [-3.23580070e-01  4.23824708e-01]
 [ 7.99179995e-01  1.26261366e+00]
 [ 7.51964849e-01 -9.93760983e-01]
 [ 1.10914328e+00 -1.76491773e+00]
 [-1.14421297e-01 -4.98174194e-01]
 [-1.06079904e+00  5.91666521e-01]
 [-1.83256574e-01  1.01985473e+00]
 [-1.48246548e+00  8.46311892e-01]
 [ 4.97940148e-01  1.26504175e-01]
 [-1.41881055e+00 -2.51774118e-01]
 [-1.54667461e+00 -2.08265194e+00]
 [ 3.27974540e+00  9.70861320e-01]
 [ 1.79259285e+00 -4.29013319e-01]
 [ 6.96197980e-01  6.97416272e-01]
 [ 6.01515814e-01  3.65949071e-03]
 [-2.28247558e-01 -2.06961226e+00]
 [ 6.10144086e-01  4.23496900e-01]
 [ 1.11788673e+00 -2.74242089e-01]
 [ 1.74181219e+00 -4.47500876e-01]
 [-1.25542722e+00  9.38163671e-01]
 [-4.68346260e-01 -1.25472031e+00]
 [ 1.24823646e-01  7.56502143e-01]
 [ 2.41439629e-01  4.97425649e-01]
 [ 4.10869262e+00  8.21120877e-01]
 [ 1.53176032e+00 -1.98584577e+00]
 [ 3.65053516e-01  7.74082033e-01]
 [-3.64479092e-01 -8.75979478e-01]
 [ 3.96520159e-01 -3.14617436e-01]
 [-5.93755583e-01  1.14950057e+00]
 [ 1.33556617e+00  3.02629336e-01]
 [-4.54227855e-01  5.14370717e-01]
 [ 8.29458431e-01  6.30621967e-01]
 [-1.45336435e+00 -3.38017777e-01]
 [ 3.59133332e-01  6.22220414e-01]
 [ 9.60781945e-01  7.58370347e-01]
 [-1.13431848e+00 -7.07420888e-01]
 [-1.22142917e+00  1.80447664e+00]
 [ 1.80409807e-01  5.53164274e-01]
 [ 1.03302907e+00 -3.29002435e-01]
 [-1.15100294e+00 -4.26522471e-01]
 [-1.48147191e-01  1.50143692e+00]
 [ 8.69598198e-01 -1.08709057e+00]
 [ 6.64221413e-01  7.34884668e-01]
 [-1.06136574e+00 -1.08516824e-01]
 [-1.85040397e+00  3.30488064e-01]
 [-3.15693210e-01 -1.35000210e+00]
 [-6.98170998e-01  2.39951198e-01]
 [-5.52949440e-01  2.99526813e-01]
 [ 5.52663696e-01 -8.40443012e-01]
 [-3.12270670e-01  2.14467809e+00]
 [ 1.21105582e-01 -8.46828752e-01]
 [ 6.04624490e-02 -1.33858888e+00]
 [ 1.13274608e+00  3.70304843e-01]
 [ 1.08580640e+00  9.02179395e-01]
 [ 3.90296450e-01  9.75509412e-01]
 [ 1.91573647e-01 -6.62209012e-01]
 [-1.02351498e+00 -4.48174823e-01]
 [-2.50545813e+00  1.82599446e+00]
 [-1.71406741e+00 -7.66395640e-02]
 [-1.31756727e+00 -2.02559359e+00]
 [-8.22453750e-02 -3.04666585e-01]
 [-1.59724130e-01  5.48946560e-01]
 [-6.18375485e-01  3.78794466e-01]
 [ 5.13251444e-01 -3.34844125e-01]
 [-2.83519516e-01  5.38424263e-01]
 [ 5.72509465e-02  1.59088487e-01]
 [-2.37440268e+00  5.85199353e-02]
 [ 3.76545911e-01 -1.35479764e-01]
 [ 3.35908395e-01  1.90437591e+00]
 [ 8.53644334e-02  6.65334278e-01]
 [-8.49995503e-01 -8.52341797e-01]
 [-4.79985112e-01 -1.01964910e+00]
 [-7.60113841e-03 -9.33830661e-01]
 [-1.74996844e-01 -1.43714343e+00]
 [-1.65220029e+00 -6.75661789e-01]
 [-1.06706712e+00 -6.52931145e-01]
 [-6.12094750e-01 -3.51262461e-01]
 [ 1.04547799e+00  1.36901602e+00]
 [ 7.25353259e-01 -3.59474459e-01]
 [ 1.49695179e+00 -1.53111111e+00]
 [-2.02336394e+00  2.67972576e-01]
 [-2.20644541e-03 -1.39291883e-01]
 [ 3.25654693e-02 -1.64056022e+00]
 [-1.15669917e+00  1.23403468e+00]
 [ 1.02818490e+00 -7.21879726e-01]
 [ 1.93315697e+00 -1.07079633e+00]
 [-5.71381608e-01  2.92432067e-01]
 [-1.19499989e+00 -4.87930544e-01]
 [-1.73071165e-01 -3.95346401e-01]
 [ 8.70840765e-01  5.92806797e-01]
 [-1.09929731e+00 -6.81530644e-01]
 [ 1.80066685e-01 -6.69310440e-02]
 [-7.87749540e-01  4.24753672e-01]
 [ 8.19885117e-01 -6.31118683e-01]
 [ 7.89059649e-01 -1.62167380e+00]
 [-1.61049926e+00  4.99939764e-01]
 [-8.34515207e-01 -9.96959687e-01]
 [-2.63388077e-01 -6.77360492e-01]
 [ 3.27067038e-01 -1.45535944e+00]
 [-3.71519124e-01  3.16096597e+00]
 [ 1.09951013e-01 -1.91352322e+00]
 [ 5.99820429e-01  5.49384465e-01]
 [ 1.38378103e+00  1.48349243e-01]
 [-6.53541444e-01  1.40883398e+00]
 [ 7.12061227e-01 -1.80071604e+00]
 [ 7.47598942e-01 -2.32897001e-01]
 [ 1.11064528e+00 -3.73338813e-01]
 [ 7.86146070e-01  1.94168696e-01]
 [ 5.86204098e-01 -2.03872918e-02]
 [-4.14408598e-01  6.73134124e-02]
 [ 6.31798924e-01  4.17592731e-01]
 [ 1.61517627e+00  4.25606211e-01]
 [ 6.35363758e-01  2.10222927e+00]
 [ 6.61264168e-02  5.35558351e-01]
 [-6.03140792e-01  4.19576292e-02]
 [ 1.64191464e+00  3.11697707e-01]
 [ 1.45116990e+00 -1.06492788e+00]
 [-1.40084545e+00  3.07525527e-01]
 [-1.36963867e+00  2.67033724e+00]
 [ 1.24845030e+00 -1.24572655e+00]
 [-1.67168774e-01 -5.76610930e-01]
 [ 4.16021749e-01 -5.78472626e-02]
 [ 9.31887358e-01  1.46833213e+00]
 [-2.21320943e-01 -1.17315562e+00]
 [ 5.62669078e-01 -1.64515057e-01]
 [ 1.14485538e+00 -1.52117687e-01]
 [ 8.29789046e-01  3.36065952e-01]
 [-1.89044051e-01 -4.49328601e-01]
 [ 7.13524448e-01  2.52973487e+00]
 [ 8.37615794e-01 -1.31682403e-01]
 [ 7.07592866e-01  1.14053878e-01]
 [-1.28089518e+00  3.09846277e-01]
 [ 1.54829069e+00 -3.15828043e-01]
 [-1.12590378e+00  4.88496666e-01]
 [ 1.83094666e+00  9.40175993e-01]
 [ 1.01871705e+00  2.30237829e+00]
 [ 1.62109298e+00  7.12683273e-01]
 [-2.08703629e-01  1.37617991e-01]
 [-1.03352168e-01  8.48350567e-01]
 [-8.83125561e-01  1.54538683e+00]
 [ 1.45840073e-01 -4.00106056e-01]
 [ 8.15206041e-01 -2.07492237e+00]
 [-8.34437391e-01 -6.57718447e-01]
 [ 8.20564332e-01 -4.89157001e-01]
 [ 1.42496703e+00 -4.46857897e-01]
 [ 5.21109431e-01 -7.08194380e-01]
 [ 1.15553059e+00 -2.54530459e-01]
 [ 5.18924924e-01 -4.92994911e-01]
 [-1.08654815e+00 -2.30917497e-01]
 [ 1.09801004e+00 -1.01787805e+00]
 [-1.52939136e+00 -3.07987737e-01]
 [ 7.80754356e-01 -1.05583964e+00]
 [-5.43883381e-01  1.84301739e-01]
 [-3.30675843e-01  2.87208202e-01]
 [ 1.18952814e+00  2.12015479e-02]
 [-6.54096803e-02  7.66115904e-01]
 [-6.16350846e-02 -9.52897152e-01]
 [-1.01446306e+00 -1.11526396e+00]
 [ 1.91260068e+00 -4.52632031e-02]
 [ 5.76909718e-01  7.17805695e-01]
 [-9.38998998e-01  6.28775807e-01]
 [-5.64493432e-01 -2.08780746e+00]
 [-2.15050132e-01 -1.07502856e+00]
 [-3.37972149e-01  3.43212732e-01]
 [ 2.28253964e+00 -4.95778848e-01]
 [-1.63962832e-01  3.71622161e-01]
 [ 1.86521520e-01 -1.58429224e-01]
 [-1.08292956e+00 -9.56625520e-01]
 [-1.83376735e-01 -1.15980690e+00]
 [-6.57768362e-01 -1.25144841e+00]
 [ 1.12448286e+00 -1.49783981e+00]
 [ 1.90201722e+00 -5.80383038e-01]
 [-1.05491567e+00 -1.18275720e+00]
 [ 7.79480054e-01  1.02659795e+00]
 [-8.48666001e-01  3.31539648e-01]
 [-1.49591353e-01 -2.42440600e-01]
 [ 1.51197175e-01  7.65069481e-01]
 [-1.91663052e+00 -2.22734129e+00]
 [ 2.06689897e-01 -7.08763560e-02]
 [ 6.84759969e-01 -1.70753905e+00]
 [-9.86569665e-01  1.54353634e+00]
 [-1.31027053e+00  3.63433972e-01]
 [-7.94872445e-01 -4.05286267e-01]
 [-1.37775793e+00  1.18604868e+00]
 [-1.90382114e+00 -1.19814038e+00]
 [-9.10065643e-01  1.17645419e+00]
 [ 2.99210670e-01  6.79267178e-01]
 [-1.76606800e-02  2.36040923e-01]
 [ 4.94035871e-01  1.54627765e+00]
 [ 2.46857508e-01 -1.46877580e+00]
 [ 1.14709994e+00  9.55569845e-02]
 [-1.10743873e+00 -1.76286141e-01]
 [-9.82755667e-01  2.08668273e+00]
 [-3.44623671e-01 -2.00207923e+00]
 [ 3.03234433e-01 -8.29874845e-01]
 [ 1.28876941e+00  1.34925462e-01]
 [-1.77860064e+00 -5.00791490e-01]
 [-1.08816157e+00 -7.57855553e-01]
 [-6.43744900e-01 -2.00878453e+00]
 [ 1.96262894e-01 -8.75896370e-01]
 [-8.93609209e-01  7.51902355e-01]
 [ 1.89693224e+00 -6.29079151e-01]
 [ 1.81208553e+00 -2.05626574e+00]
 [ 5.62704887e-01 -5.82070757e-01]
 [-7.40029749e-02 -9.86496364e-01]
 [-5.94722499e-01 -3.14811843e-01]
 [-3.46940532e-01  4.11443516e-01]
 [ 2.32639090e+00 -6.34053128e-01]
 [-1.54409962e-01 -1.74928880e+00]
 [-2.51957930e+00  1.39116243e+00]
 [-1.32934644e+00 -7.45596414e-01]
 [ 2.12608498e-02  9.10917515e-01]
 [ 3.15276082e-01  1.86620821e+00]
 [-1.82497623e-01 -1.82826634e+00]
 [ 1.38955717e-01  1.19450165e-01]
 [-8.18899200e-01 -3.32639265e-01]
 [-5.86387955e-01  1.73451634e+00]
 [-6.12751558e-01 -1.39344202e+00]
 [ 2.79433757e-01 -1.82223127e+00]
 [ 4.27017458e-01  4.06987749e-01]
 [-8.44308241e-01 -5.59820113e-01]
 [-6.00520405e-01  1.61487324e+00]
 [ 3.94953220e-01 -1.20381347e+00]
 [-1.24747243e+00 -7.75462496e-02]
 [-1.33397514e-02 -7.68323250e-01]
 [ 2.91234010e-01 -1.97330948e-01]
 [ 1.07682965e+00  4.37410232e-01]
 [-9.31978663e-02  1.35631416e-01]
 [-8.82708822e-01  8.84744194e-01]
 [ 3.83204463e-01 -4.16994149e-01]
 [ 1.17796550e-01 -5.36685309e-01]
 [ 2.48718458e+00 -4.51361054e-01]
 [ 5.18836127e-01  3.64448005e-01]
 [-7.98348729e-01  5.65779713e-03]
 [-3.20934708e-01  2.49513550e-01]
 [ 2.56308392e-01  7.67625083e-01]
 [ 7.83020087e-01 -4.07063047e-01]
 [-5.24891667e-01 -5.89808683e-01]
 [-8.62531086e-01 -1.74287290e+00]]
[[1]
 [0]
 [0]
 [1]
 [1]
 [0]
 [1]
 [1]
 [1]
 [1]
 [1]
 [1]
 [1]
 [1]
 [0]
 [0]
 [0]
 [1]
 [1]
 [1]
 [1]
 [0]
 [0]
 [1]
 [1]
 [1]
 [1]
 [1]
 [1]
 [0]
 [1]
 [1]
 [0]
 [0]
 [1]
 [1]
 [1]
 [0]
 [0]
 [0]
 [1]
 [0]
 [0]
 [0]
 [1]
 [1]
 [0]
 [0]
 [1]
 [0]
 [1]
 [1]
 [0]
 [0]
 [1]
 [1]
 [1]
 [1]
 [1]
 [0]
 [1]
 [1]
 [0]
 [1]
 [0]
 [1]
 [1]
 [1]
 [0]
 [1]
 [0]
 [0]
 [0]
 [0]
 [1]
 [1]
 [0]
 [1]
 [1]
 [0]
 [0]
 [1]
 [1]
 [1]
 [0]
 [0]
 [1]
 [1]
 [1]
 [1]
 [1]
 [1]
 [1]
 [0]
 [1]
 [1]
 [1]
 [0]
 [1]
 [1]
 [1]
 [0]
 [1]
 [1]
 [1]
 [0]
 [1]
 [1]
 [1]
 [1]
 [0]
 [1]
 [1]
 [1]
 [1]
 [1]
 [1]
 [1]
 [0]
 [0]
 [0]
 [1]
 [1]
 [1]
 [1]
 [1]
 [1]
 [0]
 [1]
 [0]
 [1]
 [1]
 [1]
 [1]
 [0]
 [0]
 [1]
 [1]
 [0]
 [1]
 [0]
 [0]
 [1]
 [0]
 [0]
 [1]
 [0]
 [1]
 [1]
 [1]
 [1]
 [1]
 [1]
 [1]
 [1]
 [0]
 [0]
 [1]
 [1]
 [0]
 [0]
 [0]
 [1]
 [1]
 [0]
 [0]
 [1]
 [1]
 [1]
 [1]
 [1]
 [1]
 [0]
 [0]
 [1]
 [1]
 [0]
 [0]
 [0]
 [0]
 [0]
 [1]
 [1]
 [0]
 [1]
 [1]
 [1]
 [1]
 [1]
 [0]
 [1]
 [1]
 [1]
 [0]
 [1]
 [0]
 [0]
 [0]
 [1]
 [1]
 [0]
 [1]
 [0]
 [1]
 [1]
 [0]
 [1]
 [1]
 [1]
 [1]
 [0]
 [0]
 [1]
 [1]
 [1]
 [1]
 [1]
 [1]
 [0]
 [0]
 [1]
 [1]
 [0]
 [1]
 [1]
 [0]
 [1]
 [1]
 [0]
 [1]
 [1]
 [0]
 [0]
 [0]
 [1]
 [1]
 [1]
 [1]
 [0]
 [1]
 [0]
 [0]
 [1]
 [1]
 [0]
 [0]
 [0]
 [1]
 [1]
 [0]
 [0]
 [1]
 [1]
 [0]
 [0]
 [1]
 [1]
 [0]
 [1]
 [0]
 [1]
 [1]
 [0]
 [0]
 [1]
 [1]
 [1]
 [1]
 [0]
 [0]
 [0]
 [0]
 [1]
 [0]
 [0]
 [1]
 [1]
 [0]
 [0]
 [0]
 [1]
 [1]
 [0]
 [1]
 [1]
 [1]
 [1]
 [1]
 [1]
 [1]
 [1]
 [1]
 [0]
 [1]
 [1]
 [1]
 [1]
 [1]
 [1]
 [0]]
[[‘red‘], [‘blue‘], [‘blue‘], [‘red‘], [‘red‘], [‘blue‘], [‘red‘], [‘red‘], [‘red‘], [‘red‘], [‘red‘], [‘red‘], [‘red‘], [‘red‘], [‘blue‘], [‘blue‘], [‘blue‘], [‘red‘], [‘red‘], [‘red‘], [‘red‘], [‘blue‘], [‘blue‘], [‘red‘], [‘red‘], [‘red‘], [‘red‘], [‘red‘], [‘red‘], [‘blue‘], [‘red‘], [‘red‘], [‘blue‘], [‘blue‘], [‘red‘], [‘red‘], [‘red‘], [‘blue‘], [‘blue‘], [‘blue‘], [‘red‘], [‘blue‘], [‘blue‘], [‘blue‘], [‘red‘], [‘red‘], [‘blue‘], [‘blue‘], [‘red‘], [‘blue‘], [‘red‘], [‘red‘], [‘blue‘], [‘blue‘], [‘red‘], [‘red‘], [‘red‘], [‘red‘], [‘red‘], [‘blue‘], [‘red‘], [‘red‘], [‘blue‘], [‘red‘], [‘blue‘], [‘red‘], [‘red‘], [‘red‘], [‘blue‘], [‘red‘], [‘blue‘], [‘blue‘], [‘blue‘], [‘blue‘], [‘red‘], [‘red‘], [‘blue‘], [‘red‘], [‘red‘], [‘blue‘], [‘blue‘], [‘red‘], [‘red‘], [‘red‘], [‘blue‘], [‘blue‘], [‘red‘], [‘red‘], [‘red‘], [‘red‘], [‘red‘], [‘red‘], [‘red‘], [‘blue‘], [‘red‘], [‘red‘], [‘red‘], [‘blue‘], [‘red‘], [‘red‘], [‘red‘], [‘blue‘], [‘red‘], [‘red‘], [‘red‘], [‘blue‘], [‘red‘], [‘red‘], [‘red‘], [‘red‘], [‘blue‘], [‘red‘], [‘red‘], [‘red‘], [‘red‘], [‘red‘], [‘red‘], [‘red‘], [‘blue‘], [‘blue‘], [‘blue‘], [‘red‘], [‘red‘], [‘red‘], [‘red‘], [‘red‘], [‘red‘], [‘blue‘], [‘red‘], [‘blue‘], [‘red‘], [‘red‘], [‘red‘], [‘red‘], [‘blue‘], [‘blue‘], [‘red‘], [‘red‘], [‘blue‘], [‘red‘], [‘blue‘], [‘blue‘], [‘red‘], [‘blue‘], [‘blue‘], [‘red‘], [‘blue‘], [‘red‘], [‘red‘], [‘red‘], [‘red‘], [‘red‘], [‘red‘], [‘red‘], [‘red‘], [‘blue‘], [‘blue‘], [‘red‘], [‘red‘], [‘blue‘], [‘blue‘], [‘blue‘], [‘red‘], [‘red‘], [‘blue‘], [‘blue‘], [‘red‘], [‘red‘], [‘red‘], [‘red‘], [‘red‘], [‘red‘], [‘blue‘], [‘blue‘], [‘red‘], [‘red‘], [‘blue‘], [‘blue‘], [‘blue‘], [‘blue‘], [‘blue‘], [‘red‘], [‘red‘], [‘blue‘], [‘red‘], [‘red‘], [‘red‘], [‘red‘], [‘red‘], [‘blue‘], [‘red‘], [‘red‘], [‘red‘], [‘blue‘], [‘red‘], [‘blue‘], [‘blue‘], [‘blue‘], [‘red‘], [‘red‘], [‘blue‘], [‘red‘], [‘blue‘], [‘red‘], [‘red‘], [‘blue‘], [‘red‘], [‘red‘], [‘red‘], [‘red‘], [‘blue‘], [‘blue‘], [‘red‘], [‘red‘], [‘red‘], [‘red‘], [‘red‘], [‘red‘], [‘blue‘], [‘blue‘], [‘red‘], [‘red‘], [‘blue‘], [‘red‘], [‘red‘], [‘blue‘], [‘red‘], [‘red‘], [‘blue‘], [‘red‘], [‘red‘], [‘blue‘], [‘blue‘], [‘blue‘], [‘red‘], [‘red‘], [‘red‘], [‘red‘], [‘blue‘], [‘red‘], [‘blue‘], [‘blue‘], [‘red‘], [‘red‘], [‘blue‘], [‘blue‘], [‘blue‘], [‘red‘], [‘red‘], [‘blue‘], [‘blue‘], [‘red‘], [‘red‘], [‘blue‘], [‘blue‘], [‘red‘], [‘red‘], [‘blue‘], [‘red‘], [‘blue‘], [‘red‘], [‘red‘], [‘blue‘], [‘blue‘], [‘red‘], [‘red‘], [‘red‘], [‘red‘], [‘blue‘], [‘blue‘], [‘blue‘], [‘blue‘], [‘red‘], [‘blue‘], [‘blue‘], [‘red‘], [‘red‘], [‘blue‘], [‘blue‘], [‘blue‘], [‘red‘], [‘red‘], [‘blue‘], [‘red‘], [‘red‘], [‘red‘], [‘red‘], [‘red‘], [‘red‘], [‘red‘], [‘red‘], [‘red‘], [‘blue‘], [‘red‘], [‘red‘], [‘red‘], [‘red‘], [‘red‘], [‘red‘], [‘blue‘]]

数据集的数据内容

3.2 前向传播forward.py

# coding:utf-8
# 导入模块 ,生成模拟数据集
import tensorflow as tf

# 定义神经网络的输入、参数和输出,定义前向传播过程
def get_weight(shape, regularizer):
    w = tf.Variable(tf.random_normal(shape), dtype=tf.float32)
    tf.add_to_collection(‘losses‘, tf.contrib.layers.l2_regularizer(regularizer)(w))
    return w

def get_bias(shape):
    b = tf.Variable(tf.constant(0.01, shape=shape))
    return b

def forward(x, regularizer):
    w1 = get_weight([2, 11], regularizer)
    b1 = get_bias([11])
    y1 = tf.nn.relu(tf.matmul(x, w1) + b1)

    w2 = get_weight([11, 1], regularizer)
    b2 = get_bias([1])
    y = tf.matmul(y1, w2) + b2 

    return y

3.3 反向传播过程backward.py

# coding:utf-8
# 0导入模块 ,生成模拟数据集
import tensorflow as tf
import numpy as np
import matplotlib.pyplot as plt
import generateds
import forward

STEPS = 40000
BATCH_SIZE = 30
LEARNING_RATE_BASE = 0.001
LEARNING_RATE_DECAY = 0.999
REGULARIZER = 0.01

def backward():
    x = tf.placeholder(tf.float32, shape=(None, 2))
    y_ = tf.placeholder(tf.float32, shape=(None, 1))

    X, Y_, Y_c = generateds.generateds()

    y = forward.forward(x, REGULARIZER)

    global_step = tf.Variable(0,trainable=False)

    learning_rate = tf.train.exponential_decay(
        LEARNING_RATE_BASE,
        global_step,
        300/BATCH_SIZE,
        LEARNING_RATE_DECAY,
        staircase=True)

    # 定义损失函数
    loss_mse = tf.reduce_mean(tf.square(y-y_))
    loss_total = loss_mse + tf.add_n(tf.get_collection(‘losses‘))

    # 定义反向传播方法:包含正则化
    train_step = tf.train.AdamOptimizer(learning_rate).minimize(loss_total)

    with tf.Session() as sess:
        init_op = tf.global_variables_initializer()
        sess.run(init_op)
        for i in range(STEPS):
            start = (i*BATCH_SIZE) % 300
            end = start + BATCH_SIZE
            sess.run(train_step, feed_dict={x: X[start:end], y_:Y_[start:end]})
            if i % 2000 == 0:
                loss_v = sess.run(loss_total, feed_dict={x:X,y_:Y_})
                print("After %d steps, loss is: %f" %(i, loss_v))

        xx, yy = np.mgrid[-3:3:.01, -3:3:.01]
        grid = np.c_[xx.ravel(), yy.ravel()]
        probs = sess.run(y, feed_dict={x:grid})
        probs = probs.reshape(xx.shape)

    plt.scatter(X[:,0], X[:,1], c=np.squeeze(Y_c))
    plt.contour(xx, yy, probs, levels=[.5])
    plt.show()

if __name__==‘__main__‘:
    backward()

运行

原文地址:https://www.cnblogs.com/gengyi/p/9905226.html

时间: 2024-10-12 06:20:32

神经网络优化(二) - 搭建神经网络八股的相关文章

【零基础】神经网络优化之Adam

一.序言 Adam是神经网络优化的另一种方法,有点类似上一篇中的“动量梯度下降”,实际上是先提出了RMSprop(类似动量梯度下降的优化算法),而后结合RMSprop和动量梯度下降整出了Adam,所以这里我们先由动量梯度下降引申出RMSprop,最后再介绍Adam.不过,由于RMSprop.Adam什么的,真的太难理解了,我就只说实现不说原理了. 二.RMSprop 先回顾一下动量梯度下降中的“指数加权平均”公式: vDW1 = beta*vDW0 + (1-beta)*dw1 vDb1 = b

(转)一文学会用 Tensorflow 搭建神经网络

一文学会用 Tensorflow 搭建神经网络 本文转自:http://www.jianshu.com/p/e112012a4b2d 字数2259 阅读3168 评论8 喜欢11 cs224d-Day 6: 快速入门 Tensorflow 本文是学习这个视频课程系列的笔记,课程链接是 youtube 上的,讲的很好,浅显易懂,入门首选, 而且在github有代码,想看视频的也可以去他的优酷里的频道找. Tensorflow 官网 神经网络是一种数学模型,是存在于计算机的神经系统,由大量的神经元相

快速搭建神经网络

原来-方式一: # class Net(torch.nn.Module): # 继承 torch 的 Module # def __init__(self, n_feature, n_hidden, n_output): # super(Net, self).__init__() # 继承 __init__ 功能 # self.hidden = torch.nn.Linear(n_feature, n_hidden) # 隐藏层线性输出 # self.out = torch.nn.Linear(

从图(Graph)到图卷积(Graph Convolution):漫谈图神经网络模型 (二)

本文属于图神经网络的系列文章,文章目录如下: 从图(Graph)到图卷积(Graph Convolution):漫谈图神经网络模型 (一) 从图(Graph)到图卷积(Graph Convolution):漫谈图神经网络模型 (二) 从图(Graph)到图卷积(Graph Convolution):漫谈图神经网络模型 (三) 在上一篇博客中,我们简单介绍了基于循环图神经网络的两种重要模型,在本篇中,我们将着大量笔墨介绍图卷积神经网络中的卷积操作.接下来,我们将首先介绍一下图卷积神经网络的大概框架

AspectJ基础学习之二搭建环境(转载)

AspectJ基础学习之二搭建环境(转载) 一.下载Aspectj以及AJDT 上一章已经列出了他的官方网站,自己上去download吧.AJDT是一个eclipse插件,开发aspectj必装,他可以提供语法检查,以及编译.这里要说一点重要的知识: aspectj不能使用传统的JDK编译,他的编译器扩展自JDK.AJDT提供的编译功能,就为我们省了很多事,当然你也可以用命令行自己去编译(不过我从来没有这么做过). 无论是apsectj的安装,还是AJDT网上还是有很多文章讲的.不会的同学可以自

搭建神经网络

1 基本概念 基于Tensorflow 的NN:用张量表示数据,用计算图搭建神经网络,用会话执行计算图,优化线上的权重(参数),得到模型. 张量: 原文地址:https://www.cnblogs.com/gengyi/p/9813109.html

AWS研究热点:BMXNet – 基于MXNet的开源二进神经网络实现

http://www.atyun.com/9625.html 最近提出的二进神经网络(BNN)可以通过应用逐位运算替代标准算术运算来大大减少存储器大小和存取率.通过显着提高运行时的效率并降低能耗,让最先进的深度学习模型也能在低功耗设备上使用.这种技术结合了对开发者友好的OpenCL(与VHDL或Verilog相比),同时也让FPGA成为深度学习的可行选择. 在这篇文章中,我们要介绍BMXNet,它是一种基于Apache MXNet的开源BNN(二进神经网络)库.成熟的BNN层可以很好地应用于其他

神经网络优化(二) - 学习率

1 学习率的基本定义 学习率learning_rate:每次参数更新的幅度. 简单示例: 假设损失函数 loss = ( w + 1 )2,则梯度为 参数 w 初始化为 5 ,学习率为 0.2 ,则 运行次数 参数w值 计算 1次 5 5-0.2*(2*5+2) = 2.6 2次 2.6 2.6-0.2*(2*2.6+2) = 1.16 3次 1.16 1.16-0.2*(2*1.16+2) = 0.296 4次 0.296   2 学习率的初步应用 2.1  学习率 0.2 时 # 已知损失函

神经网络优化(二) - 滑动平均

1 滑动平均概述 滑动平均(也称为 影子值 ):记录了每一个参数一段时间内过往值的平均,增加了模型的泛化性. 滑动平均通常针对所有参数进行优化:W 和 b, 简单地理解,滑动平均像是给参数加了一个影子,参数变化,影子缓慢追随. 滑动平均的表示公式为 影子 = 衰减率 * 影子 + ( 1 - 衰减率 ) * 参数 或 滑动平均值 = 衰减率 * 滑动平均值 + ( 1 - 衰减率 )* 参数 备注 影子初值 = 参数初值 衰减率 = min{ MOVING_AVERAGE_DECAY, (1+轮