Pytorch反向传播中的细节-计算梯度时的默认累加操作


Posted in Python onJune 05, 2021

Pytorch反向传播计算梯度默认累加

今天学习pytorch实现简单的线性回归,发现了pytorch的反向传播时计算梯度采用的累加机制, 于是百度来一下,好多博客都说了累加机制,但是好多都没有说明这个累加机制到底会有啥影响, 所以我趁着自己练习的一个例子正好直观的看一下以及如何解决:

pytorch实现线性回归

先附上试验代码来感受一下:

torch.manual_seed(6)
lr = 0.01   # 学习率
result = []

# 创建训练数据
x = torch.rand(20, 1) * 10
y = 2 * x + (5 + torch.randn(20, 1)) 

# 构建线性回归函数
w = torch.randn((1), requires_grad=True)
b = torch.zeros((1), requires_grad=True)
# 这里是迭代过程,为了看pytorch的反向传播计算梯度的细节,我先迭代两次
for iteration in range(2):

    # 前向传播
    wx = torch.mul(w, x)
    y_pred = torch.add(wx, b)

    # 计算 MSE loss
    loss = (0.5 * (y - y_pred) ** 2).mean()
    
    # 反向传播
    loss.backward()
    
    # 这里看一下反向传播计算的梯度
    print("w.grad:", w.grad)
    print("b.grad:", b.grad)
    
    # 更新参数
    b.data.sub_(lr * b.grad)
    w.data.sub_(lr * w.grad)

上面的代码比较简单,迭代了两次, 看一下计算的梯度结果:

w.grad: tensor([-74.6261])
b.grad: tensor([-12.5532])
w.grad: tensor([-122.9075])
b.grad: tensor([-20.9364])

然后我稍微加两行代码, 就是在反向传播上面,我手动添加梯度清零操作的代码,再感受一下结果:

torch.manual_seed(6)
lr = 0.01
result = []
# 创建训练数据
x = torch.rand(20, 1) * 10
#print(x)
y = 2 * x + (5 + torch.randn(20, 1)) 
#print(y)
# 构建线性回归函数
w = torch.randn((1), requires_grad=True)
#print(w)
b = torch.zeros((1), requires_grad=True)
#print(b)
for iteration in range(2):
    # 前向传播
    wx = torch.mul(w, x)
    y_pred = torch.add(wx, b)

    # 计算 MSE loss
    loss = (0.5 * (y - y_pred) ** 2).mean()
    
    # 由于pytorch反向传播中,梯度是累加的,所以如果不想先前的梯度影响当前梯度的计算,需要手动清0
     if iteration > 0: 
        w.grad.data.zero_()
        b.grad.data.zero_()
    
    # 反向传播
    loss.backward()
    
    # 看一下梯度
    print("w.grad:", w.grad)
    print("b.grad:", b.grad)
    
    # 更新参数
    b.data.sub_(lr * b.grad)
    w.data.sub_(lr * w.grad)

w.grad: tensor([-74.6261])
b.grad: tensor([-12.5532])
w.grad: tensor([-48.2813])
b.grad: tensor([-8.3831])

从上面可以发现,pytorch在反向传播的时候,确实是默认累加上了上一次求的梯度, 如果不想让上一次的梯度影响自己本次梯度计算的话,需要手动的清零。

但是, 如果不进行手动清零的话,会有什么后果呢? 我在这次线性回归试验中,遇到的后果就是loss值反复的震荡不收敛。下面感受一下:

torch.manual_seed(6)
lr = 0.01
result = []
# 创建训练数据
x = torch.rand(20, 1) * 10
#print(x)
y = 2 * x + (5 + torch.randn(20, 1)) 
#print(y)
# 构建线性回归函数
w = torch.randn((1), requires_grad=True)
#print(w)
b = torch.zeros((1), requires_grad=True)
#print(b)

for iteration in range(1000):
    # 前向传播
    wx = torch.mul(w, x)
    y_pred = torch.add(wx, b)

    # 计算 MSE loss
    loss = (0.5 * (y - y_pred) ** 2).mean()
#     print("iteration {}: loss {}".format(iteration, loss))
    result.append(loss)
    
    # 由于pytorch反向传播中,梯度是累加的,所以如果不想先前的梯度影响当前梯度的计算,需要手动清0
    #if iteration > 0: 
    #    w.grad.data.zero_()
    #    b.grad.data.zero_()
  
    # 反向传播
    loss.backward()
 
    # 更新参数
    b.data.sub_(lr * b.grad)
    w.data.sub_(lr * w.grad)
    
    if loss.data.numpy() < 1:
        break
   plt.plot(result)

上面的代码中,我没有进行手动清零,迭代1000次, 把每一次的loss放到来result中, 然后画出图像,感受一下结果:

Pytorch反向传播中的细节-计算梯度时的默认累加操作

接下来,我把手动清零的注释打开,进行每次迭代之后的手动清零操作,得到的结果:

Pytorch反向传播中的细节-计算梯度时的默认累加操作

