Pytorch反向传播中的细节-计算梯度时的默认累加操作


Posted in Python onJune 05, 2021

Pytorch反向传播计算梯度默认累加

今天学习pytorch实现简单的线性回归,发现了pytorch的反向传播时计算梯度采用的累加机制, 于是百度来一下,好多博客都说了累加机制,但是好多都没有说明这个累加机制到底会有啥影响, 所以我趁着自己练习的一个例子正好直观的看一下以及如何解决:

pytorch实现线性回归

先附上试验代码来感受一下:

torch.manual_seed(6)
lr = 0.01   # 学习率
result = []

# 创建训练数据
x = torch.rand(20, 1) * 10
y = 2 * x + (5 + torch.randn(20, 1)) 

# 构建线性回归函数
w = torch.randn((1), requires_grad=True)
b = torch.zeros((1), requires_grad=True)
# 这里是迭代过程,为了看pytorch的反向传播计算梯度的细节,我先迭代两次
for iteration in range(2):

    # 前向传播
    wx = torch.mul(w, x)
    y_pred = torch.add(wx, b)

    # 计算 MSE loss
    loss = (0.5 * (y - y_pred) ** 2).mean()
    
    # 反向传播
    loss.backward()
    
    # 这里看一下反向传播计算的梯度
    print("w.grad:", w.grad)
    print("b.grad:", b.grad)
    
    # 更新参数
    b.data.sub_(lr * b.grad)
    w.data.sub_(lr * w.grad)

上面的代码比较简单,迭代了两次, 看一下计算的梯度结果:

w.grad: tensor([-74.6261])
b.grad: tensor([-12.5532])
w.grad: tensor([-122.9075])
b.grad: tensor([-20.9364])

然后我稍微加两行代码, 就是在反向传播上面,我手动添加梯度清零操作的代码,再感受一下结果:

torch.manual_seed(6)
lr = 0.01
result = []
# 创建训练数据
x = torch.rand(20, 1) * 10
#print(x)
y = 2 * x + (5 + torch.randn(20, 1)) 
#print(y)
# 构建线性回归函数
w = torch.randn((1), requires_grad=True)
#print(w)
b = torch.zeros((1), requires_grad=True)
#print(b)
for iteration in range(2):
    # 前向传播
    wx = torch.mul(w, x)
    y_pred = torch.add(wx, b)

    # 计算 MSE loss
    loss = (0.5 * (y - y_pred) ** 2).mean()
    
    # 由于pytorch反向传播中,梯度是累加的,所以如果不想先前的梯度影响当前梯度的计算,需要手动清0
     if iteration > 0: 
        w.grad.data.zero_()
        b.grad.data.zero_()
    
    # 反向传播
    loss.backward()
    
    # 看一下梯度
    print("w.grad:", w.grad)
    print("b.grad:", b.grad)
    
    # 更新参数
    b.data.sub_(lr * b.grad)
    w.data.sub_(lr * w.grad)

w.grad: tensor([-74.6261])
b.grad: tensor([-12.5532])
w.grad: tensor([-48.2813])
b.grad: tensor([-8.3831])

从上面可以发现,pytorch在反向传播的时候,确实是默认累加上了上一次求的梯度, 如果不想让上一次的梯度影响自己本次梯度计算的话,需要手动的清零。

但是, 如果不进行手动清零的话,会有什么后果呢? 我在这次线性回归试验中,遇到的后果就是loss值反复的震荡不收敛。下面感受一下:

torch.manual_seed(6)
lr = 0.01
result = []
# 创建训练数据
x = torch.rand(20, 1) * 10
#print(x)
y = 2 * x + (5 + torch.randn(20, 1)) 
#print(y)
# 构建线性回归函数
w = torch.randn((1), requires_grad=True)
#print(w)
b = torch.zeros((1), requires_grad=True)
#print(b)

for iteration in range(1000):
    # 前向传播
    wx = torch.mul(w, x)
    y_pred = torch.add(wx, b)

    # 计算 MSE loss
    loss = (0.5 * (y - y_pred) ** 2).mean()
#     print("iteration {}: loss {}".format(iteration, loss))
    result.append(loss)
    
    # 由于pytorch反向传播中,梯度是累加的,所以如果不想先前的梯度影响当前梯度的计算,需要手动清0
    #if iteration > 0: 
    #    w.grad.data.zero_()
    #    b.grad.data.zero_()
  
    # 反向传播
    loss.backward()
 
    # 更新参数
    b.data.sub_(lr * b.grad)
    w.data.sub_(lr * w.grad)
    
    if loss.data.numpy() < 1:
        break
   plt.plot(result)

上面的代码中,我没有进行手动清零,迭代1000次, 把每一次的loss放到来result中, 然后画出图像,感受一下结果:

Pytorch反向传播中的细节-计算梯度时的默认累加操作

接下来,我把手动清零的注释打开,进行每次迭代之后的手动清零操作,得到的结果:

Pytorch反向传播中的细节-计算梯度时的默认累加操作

