『Pytorch』静动态图构建对比

对比TensorFlow和Pytorch的动静态图构建上的差异

静态图框架设计好了不能够修改,且定义静态图时需要使用新的特殊语法,这也意味着图设定时无法使用if、while、for-loop等结构,而是需要特殊的由框架专门设计的语法,在构建图时,我们需要考虑到所有的情况(即各个if分支图结构必须全部在图中,即使不一定会在每一次运行时使用到),使得静态图异常庞大占用过多显存。

以动态图没有这个顾虑,它兼容python的各种逻辑控制语法,最终创建的图取决于每次运行时的条件分支选择,下面我们对比一下TensorFlow和Pytorch的if条件分支构建图的实现:

# Author : Hellcat
# Time   : 2018/2/9

def tf_graph_if():
    import numpy as np
    import tensorflow as tf

    x = tf.placeholder(tf.float32, shape=(3, 4))
    z = tf.placeholder(tf.float32, shape=None)
    w1 = tf.placeholder(tf.float32, shape=(4, 5))
    w2 = tf.placeholder(tf.float32, shape=(4, 5))

    def f1():
        return tf.matmul(x, w1)

    def f2():
        return tf.matmul(x, w2)

    y = tf.cond(tf.less(z, 0), f1, f2)

    with tf.Session() as sess:
        y_out = sess.run(y, feed_dict={
            x: np.random.randn(3, 4),
            z: 10,
            w1: np.random.randn(4, 5),
            w2: np.random.randn(4, 5)})
    return y_out

def t_graph_if():
    import torch as t
    from torch.autograd import Variable

    x = Variable(t.randn(3, 4))
    w1 = Variable(t.randn(4, 5))
    w2 = Variable(t.randn(4, 5))

    z = 10
    if z > 0:
        y = x.mm(w1)
    else:
        y = x.mm(w2)

    return y

if __name__ == "__main__":
    print(tf_graph_if())
    print(t_graph_if())

计算输出如下:

[[ 4.0871315   0.90317607 -4.65211582  0.71610922 -2.70281982]
 [ 3.67874336 -0.58160967 -3.43737102  1.9781189  -2.18779659]
 [ 2.6638422  -0.81783319 -0.30386463 -0.61386991 -3.80232286]]
Variable containing:
-0.2474  0.1269  0.0830  3.4642  0.2255
 0.7555 -0.8057 -2.8159  3.7416  0.6230
 0.9010 -0.9469 -2.5086 -0.8848  0.2499
[torch.FloatTensor of size 3x5]

个人感觉上面的对比不太完美,如果使用TensorFlow的变量来对比,上面函数应该改写如下,

# Author : Hellcat
# Time   : 2018/2/9

def tf_graph_if():
    import tensorflow as tf

    x = tf.Variable(dtype=tf.float32, initial_value=tf.random_uniform(shape=[3, 4]))
    z = tf.constant(dtype=tf.float32, value=10)
    w1 = tf.Variable(dtype=tf.float32, initial_value=tf.random_uniform(shape=[4, 5]))
    w2 = tf.Variable(dtype=tf.float32, initial_value=tf.random_uniform(shape=[4, 5]))

    def f1():
        return tf.matmul(x, w1)

    def f2():
        return tf.matmul(x, w2)

    y = tf.cond(tf.less(z, 0), f1, f2)

    with tf.Session() as sess:
        sess.run(tf.global_variables_initializer())
        y_out = sess.run(y)
    return y_out

输出没什么变化,

[[ 1.89582038  1.12734962  0.59730953  0.99833554  0.86517167]
 [ 1.2659111   0.77320379  0.63649696  0.5804953   0.82271856]
 [ 1.92151642  1.64715886  1.19869363  1.31581473  1.5636673 ]]

可以看到,TensorFlow的if条件分支使用函数tf.cond(tf.less(z, 0), f1, f2)来实现,这和Pytorch直接使用if的逻辑很不同,而且,动态图不需要feed,直接运行便可。简单对比,可以看到Pytorch的逻辑更为简洁,让人很感兴趣。

原文地址:https://www.cnblogs.com/hellcat/p/8436955.html

时间: 2024-10-01 20:59:40

『Pytorch』静动态图构建对比的相关文章

『PyTorch』第四弹_通过LeNet初识pytorch神经网络_下

『PyTorch』第四弹_通过LeNet初识pytorch神经网络_上 # Author : Hellcat # Time : 2018/2/11 import torch as t import torch.nn as nn import torch.nn.functional as F class LeNet(nn.Module): def __init__(self): super(LeNet,self).__init__() self.conv1 = nn.Conv2d(3, 6, 5)