可以看到,这个才是理想中的反向传播求导,然后更新参数后得到的loss值的变化。

总结

这次主要是记录一下,pytorch在进行反向传播计算梯度的时候的累加机制到底是什么样子? 至于为什么采用这种机制,我也搜了一下,大部分给出的结果是这样子的:

Pytorch反向传播中的细节-计算梯度时的默认累加操作

但是如果不想累加的话,可以采用手动清零的方式,只需要在每次迭代时加上即可

w.grad.data.zero_()
b.grad.data.zero_()

另外, 在搜索资料的时候,在一篇博客上看到两个不错的线性回归时pytorch的计算图在这里借用一下:

Pytorch反向传播中的细节-计算梯度时的默认累加操作
Pytorch反向传播中的细节-计算梯度时的默认累加操作

以上为个人经验,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
python基础教程之Hello World!
Aug 29 Python
python xlsxwriter库生成图表的应用示例
Mar 16 Python
DES加密解密算法之python实现版(图文并茂)
Dec 06 Python
对python的unittest架构公共参数token提取方法详解
Dec 17 Python
python UDP(udp)协议发送和接收的实例
Jul 22 Python
在vscode中配置python环境过程解析
Sep 28 Python
Python+OpenCV实现旋转文本校正方式
Jan 09 Python
python实现猜单词游戏
May 22 Python
使用keras时input_shape的维度表示问题说明
Jun 29 Python
Python日志打印里logging.getLogger源码分析详解
Jan 17 Python
Python3 + Appium + 安卓模拟器实现APP自动化测试并生成测试报告
Jan 27 Python
Python IO文件管理的具体使用
Mar 20 Python
pytorch 梯度NAN异常值的解决方案
Jun 05 #Python
pytorch 权重weight 与 梯度grad 可视化操作
PyTorch 如何检查模型梯度是否可导
python-opencv 中值滤波{cv2.medianBlur(src, ksize)}的用法
解决Pytorch修改预训练模型时遇到key不匹配的情况
Jun 05 #Python
pytorch 预训练模型读取修改相关参数的填坑问题
Jun 05 #Python
解决pytorch 损失函数中输入输出不匹配的问题
Jun 05 #Python
You might like
PHP 飞信好友免费短信API接口开源版
2010/07/22 PHP
PHP+JS+rsa数据加密传输实现代码
2011/03/23 PHP
谷歌音乐搜索栏的提示功能php修正代码
2011/05/09 PHP
详解php的socket通信
2015/08/11 PHP
Laravel框架实现的记录SQL日志功能示例
2018/06/19 PHP
用JQuery 实现AJAX加载XML并解析的脚本
2009/07/25 Javascript
Jquery实现带动画效果的经典二级导航菜单
2013/03/22 Javascript
Javascript中call的两种用法实例
2013/12/13 Javascript
jQuery 滑动方法slideDown向下滑动元素
2014/01/16 Javascript
JavaScript中使用ActiveXObject操作本地文件夹的方法
2014/03/28 Javascript
javascript实现dom动态创建省市纵向列表菜单的方法
2015/05/14 Javascript
使用jquery动态加载Js文件和Css文件
2015/10/24 Javascript
JavaScript仿支付宝6位数字密码输入框
2016/12/29 Javascript
javascript正则表达式模糊匹配IP地址功能示例
2017/01/06 Javascript
JS中利用localStorage防止页面动态添加数据刷新后数据丢失
2017/03/10 Javascript
新手如何快速理解js异步编程
2019/06/24 Javascript
vue使用codemirror的两种用法
2019/08/27 Javascript
[00:48]完美“圣”典2016风云人物:xiao8宣传片
2016/11/30 DOTA
[01:11:37]完美世界DOTA2联赛PWL S2 SZ vs FTD.C 第一场 11.19
2020/11/19 DOTA
Python简单获取二维数组行列数的方法示例
2018/12/21 Python
Python数据结构之栈、队列及二叉树定义与用法浅析
2018/12/27 Python
简单了解python变量的作用域
2019/07/30 Python
python实现批量处理将图片粘贴到另一张图片上并保存
2019/12/12 Python
python 实现让字典的value 成为列表
2019/12/16 Python
Python使用GitPython操作Git版本库的方法
2020/02/29 Python
Python logging模块进行封装实现原理解析
2020/08/07 Python
挪威户外活动服装和装备购物网站:Bergfreunde挪威
2016/10/20 全球购物
护士节演讲稿开场白
2014/08/25 职场文书
保证金退回承诺函格式
2015/01/21 职场文书
2015年科普工作总结
2015/07/23 职场文书
小学记事作文之200字
2019/08/06 职场文书
《学会生存》读后感3篇
2019/12/09 职场文书
浅谈golang package中init方法的多处定义及运行顺序问题
2021/05/06 Golang
Mysql数据库表中为什么有索引却没有提高查询速度
2022/02/24 MySQL
python字符串的一些常见实用操作
2022/04/06 Python
Win11 22H2 2022怎么更新? 获得Win1122H22022版本升级技巧
2022/09/23 数码科技