Pytorch反向传播中的细节-计算梯度时的默认累加操作


Posted in Python onJune 05, 2021

Pytorch反向传播计算梯度默认累加

今天学习pytorch实现简单的线性回归,发现了pytorch的反向传播时计算梯度采用的累加机制, 于是百度来一下,好多博客都说了累加机制,但是好多都没有说明这个累加机制到底会有啥影响, 所以我趁着自己练习的一个例子正好直观的看一下以及如何解决:

pytorch实现线性回归

先附上试验代码来感受一下:

torch.manual_seed(6)
lr = 0.01   # 学习率
result = []

# 创建训练数据
x = torch.rand(20, 1) * 10
y = 2 * x + (5 + torch.randn(20, 1)) 

# 构建线性回归函数
w = torch.randn((1), requires_grad=True)
b = torch.zeros((1), requires_grad=True)
# 这里是迭代过程,为了看pytorch的反向传播计算梯度的细节,我先迭代两次
for iteration in range(2):

    # 前向传播
    wx = torch.mul(w, x)
    y_pred = torch.add(wx, b)

    # 计算 MSE loss
    loss = (0.5 * (y - y_pred) ** 2).mean()
    
    # 反向传播
    loss.backward()
    
    # 这里看一下反向传播计算的梯度
    print("w.grad:", w.grad)
    print("b.grad:", b.grad)
    
    # 更新参数
    b.data.sub_(lr * b.grad)
    w.data.sub_(lr * w.grad)

上面的代码比较简单,迭代了两次, 看一下计算的梯度结果:

w.grad: tensor([-74.6261])
b.grad: tensor([-12.5532])
w.grad: tensor([-122.9075])
b.grad: tensor([-20.9364])

然后我稍微加两行代码, 就是在反向传播上面,我手动添加梯度清零操作的代码,再感受一下结果:

torch.manual_seed(6)
lr = 0.01
result = []
# 创建训练数据
x = torch.rand(20, 1) * 10
#print(x)
y = 2 * x + (5 + torch.randn(20, 1)) 
#print(y)
# 构建线性回归函数
w = torch.randn((1), requires_grad=True)
#print(w)
b = torch.zeros((1), requires_grad=True)
#print(b)
for iteration in range(2):
    # 前向传播
    wx = torch.mul(w, x)
    y_pred = torch.add(wx, b)

    # 计算 MSE loss
    loss = (0.5 * (y - y_pred) ** 2).mean()
    
    # 由于pytorch反向传播中,梯度是累加的,所以如果不想先前的梯度影响当前梯度的计算,需要手动清0
     if iteration > 0: 
        w.grad.data.zero_()
        b.grad.data.zero_()
    
    # 反向传播
    loss.backward()
    
    # 看一下梯度
    print("w.grad:", w.grad)
    print("b.grad:", b.grad)
    
    # 更新参数
    b.data.sub_(lr * b.grad)
    w.data.sub_(lr * w.grad)

w.grad: tensor([-74.6261])
b.grad: tensor([-12.5532])
w.grad: tensor([-48.2813])
b.grad: tensor([-8.3831])

从上面可以发现,pytorch在反向传播的时候,确实是默认累加上了上一次求的梯度, 如果不想让上一次的梯度影响自己本次梯度计算的话,需要手动的清零。

但是, 如果不进行手动清零的话,会有什么后果呢? 我在这次线性回归试验中,遇到的后果就是loss值反复的震荡不收敛。下面感受一下:

torch.manual_seed(6)
lr = 0.01
result = []
# 创建训练数据
x = torch.rand(20, 1) * 10
#print(x)
y = 2 * x + (5 + torch.randn(20, 1)) 
#print(y)
# 构建线性回归函数
w = torch.randn((1), requires_grad=True)
#print(w)
b = torch.zeros((1), requires_grad=True)
#print(b)

for iteration in range(1000):
    # 前向传播
    wx = torch.mul(w, x)
    y_pred = torch.add(wx, b)

    # 计算 MSE loss
    loss = (0.5 * (y - y_pred) ** 2).mean()
#     print("iteration {}: loss {}".format(iteration, loss))
    result.append(loss)
    
    # 由于pytorch反向传播中,梯度是累加的,所以如果不想先前的梯度影响当前梯度的计算,需要手动清0
    #if iteration > 0: 
    #    w.grad.data.zero_()
    #    b.grad.data.zero_()
  
    # 反向传播
    loss.backward()
 
    # 更新参数
    b.data.sub_(lr * b.grad)
    w.data.sub_(lr * w.grad)
    
    if loss.data.numpy() < 1:
        break
   plt.plot(result)

上面的代码中,我没有进行手动清零,迭代1000次, 把每一次的loss放到来result中, 然后画出图像,感受一下结果:

Pytorch反向传播中的细节-计算梯度时的默认累加操作

