『TensorFlow』以GAN为例的神经网络类范式

1、导入包:

import os
import time
import math
from glob import glob
from PIL import Image
import tensorflow as tf
import numpy as np

import ops                    # 层函数封装包
import utils                  # 其他辅助函数

2、简单的临时辅助函数:

def conv_out_size_same(size, stride):
    # 对浮点数向上取整(大于f的最小整数)
    return int(math.ceil(float(size) / float(stride)))

3、声明类&初始化类:

示例没有使用到,实际上一般类属性也会用到

类属性&__init__初始化:用于接收参数生成低层次的属性值,数据读取或者数据名列表一般也会放在__init__中

class DCGAN():

    def __init__(self, sess,
                 input_height=108, input_width=108,
                 crop=True, batch_size=64, sample_num=64,
                 output_height=64, output_width=64,
                 z_dim=100, gf_dim=64,
                 df_dim=64, gfc_dim=1024,
                 dfc_dim=1024, c_dim=3,
                 dataset_name=‘default‘, input_fname_pattern=‘*.jpg‘,
                 checkpoint_dir=None, sample_dir=None):
        """
        Args:
            sess: TensorFlow session
            batch_size: The size of batch. Should be specified before training.
            z_dim: (optional) Dimension of dim for Z. [100]
            gf_dim: (optional) Dimension of gen filters in first conv layer. [64]
            df_dim: (optional) Dimension of discrim filters in first conv layer. [64]
            gfc_dim: (optional) Dimension of gen units for for fully connected layer. [1024]
            dfc_dim: (optional) Dimension of discrim units for fully connected layer. [1024]
            c_dim: (optional) Dimension of image color. For grayscale input, set to 1. [3]
        """
        self.sess = sess
        self.batch_size = batch_size
        self.sample_num = sample_num

        # crop输入输出尺寸
        # crop为True则output尺寸为网络输入尺寸
        # crop为False则input直接进入网络输入层
        self.crop = crop
        self.input_height = input_height
        self.input_width = input_width
        self.output_height = output_height
        self.output_width = output_width

        self.z_dim = z_dim

        self.gf_dim = gf_dim
        self.df_dim = df_dim

        self.dfc_dim = dfc_dim
        self.gfc_dim = gfc_dim

        self.g_bn0 = ops.batch_norm(name=‘g_bn0‘)
        self.g_bn1 = ops.batch_norm(name=‘g_bn1‘)
        self.g_bn2 = ops.batch_norm(name=‘g_bn2‘)
        self.g_bn3 = ops.batch_norm(name=‘g_bn3‘)

        self.d_bn1 = ops.batch_norm(name=‘d_bn1‘)
        self.d_bn2 = ops.batch_norm(name=‘d_bn2‘)
        self.d_bn3 = ops.batch_norm(name=‘d_bn3‘)

        ‘‘‘读取数据‘‘‘
        self.dataset_name = dataset_name
        self.input_fname_pattern = input_fname_pattern
        self.checkpoint_dir = checkpoint_dir

        self.data = glob(os.path.join(‘./data‘, self.dataset_name, self.input_fname_pattern))  # 载入所有图片

        ‘‘‘读取一张图片判断通道数目‘‘‘
        imreadImg = np.asarray(Image.open(self.data[0]))
        if len(imreadImg.shape) >= 3:
            self.c_dim = imreadImg.shape[-1]
        else:
            self.c_dim = 1

        self.grayscale = (self.c_dim == 1)

4、网络结构生成:

