tensorflow bilstm官方示例

  1 ‘‘‘
  2 A Bidirectional Recurrent Neural Network (LSTM) implementation example using TensorFlow library.
  3 This example is using the MNIST database of handwritten digits (http://yann.lecun.com/exdb/mnist/)
  4 Long Short Term Memory paper: http://deeplearning.cs.cmu.edu/pdfs/Hochreiter97_lstm.pdf
  5
  6 Author: Aymeric Damien
  7 Project: https://github.com/aymericdamien/TensorFlow-Examples/
  8 ‘‘‘
  9
 10 from __future__ import print_function
 11
 12 import tensorflow as tf
 13 from tensorflow.contrib import rnn
 14 import numpy as np
 15
 16 # Import MNIST data
 17 from tensorflow.examples.tutorials.mnist import input_data
 18 mnist = input_data.read_data_sets("/tmp/data/", one_hot=True)
 19
 20 ‘‘‘
 21 To classify images using a bidirectional recurrent neural network, we consider
 22 every image row as a sequence of pixels. Because MNIST image shape is 28*28px,
 23 we will then handle 28 sequences of 28 steps for every sample.
 24 ‘‘‘
 25
 26 # Parameters
 27 learning_rate = 0.001
 28
 29 # 可以理解为,训练时总共用的样本数
 30 training_iters = 100000
 31
 32 # 每次训练的样本大小
 33 batch_size = 128
 34
 35 # 这个是用来显示的。
 36 display_step = 10
 37
 38 # Network Parameters
 39 # n_steps*n_input其实就是那张图 把每一行拆到每个time step上。
 40 n_input = 28 # MNIST data input (img shape: 28*28)
 41 n_steps = 28 # timesteps
 42
 43 # 隐藏层大小
 44 n_hidden = 128 # hidden layer num of features
 45 n_classes = 10 # MNIST total classes (0-9 digits)
 46
 47 # tf Graph input
 48 # [None, n_steps, n_input]这个None表示这一维不确定大小
 49 x = tf.placeholder("float", [None, n_steps, n_input])
 50 y = tf.placeholder("float", [None, n_classes])
 51
 52 # Define weights
 53 weights = {
 54     # Hidden layer weights => 2*n_hidden because of forward + backward cells
 55     ‘out‘: tf.Variable(tf.random_normal([2*n_hidden, n_classes]))
 56 }
 57 biases = {
 58     ‘out‘: tf.Variable(tf.random_normal([n_classes]))
 59 }
 60
 61
 62 def BiRNN(x, weights, biases):
 63
 64     # Prepare data shape to match `bidirectional_rnn` function requirements
 65     # Current data input shape: (batch_size, n_steps, n_input)
 66     # Required shape: ‘n_steps‘ tensors list of shape (batch_size, n_input)
 67
 68     # Unstack to get a list of ‘n_steps‘ tensors of shape (batch_size, n_input)
 69     # 变成了n_steps*(batch_size, n_input)
 70     x = tf.unstack(x, n_steps, 1)
 71
 72     # Define lstm cells with tensorflow
 73     # Forward direction cell
 74     lstm_fw_cell = rnn.BasicLSTMCell(n_hidden, forget_bias=1.0)
 75     # Backward direction cell
 76     lstm_bw_cell = rnn.BasicLSTMCell(n_hidden, forget_bias=1.0)
 77
 78     # Get lstm cell output
 79     try:
 80         outputs, _, _ = rnn.static_bidirectional_rnn(lstm_fw_cell, lstm_bw_cell, x,
 81                                               dtype=tf.float32)
 82     except Exception: # Old TensorFlow version only returns outputs not states
 83         outputs = rnn.static_bidirectional_rnn(lstm_fw_cell, lstm_bw_cell, x,
 84                                         dtype=tf.float32)
 85
 86     # Linear activation, using rnn inner loop last output
 87     return tf.matmul(outputs[-1], weights[‘out‘]) + biases[‘out‘]
 88
 89 pred = BiRNN(x, weights, biases)
 90
 91 # Define loss and optimizer
 92 # softmax_cross_entropy_with_logits:Measures the probability error in discrete classification tasks in which the classes are mutually exclusive
 93 # return a 1-D Tensor of length batch_size of the same type as logits with the softmax cross entropy loss.
 94 # reduce_mean就是对所有数值(这里没有指定哪一维)求均值。
 95 cost = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(logits=pred, labels=y))
 96 optimizer = tf.train.AdamOptimizer(learning_rate=learning_rate).minimize(cost)
 97
 98 # Evaluate model
 99 correct_pred = tf.equal(tf.argmax(pred,1), tf.argmax(y,1))
