w = tf.Variable(tf.random_normal([2,3], stddev=2,mean=0, seed=1))
其中
tf.random_normal是正太分布
除了这个 还有tf.truncated_normal:去掉过大偏离点(大于2个标准差)的正态分布
tf.random_uniform:平均分布
[2,3]是生成2x3的矩阵
stddev是标准差
mean是均值
seed是随机数种子
构造其余量的方法:
#tf.zeros 全0数组 tf.zeros([3,2], int32) #生成 [[0,0], [0,0], [0,0]] #tf.ones 全1数组 tf.ones([3,2],int32) #生成 [[1,1], [1,1], [1,1]] #tf.fill 全定值数组 tf.fill([3,2],8) #生成 [[8,8], [8,8], [8,8]] #tf.constant 直接给值 tf.constant([3,2,1]) #生成[3,2,1]
在数组中[x,y]中 x即为有x个输入特征 y即为有y个输出特征
即如图 输入层有2个输入特征 而在隐藏层中有3个特征
所以数组为2x3
而最后隐藏层中 输出y 只有1个
所以隐藏层到输出层的权w即为3x1的数组
总结 即为
输入层X与2x3的权矩阵W1 相乘得到隐藏层a数据
隐藏层a数据与 3x1的权矩阵W2 相乘得到输出层y数据
代码过程:
import tensorflow as tf #定义输入和参数 x = tf.constant([[0.7, 0.5]]) w1 = tf.Variable(tf.random_normal([2, 3], stddev=1, seed=1)) w2 = tf.Variable(tf.random_normal([3, 1], stddev=1, seed=1)) #定义前向传播过程 a = tf.matmul(x, w1) y = tf.matmul(a, w2) #用会话计算结果 with tf.Session() as sess: #初始化所有节点变量 init_op = tf.global_variables_initializer() sess.run(init_op) print ("y is: ", sess.run(y))
得到结果:y is: [[3.0904665]]
使用placeholder添加数据:
import tensorflow as tf #定义输入和参数 #定义了一个数据类型为32位浮点,形状为1行2列的数组 x = tf.placeholder(tf.float32, shape=(1, 2)) w1 = tf.Variable(tf.random_normal([2, 3], stddev=1, seed=1)) w2 = tf.Variable(tf.random_normal([3, 1], stddev=1, seed=1)) #定义前向传播过程 a = tf.matmul(x, w1) y = tf.matmul(a, w2) #用会话计算结果 with tf.Session() as sess: #初始化所有节点变量 init_op = tf.global_variables_initializer() sess.run(init_op) print ("y is: ", sess.run(y, feed_dict={x: [[0.7, 0.5]]}))
添加多组数据:
import tensorflow as tf #定义输入和参数 #定义了一个数据类型为32位浮点,形状为1行2列的数组 x = tf.placeholder(tf.float32, shape=(None, 2)) w1 = tf.Variable(tf.random_normal([2, 3], stddev=1, seed=1)) w2 = tf.Variable(tf.random_normal([3, 1], stddev=1, seed=1)) #定义前向传播过程 a = tf.matmul(x, w1) y = tf.matmul(a, w2) #用会话计算结果 with tf.Session() as sess: #初始化所有节点变量 init_op = tf.global_variables_initializer() sess.run(init_op) print ("y is: ", sess.run(y, feed_dict={x: [[0.7, 0.5], [0.2,0.3], [0.3,0.4], [0.4,0.5]]})) print ("w1:", sess.run(w1)) print ("w2:", sess.run(w2))
得到结果:
y is: [[3.0904665]
[1.2236414]
[1.7270732]
[2.2305048]]
w1: [[-0.8113182 1.4845988 0.06532937]
[-2.4427042 0.0992484 0.5912243 ]]
w2: [[-0.8113182 ]
[ 1.4845988 ]
[ 0.06532937]]
原文地址:https://www.cnblogs.com/EatMedicine/p/9029287.html
时间: 2024-11-08 05:43:53