由于GAN的特殊性,被拆分了build_model(self)作为主干,discriminator(self,image,reuse=False)和generator(self,z)作为模组,这一过程包含了由数据进入网络到loss函数计算的整个流程

    def build_model(self):

        if self.crop:
            image_dims = [self.output_height, self.output_width, self.c_dim]
        else:
            image_dims = [self.input_height, self.input_width, self.c_dim]

        ‘‘‘数据输入层‘‘‘
        self.input_layer = tf.placeholder(tf.float32, [self.batch_size].extend(image_dims), name=‘input_layer‘)
        inputs = self.input_layer

        self.z = tf.placeholder(tf.float32, [None, self.z_dim], name=‘z‘)
        self.z_sum = tf.summary.histogram(‘z‘, self.z)

        ‘‘‘主要计算节点‘‘‘
        # 生成
        self.G                  = self.generator(self.z)
        self.D, self.D_logits   = self.discriminator(inputs, reuse=False)
        self.sampler            = self.sampler(self.z)
        self.D_, self.D_logits_ = self.discriminator(self.G, reuse=True)

        # 记录
        self.G_sum = tf.summary.image(‘G‘, self.G)
        self.D_sum = tf.summary.histogram(‘D‘, self.D)
        self.D__sum = tf.summary.histogram(‘D_‘, self.D_)

        ‘‘‘损失函数‘‘‘
        # 构建
        self.d_loss_real = tf.reduce_mean(
            tf.nn.sigmoid_cross_entropy_with_logits(self.D_logits,tf.ones_like(self.D)))
        self.d_loss_fake = tf.reduce_mean(
            tf.nn.sigmoid_cross_entropy_with_logits(self.D_logits_,tf.zeros_like(self.D_)))
        self.g_loss = tf.reduce_mean(
            tf.nn.sigmoid_cross_entropy_with_logits(self.D_logits_,tf.ones_like(self.D_)))
        self.d_loss = self.d_loss_real + self.d_loss_fake

        # 记录
        self.d_loss_real_sum = tf.Summary.scalar("d_loss_real",self.d_loss_real)
        self.d_loss_fake_sum = tf.Summary.scalar("d_loss_fake",self.d_loss_fake)
        self.g_loss_sum = tf.Summary.scalar("g_loss",self.g_loss)
        self.d_loss_sum = tf.Summary.scalar("d_loss",self.d_loss)

        # 训练参数分离
        t_vars = tf.trainable_variables()
        self.d_vars = [var for var in t_vars if ‘d_‘ in var.name]
        self.g_vars = [var for var in t_vars if ‘g_‘ in var.name]

        # 保存器类
        self.saver = tf.train.Saver()

    def discriminator(self,image,reuse=False):
        with tf.variable_scope(‘discriminator‘, reuse=reuse):
            h0 = ops.lrelu(ops.conv2d(image,self.df_dim,name=‘d_h0_conv‘))
            h1 = ops.lrelu(self.d_bn1(ops.conv2d(h0,self.df_dim * 2,name=‘d_h1_conv‘)))
            h2 = ops.lrelu(self.d_bn2(ops.conv2d(h1,self.df_dim * 4,name=‘d_h2_conv‘)))
            h3 = ops.lrelu(self.d_bn3(ops.conv2d(h2,self.df_dim * 8,name=‘d_h3_conv‘)))
            h4 = ops.linear(tf.reshape(h3,[self.batch_size,-1]),1,‘d_h4_lin‘)

        return tf.nn.sigmoid(h4),h4

    def generator(self,z):
        with tf.variable_scope(‘generator‘):
            s_h, s_w = self.output_height, self.output_width                        # 生成图片大小
            s_h2,s_w2 = conv_out_size_same(s_h,2),conv_out_size_same(s_w,2)
            s_h4,s_w4 = conv_out_size_same(s_h2,2),conv_out_size_same(s_w2,2)
            s_h8,s_w8 = conv_out_size_same(s_h4,2),conv_out_size_same(s_w4,2)
            s_h16,s_w16 = conv_out_size_same(s_h8,2),conv_out_size_same(s_w8,2)

            # batch_size不变,h、w每层扩大一倍,c每层缩小一半

            # 线性层
            self.z_,self.h0_w,self.h0_b = ops.linear(z,self.gf_dim * 8 * s_h16 * s_w16,‘g_h0_lin‘,with_w=True)
            self.h0 = tf.reshape(self.z_,[-1,s_h16,s_w16,self.gf_dim * 8])
            h0 = tf.nn.relu(self.g_bn0(self.h0))

            # 转置卷积层
            self.h1,self.h1_w,self.h1_b = ops.deconv2d(h0,[self.batch_size,s_h8,s_w8,self.gf_dim * 4],name=‘g_h1‘,with_w=True)
            h1 = tf.nn.relu(self.g_bn1(self.h1))

            h2,self.h2_w,self.h2_b = ops.deconv2d(h1,[self.batch_size,s_h4,s_w4,self.gf_dim * 2],name=‘g_h2‘,with_w=True)
            h2 = tf.nn.relu(self.g_bn2(h2))

            h3,self.h3_w,self.h3_b = ops.deconv2d(h2,[self.batch_size,s_h2,s_w2,self.gf_dim * 1],name=‘g_h3‘,with_w=True)
            h3 = tf.nn.relu(self.g_bn3(h3))

            h4,self.h4_w,self.h4_b = ops.deconv2d(h3,[self.batch_size,s_h,s_w,self.c_dim],name=‘g_h4‘,with_w=True)

        return tf.nn.tanh(h4)