『PyTorch』第十二弹_nn.Module和nn.functional

大部分nn中的层class都有nn.function对应,其区别是: nn.Module实现的layer是由class Layer(nn.Module)定义的特殊类,会自动提取可学习参数nn.Parameter nn.functional中的函数更像是纯函数,由def function(input)定义. 由于两者性能差异不大,所以具体使用取决于个人喜好.对于激活函数和池化层,由于没有可学习参数,一般使用nn.functional完成,其他的有学习参数的部分则使用类.但是Droupout由于在训

『PyTorch』第五弹_深入理解autograd_下:Variable梯度探究

查看非叶节点梯度的两种方法 在反向传播过程中非叶子节点的导数计算完之后即被清空.若想查看这些变量的梯度,有两种方法: 使用autograd.grad函数 使用hook autograd.grad和hook方法都是很强大的工具,更详细的用法参考官方api文档,这里举例说明基础的使用.推荐使用hook方法,但是在实际使用中应尽量避免修改grad的值. 求z对y的导数 x = V(t.ones(3)) w = V(t.rand(3),requires_grad=True) y = w.mul(x) z

『PyTorch』第十弹_循环神经网络

『cs231n』作业3问题1选讲_通过代码理解RNN&图像标注训练 对于torch中的RNN相关类,有原始和原始Cell之分,其中RNN和RNNCell层的区别在于前者一次能够处理整个序列,而后者一次只处理序列中一个时间点的数据,前者封装更完备更易于使用,后者更具灵活性.实际上RNN层的一种后端实现方式就是调用RNNCell来实现的. 一.nn.RNN import torch as t from torch import nn from torch.autograd import Variab

『PyTorch』第十一弹_torch.optim优化器

一.简化前馈网络LeNet import torch as t class LeNet(t.nn.Module): def __init__(self): super(LeNet, self).__init__() self.features = t.nn.Sequential( t.nn.Conv2d(3, 6, 5), t.nn.ReLU(), t.nn.MaxPool2d(2, 2), t.nn.Conv2d(6, 16, 5), t.nn.ReLU(), t.nn.MaxPool2d(2

『ORACLE』 PLSQL动态游标的使用(11g)

#静态游标指的是程序执行的时候不需要再去解析sql语言,对于sql语句的解析在编译的时候就可以完成的. 动态游标由于含有参数,对于sql语句的解析必须要等到参数确定的时候才能完成. 从这个角度来说,静态游标的效率也比动态游标更高一些. #游标的相关概念: 定义: 游标它是一个服务器端的存储区,这个区域提供给用户使用,在这个区域里 存储的是用户通过一个查询语句得到的结果集,用户通过控制这个游标区域当中 的指针 来提取游标中的数据,然后来进行操作. 实质: 是用户在远程客户端上对服务器内存区域的操作

『PyTorch』第一弹_Linux系统下的安装记录

官网首页(http://pytorch.org/)是有安装教程的,但是点击之后没有反应,原因不明,所以不得不自己寻找一个安装方法. 安装参考如下: http://blog.csdn.net/amds123/article/details/69396953 由于我的机器使用Anaconda2.7内部嵌套了Anaconda3.6,而我更倾向于使用3.6版本(个人感觉使用3.x是大势所趋,且3.x的确比2.7方便不少),而我的cuda版本是8,所以我根据自己的情况记录一下安装流程: # 激活环境 so

『Pytorch』torch基本操作

Tensor基础操作 import torch as t Tensor基础操作 # 构建张量空间,不初始化 x = t.Tensor(5,3) x """ -2.4365e-20 -1.4335e-03 -2.4290e+25 -1.0283e-13 -2.8296e-07 -2.0769e+22 -1.3816e-33 -6.4672e-32 1.4497e-32 1.6020e-19 6.2625e+22 4.7428e+30 4.0095e-08 1.1943e-32

『PyTorch』第五弹_深入理解Tensor对象_下:从内存看Tensor

Tensor存储结构如下, 如图所示,实际上很可能多个信息区对应于同一个存储区,也就是上一节我们说到的,初始化或者普通索引时经常会有这种情况. 一.几种共享内存的情况 view a = t.arange(0,6) print(a.storage()) b = a.view(2,3) print(b.storage()) print(id(a.storage())==id(b.storage())) a[1] = 10 print(b) 上面代码,我们通过.storage()可以查询到Tensor