吴裕雄--天生自然 pythonTensorFlow自然语言处理:Attention模型--测试

import sys
import codecs
import tensorflow as tf

# 1.参数设置。
# 读取checkpoint的路径。9000表示是训练程序在第9000步保存的checkpoint。
CHECKPOINT_PATH = "F:\\temp\\attention_ckpt-9000"

# 模型参数。必须与训练时的模型参数保持一致。
HIDDEN_SIZE = 1024                          # LSTM的隐藏层规模。
DECODER_LAYERS = 2                          # 解码器中LSTM结构的层数。
SRC_VOCAB_SIZE = 10000                      # 源语言词汇表大小。
TRG_VOCAB_SIZE = 4000                       # 目标语言词汇表大小。
SHARE_EMB_AND_SOFTMAX = True                # 在Softmax层和词向量层之间共享参数。

# 词汇表文件
SRC_VOCAB = "F:\\TensorFlowGoogle\\201806-github\\TensorFlowGoogleCode\\Chapter09\\en.vocab"
TRG_VOCAB = "F:\\TensorFlowGoogle\\201806-github\\TensorFlowGoogleCode\\Chapter09\\zh.vocab"

# 词汇表中<sos>和<eos>的ID。在解码过程中需要用<sos>作为第一步的输入,并将检查
# 是否是<eos>,因此需要知道这两个符号的ID。
SOS_ID = 1
EOS_ID = 2
# 2.定义NMT模型和解码步骤。
# 定义NMTModel类来描述模型。
class NMTModel(object):
    # 在模型的初始化函数中定义模型要用到的变量。
    def __init__(self):
        # 定义编码器和解码器所使用的LSTM结构。
        self.enc_cell_fw = tf.nn.rnn_cell.BasicLSTMCell(HIDDEN_SIZE)
        self.enc_cell_bw = tf.nn.rnn_cell.BasicLSTMCell(HIDDEN_SIZE)
        self.dec_cell = tf.nn.rnn_cell.MultiRNNCell([tf.nn.rnn_cell.BasicLSTMCell(HIDDEN_SIZE) for _ in range(DECODER_LAYERS)])

        # 为源语言和目标语言分别定义词向量。
        self.src_embedding = tf.get_variable("src_emb", [SRC_VOCAB_SIZE, HIDDEN_SIZE])
        self.trg_embedding = tf.get_variable("trg_emb", [TRG_VOCAB_SIZE, HIDDEN_SIZE])

        # 定义softmax层的变量
        if SHARE_EMB_AND_SOFTMAX:
            self.softmax_weight = tf.transpose(self.trg_embedding)
        else:
            self.softmax_weight = tf.get_variable("weight", [HIDDEN_SIZE, TRG_VOCAB_SIZE])
        self.softmax_bias = tf.get_variable("softmax_bias", [TRG_VOCAB_SIZE])

    def inference(self, src_input):
        # 虽然输入只有一个句子,但因为dynamic_rnn要求输入是batch的形式,因此这里
        # 将输入句子整理为大小为1的batch。
        src_size = tf.convert_to_tensor([len(src_input)], dtype=tf.int32)
        src_input = tf.convert_to_tensor([src_input], dtype=tf.int32)
        src_emb = tf.nn.embedding_lookup(self.src_embedding, src_input)

        with tf.variable_scope("encoder"):
            # 使用bidirectional_dynamic_rnn构造编码器。这一步与训练时相同。
            enc_outputs, enc_state = tf.nn.bidirectional_dynamic_rnn(self.enc_cell_fw, self.enc_cell_bw, src_emb, src_size, dtype=tf.float32)
            # 将两个LSTM的输出拼接为一个张量。
            enc_outputs = tf.concat([enc_outputs[0], enc_outputs[1]], -1)   

        with tf.variable_scope("decoder"):
            # 定义解码器使用的注意力机制。
            attention_mechanism = tf.contrib.seq2seq.BahdanauAttention(HIDDEN_SIZE, enc_outputs,memory_sequence_length=src_size)

            # 将解码器的循环神经网络self.dec_cell和注意力一起封装成更高层的循环神经网络。
            attention_cell = tf.contrib.seq2seq.AttentionWrapper(self.dec_cell, attention_mechanism,attention_layer_size=HIDDEN_SIZE)

        # 设置解码的最大步数。这是为了避免在极端情况出现无限循环的问题。
        MAX_DEC_LEN=100

        with tf.variable_scope("decoder/rnn/attention_wrapper"):
            # 使用一个变长的TensorArray来存储生成的句子。
            init_array = tf.TensorArray(dtype=tf.int32, size=0,dynamic_size=True, clear_after_read=False)
            # 填入第一个单词<sos>作为解码器的输入。
            init_array = init_array.write(0, SOS_ID)
            # 调用attention_cell.zero_state构建初始的循环状态。循环状态包含
            # 循环神经网络的隐藏状态,保存生成句子的TensorArray,以及记录解码
            # 步数的一个整数step。
            init_loop_var = (attention_cell.zero_state(batch_size=1, dtype=tf.float32),init_array, 0)

            # tf.while_loop的循环条件:
            # 循环直到解码器输出<eos>,或者达到最大步数为止。
            def continue_loop_condition(state, trg_ids, step):
                return tf.reduce_all(tf.logical_and(tf.not_equal(trg_ids.read(step), EOS_ID),tf.less(step, MAX_DEC_LEN-1)))

            def loop_body(state, trg_ids, step):
                # 读取最后一步输出的单词,并读取其词向量。
                trg_input = [trg_ids.read(step)]
                trg_emb = tf.nn.embedding_lookup(self.trg_embedding,trg_input)
                # 调用attention_cell向前计算一步。
                dec_outputs, next_state = attention_cell.call(state=state, inputs=trg_emb)
                # 计算每个可能的输出单词对应的logit,并选取logit值最大的单词作为
                # 这一步的而输出。
                output = tf.reshape(dec_outputs, [-1, HIDDEN_SIZE])
                logits = (tf.matmul(output, self.softmax_weight)+ self.softmax_bias)
                next_id = tf.argmax(logits, axis=1, output_type=tf.int32)
                # 将这一步输出的单词写入循环状态的trg_ids中。
                trg_ids = trg_ids.write(step+1, next_id[0])
                return next_state, trg_ids, step+1

            # 执行tf.while_loop,返回最终状态。
            state, trg_ids, step = tf.while_loop(continue_loop_condition, loop_body, init_loop_var)
            return trg_ids.stack()