5、预测部分:

一般网络用于predict标签的部分,对应到GAN就是生成仿真图片的位置,这里是不参与训练的

    def sampler(self,z):
        # 和生成器完全相同的结构且共享了变量,知识在正则化处is_training为False,这影响了滑动平均使用的两个部分
        with tf.variable_scope("generator") as scope:
            scope.reuse_variables()

            s_h,s_w = self.output_height,self.output_width
            s_h2,s_w2 = conv_out_size_same(s_h,2),conv_out_size_same(s_w,2)
            s_h4,s_w4 = conv_out_size_same(s_h2,2),conv_out_size_same(s_w2,2)
            s_h8,s_w8 = conv_out_size_same(s_h4,2),conv_out_size_same(s_w4,2)
            s_h16,s_w16 = conv_out_size_same(s_h8,2),conv_out_size_same(s_w8,2)

            h0 = tf.reshape(ops.linear(z,self.gf_dim * 8 * s_h16 * s_w16,‘g_h0_lin‘), [-1,s_h16,s_w16,self.gf_dim * 8])
            h0 = tf.nn.relu(self.g_bn0(h0,train=False))

            h1 = ops.deconv2d(h0,[self.batch_size,s_h8,s_w8,self.gf_dim * 4],name=‘g_h1‘)
            h1 = tf.nn.relu(self.g_bn1(h1,train=False))

            h2 = ops.deconv2d(h1,[self.batch_size,s_h4,s_w4,self.gf_dim * 2],name=‘g_h2‘)
            h2 = tf.nn.relu(self.g_bn2(h2,train=False))

            h3 = ops.deconv2d(h2,[self.batch_size,s_h2,s_w2,self.gf_dim * 1],name=‘g_h3‘)
            h3 = tf.nn.relu(self.g_bn3(h3,train=False))

            h4 = ops.deconv2d(h3,[self.batch_size,s_h,s_w,self.c_dim],name=‘g_h4‘)

6、训练部分:

超级麻烦的部分,

  • 构建优化器
  • 载入上次训练的结果
  • 迭代训练
    • 读取batch_size数据
    • feed进网络训练
    • 输出中间参量辅助查看
    • 保存模型
    def train(self,config):
        # 辨别器优化(总)
        d_optim = tf.train.AdamOptimizer(config.learning_rate,beta1=config.beta1)             .minimize(self.d_loss,var_list=self.d_vars)
        # 生成器优化
        g_optim = tf.train.AdamOptimizer(config.learning_rate,beta1=config.beta1)             .minimize(self.g_loss,var_list=self.g_vars)

        tf.global_variables_initializer().run()

        # 记录各个值迭代的变化
        self.g_sum = tf.Summary.merge([self.z_sum,self.D__sum, self.G_sum,self.d_loss_fake_sum,self.g_loss_sum])
        self.d_sum = tf.summary.merge([self.z_sum,self.d_sum,self.d_loss_real_sum,self.d_loss_sum])

        self.writer = tf.Summary.Writer("./logs",self.sess.graph)

        # 读取sample_num张图片
        sample_files = self.data[0:self.sample_num]
        sample = [utils.get_image(sample_file,
                      input_height=self.input_height,
                      input_width=self.input_width,
                      resize_height=self.output_height,
                      resize_width=self.output_width,
                      crop=self.crop) for sample_file in sample_files]
        sample_inputs = np.array(sample).astype(np.float32)
        sample_z = np.random.uniform(-1,1,size=(self.sample_num,self.z_dim))

        counter = 1
        start_time = time.time()
        could_load,checkpoint_counter = self.load(self.checkpoint_dir)

        # 载入model继续训练
        if could_load:
            counter = checkpoint_counter
            print(" [*] Load SUCCESS")
        else:
            print(" [!] Load failed...")

        for epoch in range(config.epoch):
            self.data = glob(os.path.join(
                "./data",config.dataset,self.input_fname_pattern))
            batch_idxs = min(len(self.data),config.train_size) // config.batch_size
            for idx in range(0,batch_idxs):

                # 读取batch图片x
                batch_files = self.data[idx * config.batch_size:(idx + 1) * config.batch_size]
                batch = [
                    utils.get_image(batch_file,
                              input_height=self.input_height,
                              input_width=self.input_width,
                              resize_height=self.output_height,
                              resize_width=self.output_width,
                              crop=self.crop) for batch_file in batch_files]
                batch_images = np.array(batch).astype(np.float32)

                # 生成噪声z
                batch_z = np.random.uniform(-1,1,[config.batch_size,self.z_dim])                     .astype(np.float32)

                # Update D network
                _,summary_str = self.sess.run([d_optim,self.d_sum],
                                              feed_dict={self.input_layer: batch_images,self.z: batch_z})
                self.writer.add_summary(summary_str,counter)

                # Update G network
                _,summary_str = self.sess.run([g_optim,self.g_sum],
                                              feed_dict={self.z: batch_z})
                self.writer.add_summary(summary_str,counter)                # 书写器书写的并不是一般意义上的记录而是普通的标量值

                # Update G network
                # Run g_optim twice to make sure that d_loss does not go to zero (different from paper)
                _,summary_str = self.sess.run([g_optim,self.g_sum],
                                              feed_dict={self.z: batch_z})
                self.writer.add_summary(summary_str,counter)

                # run损失值
                errD_fake = self.d_loss_fake.eval({self.z: batch_z})
                errD_real = self.d_loss_real.eval({self.input_layer: batch_images})
                errG = self.g_loss.eval({self.z: batch_z})

                counter += 1
                print("Epoch: [%2d] [%4d/%4d] time: %4.4f, d_loss: %.8f, g_loss: %.8f"                       % (epoch,idx,batch_idxs,
                         time.time() - start_time,errD_fake + errD_real,errG))
                if np.mod(counter,100) == 1:
                    try:
                        samples,d_loss,g_loss = self.sess.run(
                            [self.sampler,self.d_loss,self.g_loss],
                            feed_dict={
                                self.z: sample_z,
                                self.input_layer: sample_inputs,
                            },
                        )
                        utils.save_images(samples,utils.image_manifold_size(samples.shape[0]),
                                    ‘./{}/train_{:02d}_{:04d}.png‘.format(config.sample_dir,epoch,idx))
                        print("[Sample] d_loss: %.8f, g_loss: %.8f" % (d_loss,g_loss))
                    except:
                        print("one pic error!...")
                if np.mod(counter,500) == 2:
                    self.save(config.checkpoint_dir,counter)

保存&载入模型的一个demo

个人感觉功能有点臃肿,不过还是很值得借鉴的,

比如使用装饰器把函数隐藏成属性这个我就感觉很没必要,毕竟都是自家内部调用... ...

检查文件夹时的固定搭配这个就很不错:

if not os.path.exists(checkpoint_dir):
            os.makedirs(checkpoint_dir)

作者为了跑不同的数据集在文件名归类上下了一番功夫,所以load模块比较复杂,所以适当的多给了一些注释

    ‘‘‘模型保存&载入‘‘‘

    # checkpoint_dir/datasetname_batchsize_outputheight_outputwidth/模型
    @property
    def model_dir(self):
        return "{}_{}_{}_{}".format(
            self.dataset_name,self.batch_size,
            self.output_height,self.output_width)

    def save(self,checkpoint_dir,step):
        model_name = "DCGAN.model"
        checkpoint_dir = os.path.join(checkpoint_dir,self.model_dir)

        if not os.path.exists(checkpoint_dir):
            os.makedirs(checkpoint_dir)

        self.saver.save(self.sess,
                        os.path.join(checkpoint_dir,model_name),
                        global_step=step)

    def load(self,checkpoint_dir):
        import re
        print(" [*] Reading checkpoints...")
        checkpoint_dir = os.path.join(checkpoint_dir,self.model_dir)                  # 合并模型根路径和数据集路径
        ckpt = tf.train.get_checkpoint_state(checkpoint_dir)                          # 模型保存文件夹->最新模型文件名
        if ckpt and ckpt.model_checkpoint_path:
            ckpt_name = os.path.basename(ckpt.model_checkpoint_path)                  # 提取无路径模型文件名,感觉没有必要,checkpoint保存的名字本身就是不带路径的
            self.saver.restore(self.sess,os.path.join(checkpoint_dir,ckpt_name))      # 载入参数
            counter = int(next(re.finditer("(\d+)",ckpt_name)).group(0))              # 提取训练轮数
            print(" [*] Success to read {}".format(ckpt_name))
            return True,counter
        else:
            print(" [*] Failed to find a checkpoint")
        return False,0

