Pytorch反向传播中的细节-计算梯度时的默认累加操作


Posted in Python onJune 05, 2021

Pytorch反向传播计算梯度默认累加

今天学习pytorch实现简单的线性回归,发现了pytorch的反向传播时计算梯度采用的累加机制, 于是百度来一下,好多博客都说了累加机制,但是好多都没有说明这个累加机制到底会有啥影响, 所以我趁着自己练习的一个例子正好直观的看一下以及如何解决:

pytorch实现线性回归

先附上试验代码来感受一下:

torch.manual_seed(6)
lr = 0.01   # 学习率
result = []

# 创建训练数据
x = torch.rand(20, 1) * 10
y = 2 * x + (5 + torch.randn(20, 1)) 

# 构建线性回归函数
w = torch.randn((1), requires_grad=True)
b = torch.zeros((1), requires_grad=True)
# 这里是迭代过程,为了看pytorch的反向传播计算梯度的细节,我先迭代两次
for iteration in range(2):

    # 前向传播
    wx = torch.mul(w, x)
    y_pred = torch.add(wx, b)

    # 计算 MSE loss
    loss = (0.5 * (y - y_pred) ** 2).mean()
    
    # 反向传播
    loss.backward()
    
    # 这里看一下反向传播计算的梯度
    print("w.grad:", w.grad)
    print("b.grad:", b.grad)
    
    # 更新参数
    b.data.sub_(lr * b.grad)
    w.data.sub_(lr * w.grad)

上面的代码比较简单,迭代了两次, 看一下计算的梯度结果:

w.grad: tensor([-74.6261])
b.grad: tensor([-12.5532])
w.grad: tensor([-122.9075])
b.grad: tensor([-20.9364])

然后我稍微加两行代码, 就是在反向传播上面,我手动添加梯度清零操作的代码,再感受一下结果:

torch.manual_seed(6)
lr = 0.01
result = []
# 创建训练数据
x = torch.rand(20, 1) * 10
#print(x)
y = 2 * x + (5 + torch.randn(20, 1)) 
#print(y)
# 构建线性回归函数
w = torch.randn((1), requires_grad=True)
#print(w)
b = torch.zeros((1), requires_grad=True)
#print(b)
for iteration in range(2):
    # 前向传播
    wx = torch.mul(w, x)
    y_pred = torch.add(wx, b)

    # 计算 MSE loss
    loss = (0.5 * (y - y_pred) ** 2).mean()
    
    # 由于pytorch反向传播中,梯度是累加的,所以如果不想先前的梯度影响当前梯度的计算,需要手动清0
     if iteration > 0: 
        w.grad.data.zero_()
        b.grad.data.zero_()
    
    # 反向传播
    loss.backward()
    
    # 看一下梯度
    print("w.grad:", w.grad)
    print("b.grad:", b.grad)
    
    # 更新参数
    b.data.sub_(lr * b.grad)
    w.data.sub_(lr * w.grad)

w.grad: tensor([-74.6261])
b.grad: tensor([-12.5532])
w.grad: tensor([-48.2813])
b.grad: tensor([-8.3831])

从上面可以发现,pytorch在反向传播的时候,确实是默认累加上了上一次求的梯度, 如果不想让上一次的梯度影响自己本次梯度计算的话,需要手动的清零。

但是, 如果不进行手动清零的话,会有什么后果呢? 我在这次线性回归试验中,遇到的后果就是loss值反复的震荡不收敛。下面感受一下:

torch.manual_seed(6)
lr = 0.01
result = []
# 创建训练数据
x = torch.rand(20, 1) * 10
#print(x)
y = 2 * x + (5 + torch.randn(20, 1)) 
#print(y)
# 构建线性回归函数
w = torch.randn((1), requires_grad=True)
#print(w)
b = torch.zeros((1), requires_grad=True)
#print(b)

for iteration in range(1000):
    # 前向传播
    wx = torch.mul(w, x)
    y_pred = torch.add(wx, b)

    # 计算 MSE loss
    loss = (0.5 * (y - y_pred) ** 2).mean()
#     print("iteration {}: loss {}".format(iteration, loss))
    result.append(loss)
    
    # 由于pytorch反向传播中,梯度是累加的,所以如果不想先前的梯度影响当前梯度的计算,需要手动清0
    #if iteration > 0: 
    #    w.grad.data.zero_()
    #    b.grad.data.zero_()
  
    # 反向传播
    loss.backward()
 
    # 更新参数
    b.data.sub_(lr * b.grad)
    w.data.sub_(lr * w.grad)
    
    if loss.data.numpy() < 1:
        break
   plt.plot(result)

上面的代码中,我没有进行手动清零,迭代1000次, 把每一次的loss放到来result中, 然后画出图像,感受一下结果:

Pytorch反向传播中的细节-计算梯度时的默认累加操作

接下来,我把手动清零的注释打开,进行每次迭代之后的手动清零操作,得到的结果:

Pytorch反向传播中的细节-计算梯度时的默认累加操作

可以看到,这个才是理想中的反向传播求导,然后更新参数后得到的loss值的变化。

总结

这次主要是记录一下,pytorch在进行反向传播计算梯度的时候的累加机制到底是什么样子? 至于为什么采用这种机制,我也搜了一下,大部分给出的结果是这样子的:

Pytorch反向传播中的细节-计算梯度时的默认累加操作

但是如果不想累加的话,可以采用手动清零的方式,只需要在每次迭代时加上即可

w.grad.data.zero_()
b.grad.data.zero_()

另外, 在搜索资料的时候,在一篇博客上看到两个不错的线性回归时pytorch的计算图在这里借用一下:

Pytorch反向传播中的细节-计算梯度时的默认累加操作
Pytorch反向传播中的细节-计算梯度时的默认累加操作

以上为个人经验,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
跟老齐学Python之编写类之四再论继承
Oct 11 Python
仅用50行代码实现一个Python编写的计算器的教程
Apr 17 Python
浅谈python中截取字符函数strip,lstrip,rstrip
Jul 17 Python
解读Python编程中的命名空间与作用域
Oct 16 Python
Python字典及字典基本操作方法详解
Jan 30 Python
Python格式化输出字符串方法小结【%与format】
Oct 29 Python
python中数组和矩阵乘法及使用总结(推荐)
May 18 Python
python socket通信编程实现文件上传代码实例
Dec 14 Python
Python操作注册表详细步骤介绍
Feb 05 Python
keras导入weights方式
Jun 12 Python
python实现简单的井字棋
May 26 Python
Python标准库pathlib操作目录和文件
Nov 20 Python
pytorch 梯度NAN异常值的解决方案
Jun 05 #Python
pytorch 权重weight 与 梯度grad 可视化操作
PyTorch 如何检查模型梯度是否可导
python-opencv 中值滤波{cv2.medianBlur(src, ksize)}的用法
解决Pytorch修改预训练模型时遇到key不匹配的情况
Jun 05 #Python
pytorch 预训练模型读取修改相关参数的填坑问题
Jun 05 #Python
解决pytorch 损失函数中输入输出不匹配的问题
Jun 05 #Python
You might like
逐步提升php框架的性能
2008/01/10 PHP
php树型类实例
2014/12/05 PHP
详解PHP函数 strip_tags 处理字符串缺陷bug
2017/06/11 PHP
PHP连接SQL Server的方法分析【基于thinkPHP5.1框架】
2019/05/06 PHP
PHP7 其他修改
2021/03/09 PHP
静态图片的十一种滤镜效果--不支持Ie7及非IE浏览器。
2007/03/06 Javascript
跨域表单提交状态的变相判断代码
2009/11/12 Javascript
HTML颜色选择器实现代码
2010/11/23 Javascript
Javascript实现重力弹跳拖拽运动效果示例
2013/06/28 Javascript
jquery右下角弹出提示框示例代码
2013/10/08 Javascript
Jquery左右滑动插件之实现超级炫酷动画效果附源码下载
2015/12/02 Javascript
js判断文本框输入的内容是否为数字
2015/12/23 Javascript
javascript中的作用域和闭包详解
2016/01/13 Javascript
分享jQuery网页元素拖拽插件
2020/12/01 Javascript
jquery.uploadifive插件怎么解决上传限制图片或文件大小问题
2017/05/08 jQuery
nodejs+websocket实时聊天系统改进版
2017/05/18 NodeJs
深入理解Vue 的条件渲染和列表渲染
2017/09/01 Javascript
js断点调试经验分享
2017/12/08 Javascript
探秘vue-rx 2.0(推荐)
2018/09/21 Javascript
vue基础之模板和过滤器用法实例分析
2019/03/12 Javascript
微信小程序实现左侧滑栏过程解析
2019/08/26 Javascript
如何实现一个简易版的vuex持久化工具
2019/09/11 Javascript
vue实现跳转接口push 转场动画示例
2019/11/01 Javascript
ant-design表单处理和常用方法及自定义验证操作
2020/10/27 Javascript
JavaScript实现网页动态生成表格
2020/11/25 Javascript
ES6字符串的扩展实例
2020/12/21 Javascript
JavaScript实现滚动加载更多
2020/12/27 Javascript
深入源码解析Python中的对象与类型
2015/12/11 Python
Python 实现一个颜色色值转换的小工具
2016/12/06 Python
python3+PyQt5实现自定义分数滑块部件
2018/04/24 Python
pyx文件 生成pyd 文件用于 cython调用的实现
2021/03/04 Python
NICKIS.com荷兰:设计师儿童时装
2020/01/08 全球购物
药店主任岗位责任制
2014/02/10 职场文书
个人委托书范本
2014/09/13 职场文书
企业党支部工作总结2015
2015/05/21 职场文书
MySQL 原理与优化之原数据锁的应用
2022/08/14 MySQL