# 3.翻译一个测试句子。
def main():
    # 定义训练用的循环神经网络模型。
    with tf.variable_scope("nmt_model", reuse=None):
        model = NMTModel()

    # 定义个测试句子。
    test_en_text = "This is a test . <eos>"
    print(test_en_text)

    # 根据英文词汇表,将测试句子转为单词ID。
    with codecs.open(SRC_VOCAB, "r", "utf-8") as f_vocab:
        src_vocab = [w.strip() for w in f_vocab.readlines()]
        src_id_dict = dict((src_vocab[x], x) for x in range(len(src_vocab)))
    test_en_ids = [(src_id_dict[token] if token in src_id_dict else src_id_dict[‘<unk>‘])for token in test_en_text.split()]
    print(test_en_ids)

    # 建立解码所需的计算图。
    output_op = model.inference(test_en_ids)
    sess = tf.Session()
    saver = tf.train.Saver()
    saver.restore(sess, CHECKPOINT_PATH)

    # 读取翻译结果。
    output_ids = sess.run(output_op)
    print(output_ids)

    # 根据中文词汇表,将翻译结果转换为中文文字。
    with codecs.open(TRG_VOCAB, "r", "utf-8") as f_vocab:
        trg_vocab = [w.strip() for w in f_vocab.readlines()]
    output_text = ‘‘.join([trg_vocab[x] for x in output_ids])

    # 输出翻译结果。
    print(output_text.encode(‘utf8‘).decode(sys.stdout.encoding))
    sess.close()

if __name__ == "__main__":
    main()

原文地址:https://www.cnblogs.com/tszr/p/12069844.html

时间: 2024-10-10 16:44:02

吴裕雄--天生自然 pythonTensorFlow自然语言处理:Attention模型--测试的相关文章

吴裕雄--天生自然 pythonTensorFlow自然语言处理:Seq2Seq模型--训练

