Pytorch反向传播中的细节-计算梯度时的默认累加操作


Posted in Python onJune 05, 2021

Pytorch反向传播计算梯度默认累加

今天学习pytorch实现简单的线性回归,发现了pytorch的反向传播时计算梯度采用的累加机制, 于是百度来一下,好多博客都说了累加机制,但是好多都没有说明这个累加机制到底会有啥影响, 所以我趁着自己练习的一个例子正好直观的看一下以及如何解决:

pytorch实现线性回归

先附上试验代码来感受一下:

torch.manual_seed(6)
lr = 0.01   # 学习率
result = []

# 创建训练数据
x = torch.rand(20, 1) * 10
y = 2 * x + (5 + torch.randn(20, 1)) 

# 构建线性回归函数
w = torch.randn((1), requires_grad=True)
b = torch.zeros((1), requires_grad=True)
# 这里是迭代过程,为了看pytorch的反向传播计算梯度的细节,我先迭代两次
for iteration in range(2):

    # 前向传播
    wx = torch.mul(w, x)
    y_pred = torch.add(wx, b)

    # 计算 MSE loss
    loss = (0.5 * (y - y_pred) ** 2).mean()
    
    # 反向传播
    loss.backward()
    
    # 这里看一下反向传播计算的梯度
    print("w.grad:", w.grad)
    print("b.grad:", b.grad)
    
    # 更新参数
    b.data.sub_(lr * b.grad)
    w.data.sub_(lr * w.grad)

上面的代码比较简单,迭代了两次, 看一下计算的梯度结果:

w.grad: tensor([-74.6261])
b.grad: tensor([-12.5532])
w.grad: tensor([-122.9075])
b.grad: tensor([-20.9364])

然后我稍微加两行代码, 就是在反向传播上面,我手动添加梯度清零操作的代码,再感受一下结果:

torch.manual_seed(6)
lr = 0.01
result = []
# 创建训练数据
x = torch.rand(20, 1) * 10
#print(x)
y = 2 * x + (5 + torch.randn(20, 1)) 
#print(y)
# 构建线性回归函数
w = torch.randn((1), requires_grad=True)
#print(w)
b = torch.zeros((1), requires_grad=True)
#print(b)
for iteration in range(2):
    # 前向传播
    wx = torch.mul(w, x)
    y_pred = torch.add(wx, b)

    # 计算 MSE loss
    loss = (0.5 * (y - y_pred) ** 2).mean()
    
    # 由于pytorch反向传播中,梯度是累加的,所以如果不想先前的梯度影响当前梯度的计算,需要手动清0
     if iteration > 0: 
        w.grad.data.zero_()
        b.grad.data.zero_()
    
    # 反向传播
    loss.backward()
    
    # 看一下梯度
    print("w.grad:", w.grad)
    print("b.grad:", b.grad)
    
    # 更新参数
    b.data.sub_(lr * b.grad)
    w.data.sub_(lr * w.grad)

w.grad: tensor([-74.6261])
b.grad: tensor([-12.5532])
w.grad: tensor([-48.2813])
b.grad: tensor([-8.3831])

从上面可以发现,pytorch在反向传播的时候,确实是默认累加上了上一次求的梯度, 如果不想让上一次的梯度影响自己本次梯度计算的话,需要手动的清零。

但是, 如果不进行手动清零的话,会有什么后果呢? 我在这次线性回归试验中,遇到的后果就是loss值反复的震荡不收敛。下面感受一下:

torch.manual_seed(6)
lr = 0.01
result = []
# 创建训练数据
x = torch.rand(20, 1) * 10
#print(x)
y = 2 * x + (5 + torch.randn(20, 1)) 
#print(y)
# 构建线性回归函数
w = torch.randn((1), requires_grad=True)
#print(w)
b = torch.zeros((1), requires_grad=True)
#print(b)

for iteration in range(1000):
    # 前向传播
    wx = torch.mul(w, x)
    y_pred = torch.add(wx, b)

    # 计算 MSE loss
    loss = (0.5 * (y - y_pred) ** 2).mean()
#     print("iteration {}: loss {}".format(iteration, loss))
    result.append(loss)
    
    # 由于pytorch反向传播中,梯度是累加的,所以如果不想先前的梯度影响当前梯度的计算,需要手动清0
    #if iteration > 0: 
    #    w.grad.data.zero_()
    #    b.grad.data.zero_()
  
    # 反向传播
    loss.backward()
 
    # 更新参数
    b.data.sub_(lr * b.grad)
    w.data.sub_(lr * w.grad)
    
    if loss.data.numpy() < 1:
        break
   plt.plot(result)

上面的代码中,我没有进行手动清零,迭代1000次, 把每一次的loss放到来result中, 然后画出图像,感受一下结果:

Pytorch反向传播中的细节-计算梯度时的默认累加操作

接下来,我把手动清零的注释打开,进行每次迭代之后的手动清零操作,得到的结果:

Pytorch反向传播中的细节-计算梯度时的默认累加操作