100 accuracy = tf.reduce_mean(tf.cast(correct_pred, tf.float32))
101
102 # Initializing the variables
103 init = tf.global_variables_initializer()
104
105 # Launch the graph
106 with tf.Session() as sess:
107     sess.run(init)
108     step = 1
109     # Keep training until reach max iterations
110     while step * batch_size < training_iters:
111         batch_x, batch_y = mnist.train.next_batch(batch_size)
112         # Reshape data to get 28 seq of 28 elements
113         batch_x = batch_x.reshape((batch_size, n_steps, n_input))
114         # Run optimization op (backprop)
115         sess.run(optimizer, feed_dict={x: batch_x, y: batch_y})
116         if step % display_step == 0:
117             # Calculate batch accuracy
118             acc = sess.run(accuracy, feed_dict={x: batch_x, y: batch_y})
119             # Calculate batch loss
120             loss = sess.run(cost, feed_dict={x: batch_x, y: batch_y})
121             print("Iter " + str(step*batch_size) + ", Minibatch Loss= " + 122                   "{:.6f}".format(loss) + ", Training Accuracy= " + 123                   "{:.5f}".format(acc))
124         step += 1
125     print("Optimization Finished!")
126
127     # Calculate accuracy for 128 mnist test images
128     test_len = 128
129     test_data = mnist.test.images[:test_len].reshape((-1, n_steps, n_input))
130     test_label = mnist.test.labels[:test_len]
131     print("Testing Accuracy:", 132         sess.run(accuracy, feed_dict={x: test_data, y: test_label}))

官方关于bilstm的例子写的很清楚了。因为是第一次看,还是要查许多东西。尤其是数据处理方面。