import tensorflow as tf # 1.参数设置. # 假设输入数据已经用9.2.1小节中的方法转换成了单词编号的格式. SRC_TRAIN_DATA = "F:\\TensorFlowGoogle\\201806-github\\TensorFlowGoogleCode\\Chapter09\\train.en" # 源语言输入文件. TRG_TRAIN_DATA = "F:\\TensorFlowGoogle\\201806-github\\TensorF

吴裕雄--天生自然 pythonTensorFlow自然语言处理:文本数据预处理--生成词汇表

import codecs import collections from operator import itemgetter # 1. 设置参数. MODE = "PTB" # 将MODE设置为"PTB", "TRANSLATE_EN", "TRANSLATE_ZH"之一. if MODE == "PTB": # PTB数据处理 RAW_DATA = "F:\\TensorFlowGoogle

吴裕雄--天生自然 pythonTensorFlow自然语言处理:交叉熵损失函数

import tensorflow as tf # 1. sparse_softmax_cross_entropy_with_logits样例. # 假设词汇表的大小为3, 语料包含两个单词"2 0" word_labels = tf.constant([2, 0]) # 假设模型对两个单词预测时,产生的logit分别是[2.0, -1.0, 3.0]和[1.0, 0.0, -0.5] predict_logits = tf.constant([[2.0, -1.0, 3.0], [1

吴裕雄--天生自然 pythonTensorFlow自然语言处理:PTB 语言模型

import numpy as np import tensorflow as tf # 1.设置参数. TRAIN_DATA = "F:\TensorFlowGoogle\\201806-github\\TensorFlowGoogleCode\\Chapter09\\ptb.train" # 训练数据路径. EVAL_DATA = "F:\TensorFlowGoogle\\201806-github\\TensorFlowGoogleCode\\Chapter09\\p

吴裕雄--天生自然 pythonTensorFlow图形数据处理:读取MNIST手写图片数据写入的TFRecord文件

import numpy as np import tensorflow as tf from tensorflow.examples.tutorials.mnist import input_data # 读取文件. filename_queue = tf.train.string_input_producer(["F:\\output.tfrecords"]) reader = tf.TFRecordReader() _,serialized_example = reader.re

吴裕雄--天生自然 pythonTensorFlow图形数据处理:数据集高层操作

import tempfile import tensorflow as tf # 1. 列举输入文件. # 输入数据生成的训练和测试数据. train_files = tf.train.match_filenames_once("F:\\output.tfrecords") test_files = tf.train.match_filenames_once("F:\\output_test.tfrecords") # 定义解析TFRecord文件的parser方

吴裕雄--天生自然 高等数学学习:区间

区间.领域 自然数集——N 整数集——Z 有理数集——Q 实数集——R 建立数轴后,实数与数轴上的点一一对应. 建立某一实数集A与数轴上某一区间对应. 区间:设有数a,b,a<b,则称实数集{x|a<x<b}为一个开区间.记为:(a,b),即: (a,b)={x|a<x<b} a称为区间(a,b)的左端点. b称为区间(a,b)的右端点.   原文地址:https://www.cnblogs.com/tszr/p/11153079.html

吴裕雄--天生自然Linux操作系统:Linux 磁盘管理

Linux磁盘管理好坏直接关系到整个系统的性能问题. Linux磁盘管理常用三个命令为df.du和fdisk. df:列出文件系统的整体磁盘使用量 du:检查磁盘空间使用量 fdisk:用于磁盘分区 df df命令参数功能:检查文件系统的磁盘空间占用情况.可以利用该命令来获取硬盘被占用了多少空间,目前还剩下多少空间等信息. 语法: df [-ahikHTm] [目录或文件名] 选项与参数: -a :列出所有的文件系统,包括系统特有的 /proc 等文件系统: -k :以 KBytes 的容量显示

吴裕雄--天生自然 PHP开发学习:类型比较

<?php if(42 == "42") { echo '1.值相等'; } echo PHP_EOL; // 换行符 if(42 === "42") { echo '2.类型相等'; } else { echo '3.不相等'; } ?> 待完善... 原文地址:https://www.cnblogs.com/tszr/p/10936908.html