可以看到,这个才是理想中的反向传播求导,然后更新参数后得到的loss值的变化。

总结

这次主要是记录一下,pytorch在进行反向传播计算梯度的时候的累加机制到底是什么样子? 至于为什么采用这种机制,我也搜了一下,大部分给出的结果是这样子的:

Pytorch反向传播中的细节-计算梯度时的默认累加操作

但是如果不想累加的话,可以采用手动清零的方式,只需要在每次迭代时加上即可

w.grad.data.zero_()
b.grad.data.zero_()

另外, 在搜索资料的时候,在一篇博客上看到两个不错的线性回归时pytorch的计算图在这里借用一下:

Pytorch反向传播中的细节-计算梯度时的默认累加操作
Pytorch反向传播中的细节-计算梯度时的默认累加操作

以上为个人经验,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
Python写的Discuz7.2版faq.php注入漏洞工具
Aug 06 Python
Windows下用py2exe将Python程序打包成exe程序的教程
Apr 08 Python
在Mac OS上搭建Python的开发环境
Dec 24 Python
Python+django实现文件下载
Jan 17 Python
Python学习小技巧之列表项的排序
May 20 Python
Python实现随机选择元素功能
Sep 14 Python
浅析Python与Mongodb数据库之间的操作方法
Jul 01 Python
django使用django-apscheduler 实现定时任务的例子
Jul 20 Python
在pycharm中使用matplotlib.pyplot 绘图时报错的解决
Jun 01 Python
Python替换NumPy数组中大于某个值的所有元素实例
Jun 08 Python
Python字符串格式化常用手段及注意事项
Jun 17 Python
Python 中的单分派泛函数你真的了解吗
Jun 22 Python
pytorch 梯度NAN异常值的解决方案
Jun 05 #Python
pytorch 权重weight 与 梯度grad 可视化操作
PyTorch 如何检查模型梯度是否可导
python-opencv 中值滤波{cv2.medianBlur(src, ksize)}的用法
解决Pytorch修改预训练模型时遇到key不匹配的情况
Jun 05 #Python
pytorch 预训练模型读取修改相关参数的填坑问题
Jun 05 #Python
解决pytorch 损失函数中输入输出不匹配的问题
Jun 05 #Python
You might like
PHP开发文件系统实例讲解
2006/10/09 PHP
php仿discuz分页效果代码
2008/10/02 PHP
ThinkPHP实例化模型的四种方法概述
2014/08/22 PHP
Yii2 rbac权限控制操作步骤实例教程
2016/04/29 PHP
使用PHP下载CSS文件中的所有图片【几行代码即可实现】
2016/12/14 PHP
php+lottery.js实现九宫格抽奖功能
2019/07/21 PHP
javascript 框架小结 个人工作经验
2009/06/13 Javascript
JS 模态对话框和非模态对话框操作技巧汇总
2013/04/15 Javascript
document.forms用法示例介绍
2014/06/26 Javascript
js实现跨域的几种方法汇总(图片ping、JSONP和CORS)
2015/10/25 Javascript
JavaScript动态创建form表单并提交的实现方法
2015/12/10 Javascript
JavaScript实现简洁的俄罗斯方块完整实例
2016/03/01 Javascript
ES6的新特性概览
2016/03/10 Javascript
javascript实现下雪效果【实例代码】
2016/05/03 Javascript
详解React Native 屏幕适配(炒鸡简单的方法)
2018/06/11 Javascript
浅谈angular2子组件的事件传递(任意组件事件传递)
2018/09/30 Javascript
[02:03]完美世界DOTA2联赛10月30日赛事集锦
2020/10/31 DOTA
python中利用Future对象回调别的函数示例代码
2017/09/07 Python
使用python将请求的requests headers参数格式化方法
2019/01/02 Python
详解Python 定时框架 Apscheduler原理及安装过程
2019/06/14 Python
Python 读取用户指令和格式化打印实现解析
2019/09/02 Python
解决Pycharm 中遇到Unresolved reference 'sklearn'的问题
2020/07/13 Python
python实现企业微信定时发送文本消息的实例代码
2020/11/25 Python
Python的信号库Blinker用法详解
2020/12/31 Python
详解CSS3的图层阴影和文字阴影效果使用
2016/06/09 HTML / CSS
HTML5实现Notification API桌面通知功能
2016/03/02 HTML / CSS
香港礼品网站:GiftU eshop
2017/09/01 全球购物
哥伦比亚加拿大官网:Columbia Sportswear Canada
2020/09/07 全球购物
SQL Server的固定数据库角色都有哪些?对应的服务器权限有哪些?
2013/05/18 面试题
自考生自我鉴定范文
2013/10/01 职场文书
《逃家小兔》教学反思
2014/02/23 职场文书
优秀员工演讲稿
2014/05/19 职场文书
小班上学期个人总结
2015/02/12 职场文书
作弊检讨书范文
2015/05/06 职场文书
使用css样式设计一个简单的html登陆界面的实现
2021/03/30 HTML / CSS
JavaScript小技巧带你提升你的代码技能
2021/09/15 Javascript