附:脚本调用

import os
import pprint
import numpy as np
import tensorflow as tf

from model import DCGAN

# 接收命令行参数分三步

flags = tf.app.flags

flags.DEFINE_integer("epoch", 25, "Epoch to train [25]")
flags.DEFINE_float("learning_rate", 0.0002, "Learning rate of for adam [0.0002]")
flags.DEFINE_float("beta1", 0.5, "Momentum term of adam [0.5]")
flags.DEFINE_integer("train_size", np.inf, "The size of train images [np.inf]")
flags.DEFINE_integer("batch_size", 64, "The size of batch images [64]")
flags.DEFINE_integer("input_height", 108, "The size of image to use (will be center cropped). [108]")
flags.DEFINE_integer("input_width", None, "The size of image to use (will be center cropped). If None, same value as input_height [None]")
flags.DEFINE_integer("output_height", 64, "The size of the output images to produce [64]")
flags.DEFINE_integer("output_width", None, "The size of the output images to produce. If None, same value as output_height [None]")
flags.DEFINE_string("dataset", "celebA", "The name of dataset [celebA, mnist, lsun]")
flags.DEFINE_string("input_fname_pattern", "*.jpg", "Glob pattern of filename of input images [*]")
flags.DEFINE_string("checkpoint_dir", "checkpoint", "Directory name to save the checkpoints [checkpoint]")
flags.DEFINE_string("sample_dir", "samples", "Directory name to save the image samples [samples]")
flags.DEFINE_boolean("train", False, "True for training, False for testing [False]")
flags.DEFINE_boolean("crop", False, "True for training, False for testing [False]")
flags.DEFINE_boolean("visualize", False, "True for visualizing, False for nothing [False]")

FLAGS = flags.FLAGS

# 必须带参数,否则:‘TypeError: main() takes no arguments (1 given)‘;
# main的参数名随意定义,无要求
def main(_):
    # pprint模块,更美观的显示数据结构
    pp = pprint.PrettyPrinter()
    pp.pprint(flags.FLAGS.__flags)

    if FLAGS.input_width is None:
        FLAGS.input_width = FLAGS.input_height
    if FLAGS.output_width is None:
        FLAGS.output_width = FLAGS.output_height

    if not os.path.exists(FLAGS.checkpoint_dir):
        os.makedirs(FLAGS.checkpoint_dir)
    if not os.path.exists(FLAGS.sample_dir):
        os.makedirs(FLAGS.sample_dir)

    run_config = tf.ConfigProto()
    # TensorFlow占用gpu资源的默认方式异常贪婪,这里修改为按需求申请
    run_config.gpu_options.allow_growth = True
    # 下面的是按比例申请
    # run_config.gpu_options.per_process_gpu_memory_fraction=0.333

    with tf.Session(config=run_config) as sess:
        dcgan = DCGAN(
            sess,
            input_width=FLAGS.input_width,
            input_height=FLAGS.input_height,
            output_width=FLAGS.output_width,
            output_height=FLAGS.output_height,
            batch_size=FLAGS.batch_size,
            sample_num=FLAGS.batch_size,
            dataset_name=FLAGS.dataset,
            input_fname_pattern=FLAGS.input_fname_pattern,
            crop=FLAGS.crop,
            checkpoint_dir=FLAGS.checkpoint_dir,
            sample_dir=FLAGS.sample_dir)

    if FLAGS.train:
        dcgan.train(FLAGS)
    else:
        if not dcgan.load(FLAGS.checkpoint_dir)[0]:
            raise Exception("[!] Train a model first, then run test mode")

if __name__==‘__main__‘:
    tf.app.run()

预测部分没写好,所以没加上来,但是这不妨碍理解思路

值得一提的是dcgan.train(FLAGS),这里直接传入了FLAGS,对应内部train函数接收参数config,{config.参数名}这样的调用方法十分方便,这也有助于理解脚本化TF程序的便利之处『TensorFlow』脚本化使用方法