接下来,我把手动清零的注释打开,进行每次迭代之后的手动清零操作,得到的结果:

Pytorch反向传播中的细节-计算梯度时的默认累加操作

可以看到,这个才是理想中的反向传播求导,然后更新参数后得到的loss值的变化。

总结

这次主要是记录一下,pytorch在进行反向传播计算梯度的时候的累加机制到底是什么样子? 至于为什么采用这种机制,我也搜了一下,大部分给出的结果是这样子的:

Pytorch反向传播中的细节-计算梯度时的默认累加操作

但是如果不想累加的话,可以采用手动清零的方式,只需要在每次迭代时加上即可

w.grad.data.zero_()
b.grad.data.zero_()

另外, 在搜索资料的时候,在一篇博客上看到两个不错的线性回归时pytorch的计算图在这里借用一下:

Pytorch反向传播中的细节-计算梯度时的默认累加操作
Pytorch反向传播中的细节-计算梯度时的默认累加操作

以上为个人经验,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
Python yield 小结和实例
Apr 25 Python
教你如何将 Sublime 3 打造成 Python/Django IDE开发利器
Jul 04 Python
Python读取Excel的方法实例分析
Jul 11 Python
numpy返回array中元素的index方法
Jun 27 Python
利用python-docx模块写批量生日邀请函
Aug 26 Python
Python shelve模块实现解析
Aug 28 Python
Python集成开发工具Pycharm的安装和使用详解
Mar 18 Python
详解用Pytest+Allure生成漂亮的HTML图形化测试报告
Mar 31 Python
PyQt5多线程防卡死和多窗口用法的实现
Sep 15 Python
详解python算法常用技巧与内置库
Oct 17 Python
Python 图片处理库exifread详解
Feb 25 Python
如何解决.cuda()加载用时很长的问题
May 24 Python
pytorch 梯度NAN异常值的解决方案
Jun 05 #Python
pytorch 权重weight 与 梯度grad 可视化操作
PyTorch 如何检查模型梯度是否可导
python-opencv 中值滤波{cv2.medianBlur(src, ksize)}的用法
解决Pytorch修改预训练模型时遇到key不匹配的情况
Jun 05 #Python
pytorch 预训练模型读取修改相关参数的填坑问题
Jun 05 #Python
解决pytorch 损失函数中输入输出不匹配的问题
Jun 05 #Python
You might like
在windows平台上构建自己的PHP实现方法(仅适用于php5.2)
2013/07/05 PHP
php json相关函数用法示例
2017/03/28 PHP
PHP实践教程之过滤、验证、转义与密码详解
2017/07/24 PHP
PHP切割汉字的常用方法实例总结
2019/04/27 PHP
Js之软键盘实现(js源码)
2007/01/30 Javascript
jQuery trigger()方法用法介绍
2015/01/13 Javascript
Javascript实现多彩雪花从天降散落效果的方法
2015/02/02 Javascript
JavaScript检测上传文件大小的方法
2015/07/22 Javascript
易操作的jQuery表单提示插件
2015/12/01 Javascript
js 定义对象数组(结合)多维数组方法
2016/07/27 Javascript
javascript鼠标滑过显示二级菜单特效
2020/11/18 Javascript
bootstrap-table实现服务器分页的示例 (spring 后台)
2017/09/01 Javascript
JS实现的简单表单验证功能完整实例
2017/10/14 Javascript
代码详解JS操作剪贴板
2018/02/11 Javascript
vue 组件销毁并重置的实现
2020/01/13 Javascript
vue项目实现减少app.js和vender.js的体积操作
2020/11/12 Javascript
Python中DJANGO简单测试实例
2015/05/11 Python
使用Python神器对付12306变态验证码
2016/01/05 Python
python妙用之编码的转换详解
2017/04/21 Python
人机交互程序 python实现人机对话
2017/11/14 Python
python实现集中式的病毒扫描功能详解
2019/07/09 Python
python实现将字符串中的数字提取出来然后求和
2020/04/02 Python
python实现邮件循环自动发件功能
2020/09/11 Python
python 生成器需注意的小问题
2020/09/29 Python
Python try except finally资源回收的实现
2021/01/25 Python
如何用border-image实现文字气泡边框的示例代码
2020/01/21 HTML / CSS
土耳其时尚潮流在线购物网站:Trendyol
2017/10/10 全球购物
机电工程专业应届生求职信
2013/10/03 职场文书
个人实用的自我评价范文
2013/11/23 职场文书
工厂保安员岗位职责
2014/01/31 职场文书
服装创业计划书范文
2014/02/05 职场文书
红领巾广播站广播稿(3篇)
2014/09/20 职场文书
三严三实对照检查材料
2014/09/22 职场文书
员工自我评价范文
2015/03/11 职场文书
三八红旗手主要事迹材料
2015/11/04 职场文书
Windows Server 2019 配置远程控制以及管理方法
2022/04/28 Servers