数据的处理(https://segmentfault.com/a/1190000008793389)

拼接

t1 = [[1, 2, 3], [4, 5, 6]]
t2 = [[7, 8, 9], [10, 11, 12]]
tf.concat([t1, t2], 0) ==> [[1, 2, 3], [4, 5, 6], [7, 8, 9], [10, 11, 12]]
tf.concat([t1, t2], 1) ==> [[1, 2, 3, 7, 8, 9], [4, 5, 6, 10, 11, 12]]
tf.stack([t1, t2], 0)  ==> [[[1, 2, 3], [4, 5, 6]], [[7, 8, 9], [10, 11, 12]]]
tf.stack([t1, t2], 1)  ==> [[[1, 2, 3], [7, 8, 9]], [[4, 5, 6], [10, 11, 12]]]
tf.stack([t1, t2], 2)  ==> [[[1, 7], [2, 8], [3, 9]], [[4, 10], [5, 11], [6, 12]]]

从shape的角度看:

t1 = [[1, 2, 3], [4, 5, 6]]
t2 = [[7, 8, 9], [10, 11, 12]]
tf.concat([t1, t2], 0)  # [2,3] + [2,3] ==> [4, 3]
tf.concat([t1, t2], 1)  # [2,3] + [2,3] ==> [2, 6]
tf.stack([t1, t2], 0)   # [2,3] + [2,3] ==> [2*,2,3]
tf.stack([t1, t2], 1)   # [2,3] + [2,3] ==> [2,2*,3]
tf.stack([t1, t2], 2)   # [2,3] + [2,3] ==> [2,3,2*]

抽取:

input = [[[1, 1, 1], [2, 2, 2]],
         [[3, 3, 3], [4, 4, 4]],
         [[5, 5, 5], [6, 6, 6]]]
tf.slice(input, [1, 0, 0], [1, 1, 3]) ==> [[[3, 3, 3]]]
tf.slice(input, [1, 0, 0], [1, 2, 3]) ==> [[[3, 3, 3],
                                            [4, 4, 4]]]
tf.slice(input, [1, 0, 0], [2, 1, 3]) ==> [[[3, 3, 3]],
                                           [[5, 5, 5]]]

tf.gather(input, [0, 2]) ==> [[[1, 1, 1], [2, 2, 2]],
                              [[5, 5, 5], [6, 6, 6]]]
时间: 2024-08-08 22:09:53

tensorflow bilstm官方示例的相关文章

水晶报表官方示例

原文:水晶报表官方示例 使用 C# 和 C++.NET 开发的 .NET 应用程序实例列表---------------------------------- 概述 本文档列出了 Crystal Decisions 技术支持网站上所有可用的,使用 C# 和 C++.NET 开发的 .NET 应用程序实例列表.本文档还给出了每一个程序的描述和下载链接.随着新程序加入我们的支持站点,本文档将不断更新.---------------------------------- 目录 VISUAL C# .N

html5游戏引擎phaser官方示例学习

首发:个人博客,更新&纠错&回复 phaser官方示例学习进行中,把官方示例调整为简明的目录结构,学习过程中加了点中文注释,代码在这里. 目前把官方的完整游戏示例看了一大半, breakout是敲砖块,gemmatch是钻石消除,invaders是小蜜蜂,matching是配对,simon是记忆游戏,sliding是拼图,starstruck类似超级马里奥,tanks是坦克游戏. 游戏场面上看,敲砖块.小蜜蜂是竖版,超级马里奥是横版,坦克游戏是俯瞰,钻石.配对.记忆.拼图这四个都是棋盘.

DELPHI XE5 跨平台 Form ShowModal 官方示例

Calling ShowModal as an Anonymous Method on All Platforms procedure THeaderFooterForm.btnPickClick(Sender: TObject); var dlg: TForm1; begin dlg := TForm1.Create(nil); // select current value, if available in the list dlg.ListBox1.ItemIndex := dlg.Lis

ngRx 官方示例分析 - 3. reducers

上一篇:ngRx 官方示例分析 - 2. Action 管理 这里我们讨论 reducer. 如果你注意的化,会看到再不同的 Action 定义文件中,导出的 String Literal Type 名称都是 Actions ,在导入的时候,同时导入同名的类型就是问题了.这里首先使用了 import as 语法进行重命名. import * as book from '../actions/book'; import * as collection from '../actions/collec

微信小程序「官方示例代码」剖析【下】:运行机制

在上一篇<微信小程序「官方示例代码」浅析[上]>中,我们只是简单的罗列了一下代码,这一篇,让我们来玩点刺激的--就是看看IDE的代码,了解它是怎么运行的. 还好微信的开发团队在软件工程的实践还有待提高,我们才有机会可以深入了解他们的代码--真想建议他们看看Growth的第二部分,构建系统. 解压应用 首先你需要有下面的工具啦 Mac电脑 微信web开发者工具.app WebStorm / 其他编程器 或 IDE,最好可以支持重命名 首先,我们需要右键微信web开发者工具.app,然后显示包的内

DotNetBar for Windows Forms 12.7.0.10_冰河之刃重打包版原创发布-带官方示例程序版

关于 DotNetBar for Windows Forms 12.7.0.10_冰河之刃重打包版 --------------------11.8.0.8_冰河之刃重打包版---------------------------------------------------------基于 官方原版的安装包 + http://www.cnblogs.com/tracky 提供的补丁DLL制作而成.安装之后,直接就可以用了.省心省事.不必再单独的打一次补丁包了.本安装包和补丁包一样都删除了官方自

DotNetBar for Windows Forms 12.5.0.2_冰河之刃重打包版原创发布-带官方示例程序版

关于 DotNetBar for Windows Forms 12.5.0.2_冰河之刃重打包版 --------------------11.8.0.8_冰河之刃重打包版--------------------------------------------------------- 基于 官方原版的安装包 + http://www.cnblogs.com/tracky 提供的补丁DLL制作而成. 安装之后,直接就可以用了. 省心省事.不必再单独的打一次补丁包了. 本安装包和补丁包一样都删除了

将百度坐标转换的javascript api官方示例改写成传统的回调函数形式

改写前: 百度地图中坐标转换的JavaScript API示例官方示例如下: var points = [new BMap.Point(116.3786889372559,39.90762965106183), new BMap.Point(116.38632786853032,39.90795884517671), new BMap.Point(116.39534009082035,39.907432133833574), new BMap.Point(116.40624058825688,3

DotNetBar for Windows Forms 12.2.0.7_冰河之刃重打包版原创发布-带官方示例程序版

关于 DotNetBar for Windows Forms 12.2.0.7_冰河之刃重打包版 --------------------11.8.0.8_冰河之刃重打包版---------------------------------------------------------基于 官方原版的安装包 + http://www.cnblogs.com/tracky 提供的补丁DLL制作而成.安装之后,直接就可以用了.省心省事.不必再单独的打一次补丁包了.本安装包和补丁包一样都删除了官方自带