一、封装新的PyTorch函数
继承Function类
forward:输入Variable->中间计算Tensor->输出Variable
backward:均使用Variable
线性映射
from torch.autograd import Function class MultiplyAdd(Function): # <----- 类需要继承Function类 @staticmethod # <-----forward和backward都是静态方法 def forward(ctx, w, x, b): # <-----ctx作为内部参数在前向反向传播中协调 print(‘type in forward‘,type(x)) ctx.save_for_backward(w,x) # <-----ctx保存参数 output = w * x + b return output # <-----forward输入参数和backward输出参数必须一一对应 @staticmethod # <-----forward和backward都是静态方法 def backward(ctx, grad_output): # <-----ctx作为内部参数在前向反向传播中协调 w,x = ctx.saved_variables # <-----ctx读取参数 print(‘type in backward‘,type(x)) grad_w = grad_output * x grad_x = grad_output * w grad_b = grad_output * 1 return grad_w, grad_x, grad_b # <-----backward输入参数和forward输出参数必须一一对应
调用方法一
类名.apply(参数)
输出变量.backward()
import torch as t from torch.autograd import Variable as V x = V(t.ones(1)) w = V(t.rand(1), requires_grad = True) b = V(t.rand(1), requires_grad = True) print(‘开始前向传播‘) z=MultiplyAdd.apply(w, x, b) # <-----forward print(‘开始反向传播‘) z.backward() # 等效 # <-----backward # x不需要求导,中间过程还是会计算它的导数,但随后被清空 print(x.grad, w.grad, b.grad)
开始前向传播 type in forward <class ‘torch.FloatTensor‘> 开始反向传播 type in backward <class ‘torch.autograd.variable.Variable‘>(None, Variable containing: 1 [torch.FloatTensor of size 1], Variable containing: 1 [torch.FloatTensor of size 1])
调用方法二
类名.apply(参数)
输出变量.grad_fn.apply()
x = V(t.ones(1)) w = V(t.rand(1), requires_grad = True) b = V(t.rand(1), requires_grad = True) print(‘开始前向传播‘) z=MultiplyAdd.apply(w,x,b) # <-----forward print(‘开始反向传播‘) # 调用MultiplyAdd.backward # 会自动输出grad_w, grad_x, grad_b z.grad_fn.apply(V(t.ones(1))) # <-----backward,在计算中间输出,buffer并未清空,所以x的梯度不是None
开始前向传播 type in forward <class ‘torch.FloatTensor‘> 开始反向传播 type in backward <class ‘torch.autograd.variable.Variable‘>(Variable containing: 1 [torch.FloatTensor of size 1], Variable containing: 0.7655 [torch.FloatTensor of size 1], Variable containing: 1 [torch.FloatTensor of size 1])
之所以forward函数的输入是tensor,而backward函数的输入是variable,是为了实现高阶求导。backward函数的输入输出虽然是variable,但在实际使用时autograd.Function会将输入variable提取为tensor,并将计算结果的tensor封装成variable返回。在backward函数中,之所以也要对variable进行操作,是为了能够计算梯度的梯度(backward of backward)。下面举例说明,有关torch.autograd.grad的更详细使用请参照文档。
二、高阶导数
grad_x =t.autograd.grad(y, x, create_graph=True)
grad_grad_x = t.autograd.grad(grad_x[0],x)
x = V(t.Tensor([5]), requires_grad=True) y = x ** 2 grad_x = t.autograd.grad(y, x, create_graph=True) print(grad_x) # dy/dx = 2 * x grad_grad_x = t.autograd.grad(grad_x[0],x) print(grad_grad_x) # 二阶导数 d(2x)/dx = 2
(Variable containing: 10 [torch.FloatTensor of size 1],)(Variable containing: 2 [torch.FloatTensor of size 1],)
三、梯度检查
t.autograd.gradcheck(Sigmoid.apply, (test_input,), eps=1e-3)
此外在实现了自己的Function之后,还可以使用gradcheck
函数来检测实现是否正确。gradcheck
通过数值逼近来计算梯度,可能具有一定的误差,通过控制eps
的大小可以控制容忍的误差。
class Sigmoid(Function): @staticmethod def forward(ctx, x): output = 1 / (1 + t.exp(-x)) ctx.save_for_backward(output) return output @staticmethod def backward(ctx, grad_output): output, = ctx.saved_variables grad_x = output * (1 - output) * grad_output return grad_x # 采用数值逼近方式检验计算梯度的公式对不对 test_input = V(t.randn(3,4), requires_grad=True) t.autograd.gradcheck(Sigmoid.apply, (test_input,), eps=1e-3)
True
测试效率,
def f_sigmoid(x): y = Sigmoid.apply(x) y.backward(t.ones(x.size())) def f_naive(x): y = 1/(1 + t.exp(-x)) y.backward(t.ones(x.size())) def f_th(x): y = t.sigmoid(x) y.backward(t.ones(x.size())) x=V(t.randn(100, 100), requires_grad=True) %timeit -n 100 f_sigmoid(x) %timeit -n 100 f_naive(x) %timeit -n 100 f_th(x)
实际测试结果,
245 μs ± 70.1 μs per loop (mean ± std. dev. of 7 runs, 100 loops each) 211 μs ± 23.3 μs per loop (mean ± std. dev. of 7 runs, 100 loops each) 219 μs ± 36.6 μs per loop (mean ± std. dev. of 7 runs, 100 loops each)
书中说的结果,
100 loops, best of 3: 320 μs per loop 100 loops, best of 3: 588 μs per loop 100 loops, best of 3: 271 μs per loop
很奇怪,我的结果竟然是:简单堆砌<官方封装<自己封装……不过还是引用一下书中的结论吧:
显然
f_sigmoid
要比单纯利用autograd
加减和乘方操作实现的函数快不少,因为f_sigmoid的backward优化了反向传播的过程。另外可以看出系统实现的buildin接口(t.sigmoid)更快。
原文地址:https://www.cnblogs.com/hellcat/p/8453615.html