可以看到,这个才是理想中的反向传播求导,然后更新参数后得到的loss值的变化。

总结

这次主要是记录一下,pytorch在进行反向传播计算梯度的时候的累加机制到底是什么样子? 至于为什么采用这种机制,我也搜了一下,大部分给出的结果是这样子的:

Pytorch反向传播中的细节-计算梯度时的默认累加操作

但是如果不想累加的话,可以采用手动清零的方式,只需要在每次迭代时加上即可

w.grad.data.zero_()
b.grad.data.zero_()

另外, 在搜索资料的时候,在一篇博客上看到两个不错的线性回归时pytorch的计算图在这里借用一下:

Pytorch反向传播中的细节-计算梯度时的默认累加操作
Pytorch反向传播中的细节-计算梯度时的默认累加操作

以上为个人经验,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
Python实现爬取知乎神回复简单爬虫代码分享
Jan 04 Python
Python文档生成工具pydoc使用介绍
Jun 02 Python
Python使用smtplib模块发送电子邮件的流程详解
Jun 27 Python
Django rest framework实现分页的示例
May 24 Python
Django框架中间件(Middleware)用法实例分析
May 24 Python
树莓派安装OpenCV3完整过程的实现
Oct 10 Python
Python实现RGB与HSI颜色空间的互换方式
Nov 27 Python
python selenium实现发送带附件的邮件代码实例
Dec 10 Python
Python进程Multiprocessing模块原理解析
Feb 28 Python
python如何爬取动态网站
Sep 09 Python
OpenCV+Python3.5 简易手势识别的实现
Dec 21 Python
pandas提升计算效率的一些方法汇总
May 30 Python
pytorch 梯度NAN异常值的解决方案
Jun 05 #Python
pytorch 权重weight 与 梯度grad 可视化操作
PyTorch 如何检查模型梯度是否可导
python-opencv 中值滤波{cv2.medianBlur(src, ksize)}的用法
解决Pytorch修改预训练模型时遇到key不匹配的情况
Jun 05 #Python
pytorch 预训练模型读取修改相关参数的填坑问题
Jun 05 #Python
解决pytorch 损失函数中输入输出不匹配的问题
Jun 05 #Python
You might like
基于PHP编程注意事项的小结
2013/04/27 PHP
php抽奖小程序的实现代码
2013/06/18 PHP
php实现保存submit内容之后禁止刷新
2014/03/19 PHP
php随机取mysql记录方法小结
2014/12/27 PHP
php+ajax实现仿百度查询下拉内容功能示例
2017/10/20 PHP
Thinkphp5.0框架的Db操作实例分析【连接、增删改查、链式操作等】
2019/10/11 PHP
解决Laravel无法使用COOKIE和SESSION的问题
2019/10/16 PHP
用javascript实现分割提取页面所需内容
2007/05/09 Javascript
jQuery中与toggleClass等价的程序段 以及未来学习的方向
2010/03/18 Javascript
关于火狐(firefox)及ie下event获取的两种方法
2012/12/27 Javascript
jquery实现翻动fadeIn显示的方法
2015/03/05 Javascript
详解jquery easyui之datagrid使用参考
2016/12/05 Javascript
AngularJS页面带参跳转及参数解析操作示例
2017/06/28 Javascript
JavaScript闭包和回调详解
2017/08/09 Javascript
Javascript快速实现浏览器系统通知
2017/08/26 Javascript
解决vue项目使用font-awesome,build后路径的问题
2018/09/01 Javascript
对layer弹出框中icon数字参数的说明介绍
2019/09/04 Javascript
Vue过渡效果之CSS过渡详解(结合transition,animation,animate.css)
2020/02/05 Javascript
状态机的概念和在Python下使用状态机的教程
2015/04/11 Python
Python自定义简单图轴简单实例
2018/01/08 Python
解决PyCharm同目录下导入模块会报错的问题
2018/10/13 Python
Python numpy中矩阵的基本用法汇总
2019/02/12 Python
pyqt5之将textBrowser的内容写入txt文档的方法
2019/06/21 Python
对python中assert、isinstance的用法详解
2019/11/27 Python
详解如何在pyqt中通过OpenCV实现对窗口的透视变换
2020/09/20 Python
3种方式实现瀑布流布局小结
2019/09/05 HTML / CSS
碧欧泉美国官网:Biotherm美国
2016/08/31 全球购物
SQL Server的固定数据库角色都有哪些?对应的服务器权限有哪些?
2013/05/18 面试题
医院护士的求职信范文
2013/12/26 职场文书
校园达人秀策划书
2014/01/12 职场文书
2014年应届大学生毕业自我鉴定
2014/01/31 职场文书
学校督导评估方案
2014/06/10 职场文书
企业法人代表授权委托书
2014/10/02 职场文书
三八妇女节标语
2014/10/09 职场文书
司法局2014法制宣传日活动总结
2014/11/01 职场文书
义诊活动总结
2015/02/04 职场文书