时间: 2024-08-25 01:05:17

『TensorFlow』以GAN为例的神经网络类范式的相关文章

『TensorFlow』读书笔记_简单卷积神经网络

网络结构 卷积层->池化层->卷积层->池化层->全连接层->Softmax分类器 卷积层激活函数使用relu 全连接层激活函数使用relu 池化层模式使用SAME,所以stride取2,且池化层和卷积层一样,通常设置为SAME模式,本模式下stride=2正好实现1/2变换 网络实现 # Author : Hellcat # Time : 2017/12/7 import tensorflow as tf from tensorflow.examples.tutorials

『TensorFlow』函数查询列表_神经网络相关

神经网络(Neural Network) 激活函数(Activation Functions) 操作 描述 tf.nn.relu(features, name=None) 整流函数:max(features, 0) tf.nn.relu6(features, name=None) 以6为阈值的整流函数:min(max(features, 0), 6) tf.nn.elu(features, name=None) elu函数,exp(features) - 1 if < 0,否则featuresE

『TensorFlow』迁移学习_他山之石,可以攻玉

目的: 使用google已经训练好的模型,将最后的全连接层修改为我们自己的全连接层,将原有的1000分类分类器修改为我们自己的5分类分类器,利用原有模型的特征提取能力实现我们自己数据对应模型的快速训练.实际中对于一个陌生的数据集,原有模型经过不高的迭代次数即可获得很好的准确率. 实战: 实机文件夹如下,两个压缩文件可以忽略: 花朵图片数据下载: 1 curl -O http://download.tensorflow.org/example_images/flower_photos.tgz 已经

『TensorFlow』常用函数实践笔记

查询列表: 『TensorFlow』函数查询列表_数值计算 『TensorFlow』函数查询列表_张量属性调整 『TensorFlow』函数查询列表_神经网络相关 经验之谈: 节点张量铺设好了之后,只要不加sess.run(),可以运行脚本检查张量节点是否匹配,无需传入实际数据流. 'conv1'指节点,'conv1:0'指节点输出的第一个张量. sess上下文环境中的函数调用即使不传入sess句柄,函数体内也存在于默认的sess环境中,可以直接sess.run(). image_holder

『TensorFlow』读书笔记_降噪自编码器

『TensorFlow』降噪自编码器设计 之前学习过的代码,又敲了一遍,新的收获也还是有的,因为这次注释写的比较详尽,所以再次记录一下,具体的相关知识查阅之前写的文章即可(见上面链接). # Author : Hellcat # Time : 2017/12/6 import numpy as np import sklearn.preprocessing as prep import tensorflow as tf from tensorflow.examples.tutorials.mni

『TensorFlow』slim高级模块

『TensorFlow』徒手装高达_主机体框架开光版_Google自家AlexNet集成&slim高级模块学习 辅助函数 slim.arg_scope() slim.arg_scope可以定义一些函数的默认参数值,在scope内,我们重复用到这些函数时可以不用把所有参数都写一遍,注意它没有tf.variable_scope()划分图结构的功能, with slim.arg_scope([slim.conv2d, slim.fully_connected], trainable=True, act

『TensorFlow』TFR数据预处理探究以及框架搭建

TFRecord文件书写效率对比(单线程和多线程对比) 准备工作, # Author : Hellcat # Time : 18-1-15 ''' import os os.environ["CUDA_VISIBLE_DEVICES"]="-1" ''' import os import glob import numpy as np import tensorflow as tf import matplotlib.pyplot as plt np.set_pri

『TensorFlow』分布式训练_其二_多GPU并行demo分析(待续)

建议比对『MXNet』第七弹_多GPU并行程序设计 models/tutorials/image/cifar10/cifer10_multi_gpu-train.py # Copyright 2015 The TensorFlow Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file exc

『PyTorch』第十弹_循环神经网络

『cs231n』作业3问题1选讲_通过代码理解RNN&图像标注训练 对于torch中的RNN相关类,有原始和原始Cell之分,其中RNN和RNNCell层的区别在于前者一次能够处理整个序列,而后者一次只处理序列中一个时间点的数据,前者封装更完备更易于使用,后者更具灵活性.实际上RNN层的一种后端实现方式就是调用RNNCell来实现的. 一.nn.RNN import torch as t from torch import nn from torch.autograd import Variab