当前位置:网站首页>【pytorch】pytorch 自动求导、 Tensor 与 Autograd
【pytorch】pytorch 自动求导、 Tensor 与 Autograd
2022-07-31 16:15:00 【Enzo 想砸电脑】
在神经网络中,一个重要内容就是进行参数学习,而参数学习离不开求导。
现在大部分深度学习架构都有自动求导的功能,torch.autograd包 就是用来自动求导的。
torch.autograd 包为张量上 所有的操作 提供了自动求导功能
这一篇学习并记录一下 自动求导 的要点。
一、计算图
在整个向前计算过程中,PyTorch采用计算图的形式进行组织,该计算图为动态图,且在每次 前向传播时,将重新构建。其他深度学习架构,如TensorFlow、Keras一般为静态图。
- 计算图是一种有向无环图像,用图形方式来表示算子与变量之间的关系,直观高效。
- 图中 圆形表示变量,矩阵表示算子
- 表达式:z=wx+b,可写成两个表示式: y=wx,z=y+b,
- 其中x、w、b为变量,是用户创建的变量,不依赖于其他变量,故又称 为叶子节点。为计算各叶子节点的梯度,需要把对应的张量参数requires_grad属性设置为 True,这样就可自动跟踪其历史记录。(后面会细说)
- y、z 是计算得到的变量,非叶子节点,z为根节点
- mul和add是算子(或操作或函数)
这些变量及算子就构成了一个完整的计算过程 (或前向传播过程)
二、自动求导要点
为实现对Tensor自动求导,需考虑如下事项:
1)创建叶子节点(Leaf Node)的Tensor,使用requires_grad参数指定是否记录对其 的操作,以便之后利用backward()方法进行梯度求解。requires_grad参数的缺省值为 False,如果要对其求导需设置为True,然后与之有依赖关系的节点会自动变为True。
2)可利用requires_grad_()方法修改Tensor的requires_grad属性(比如一开始在训练阶段,requires_grad 值设置为了True,在测试阶段修改为 False)。可以调用.detach()或 with torch.no_grad():,将不再计算张量的梯度,跟踪张量的历史记录。这点在评估模 型、测试模型阶段中常常用到。
3)通过运算创建的Tensor(即非叶子节点),会自动被赋予grad_fn属性。该属性表 示梯度函数。叶子节点的grad_fn为None。
4)最后得到的Tensor(根节点)执行backward()函数,此时自动计算各变量的梯度。
- 每次反向传播结束,叶子结点的梯度会被清空。如果需要多次反向传播的梯度累加,需要指定backward 中的参数retain_graph=True,这样子节点的梯度是累加的。
- 非叶子节点的梯度backward调用后即被清空
5)backward()函数接收参数,该参数应和调用backward()函数的Tensor的维度相同, 或者是可broadcast的维度。如果求导的Tensor为标量(即一个数字),则backward中的参数可省略。
三、标量反向传播的计算
- 假设x、w、b都是标量,则计算结果 z 也是标量 (z=wx+b)
- 对根节点z调用backward()方法,我们无须对 backward()传入参数
* 这里先提一嘴,后面会说到的是: 如果目标张量对一个非标量调用backward(),则需要传入一个 gradient参数,该参数也是张量,而且需要与调用backward()的张量形状相同。
以下是实现自动求导的主要步骤:
import torch
# 输入张量 x
x = torch.Tensor([2])
# 初始化 权重参数w, 偏移量b,并设置 require_grad 属性为 True, 为自动求导
w = torch.randn(1, requires_grad=True)
b = torch.randn(1, requires_grad=True)
# 实现向前传播
y = torch.mul(w, x)
z = torch.add(y, b)
# 分别查看叶子节点 x, w, b 和 非叶子节点 y、z 的require_grad属性
print(x.requires_grad, w.requires_grad, b.requires_grad) # False True True
print(y.requires_grad, z.requires_grad ) # True True
# 查看各节点是否为叶子节点
print(x.is_leaf, w.is_leaf, b.is_leaf, y.is_leaf, z.is_leaf) # True True True False False
# 分别查看 叶子节点 和 非叶子节点 的 grad_fn 属性
print(x.grad_fn, w.grad_fn, b.grad_fn) # None None None
print(y.grad_fn, z.grad_fn) # <MulBackward0 object at 0x7f8ac1303910> <AddBackward0 object at 0x7f8ac1303070>
z.backward() # 梯度不会累加
# z.backward(retain_graph=True) # 如果多次使用backward,需要梯度累加,则需要修改参数retain_graph为True
# 查看叶子节点的梯度,x是叶子节点但它无须求导,故其梯度为None
print(w.grad,b.grad,x.grad) # tensor([2.]) tensor([1.]) None
#非叶子节点的梯度,执行backward之后,会自动清空
print(y.grad,z.grad) # None None
四、非标量反向传播的计算
边栏推荐
- 2022年整理LeetCode最新刷题攻略分享(附中文详细题解)
- Handling write conflicts under multi-master replication (4) - multi-master replication topology
- Dialogue with Zhuang Biaowei: The first lesson of open source
- How does automated testing create business value?
- .NET 20th Anniversary Interview - Zhang Shanyou: How .NET technology empowers and changes the world
- 更新数据表update
- 使用 Postman 工具高效管理和测试 SAP ABAP OData 服务的试读版
- Vb how to connect mysql_vb how to connect to the database collection "advice"
- Kubernetes常用命令
- Graham‘s Scan法求解凸包问题
猜你喜欢
上传图片-微信小程序(那些年的坑记录2022.4)
对话庄表伟:开源第一课
研发过程中的文档管理与工具
复杂高维医学数据挖掘与疾病风险分类研究
二分查找的细节坑
.NET 20th Anniversary Interview - Zhang Shanyou: How .NET technology empowers and changes the world
mysql black window ~ build database and build table
Kubernetes common commands
Use of radiobutton
How C programs run 01 - the composition of ordinary executable files
随机推荐
js的toString方法
复制延迟案例(3)-单调读
百度网盘网页版加速播放(有可用的网站吗)
LeetCode_733_Image rendering
Character pointer assignment [easy to understand]
BGP综合实验(建立对等体、路由反射器、联邦、路由宣告及聚合)
2020 WeChat applet decompilation tutorial (can applet decompile source code be used)
Kubernetes principle analysis and practical application manual, too complete
Emmet syntax
Use of radiobutton
【7.29】代码源 - 【排列】【石子游戏 II】【Cow and Snacks】【最小生成数】【数列】
7. Summary of common interview questions
第二届中国PWA开发者日
The new BMW 3 Series is on the market, with safety and comfort
深度学习机器学习理论及应用实战-必备知识点整理分享
[TypeScript] In-depth study of TypeScript type operations
牛客网刷题(二)
ML.NET related resources
Codeforces Round #796 (Div. 2) (A-D)
How to switch remote server in gerrit