Pytorch反向传播中的细节-计算梯度时的默认累加操作


Posted in Python onJune 05, 2021

Pytorch反向传播计算梯度默认累加

今天学习pytorch实现简单的线性回归,发现了pytorch的反向传播时计算梯度采用的累加机制, 于是百度来一下,好多博客都说了累加机制,但是好多都没有说明这个累加机制到底会有啥影响, 所以我趁着自己练习的一个例子正好直观的看一下以及如何解决:

pytorch实现线性回归

先附上试验代码来感受一下:

torch.manual_seed(6)
lr = 0.01   # 学习率
result = []

# 创建训练数据
x = torch.rand(20, 1) * 10
y = 2 * x + (5 + torch.randn(20, 1)) 

# 构建线性回归函数
w = torch.randn((1), requires_grad=True)
b = torch.zeros((1), requires_grad=True)
# 这里是迭代过程,为了看pytorch的反向传播计算梯度的细节,我先迭代两次
for iteration in range(2):

    # 前向传播
    wx = torch.mul(w, x)
    y_pred = torch.add(wx, b)

    # 计算 MSE loss
    loss = (0.5 * (y - y_pred) ** 2).mean()
    
    # 反向传播
    loss.backward()
    
    # 这里看一下反向传播计算的梯度
    print("w.grad:", w.grad)
    print("b.grad:", b.grad)
    
    # 更新参数
    b.data.sub_(lr * b.grad)
    w.data.sub_(lr * w.grad)

上面的代码比较简单,迭代了两次, 看一下计算的梯度结果:

w.grad: tensor([-74.6261])
b.grad: tensor([-12.5532])
w.grad: tensor([-122.9075])
b.grad: tensor([-20.9364])

然后我稍微加两行代码, 就是在反向传播上面,我手动添加梯度清零操作的代码,再感受一下结果:

torch.manual_seed(6)
lr = 0.01
result = []
# 创建训练数据
x = torch.rand(20, 1) * 10
#print(x)
y = 2 * x + (5 + torch.randn(20, 1)) 
#print(y)
# 构建线性回归函数
w = torch.randn((1), requires_grad=True)
#print(w)
b = torch.zeros((1), requires_grad=True)
#print(b)
for iteration in range(2):
    # 前向传播
    wx = torch.mul(w, x)
    y_pred = torch.add(wx, b)

    # 计算 MSE loss
    loss = (0.5 * (y - y_pred) ** 2).mean()
    
    # 由于pytorch反向传播中,梯度是累加的,所以如果不想先前的梯度影响当前梯度的计算,需要手动清0
     if iteration > 0: 
        w.grad.data.zero_()
        b.grad.data.zero_()
    
    # 反向传播
    loss.backward()
    
    # 看一下梯度
    print("w.grad:", w.grad)
    print("b.grad:", b.grad)
    
    # 更新参数
    b.data.sub_(lr * b.grad)
    w.data.sub_(lr * w.grad)

w.grad: tensor([-74.6261])
b.grad: tensor([-12.5532])
w.grad: tensor([-48.2813])
b.grad: tensor([-8.3831])

从上面可以发现,pytorch在反向传播的时候,确实是默认累加上了上一次求的梯度, 如果不想让上一次的梯度影响自己本次梯度计算的话,需要手动的清零。

但是, 如果不进行手动清零的话,会有什么后果呢? 我在这次线性回归试验中,遇到的后果就是loss值反复的震荡不收敛。下面感受一下:

torch.manual_seed(6)
lr = 0.01
result = []
# 创建训练数据
x = torch.rand(20, 1) * 10
#print(x)
y = 2 * x + (5 + torch.randn(20, 1)) 
#print(y)
# 构建线性回归函数
w = torch.randn((1), requires_grad=True)
#print(w)
b = torch.zeros((1), requires_grad=True)
#print(b)

for iteration in range(1000):
    # 前向传播
    wx = torch.mul(w, x)
    y_pred = torch.add(wx, b)

    # 计算 MSE loss
    loss = (0.5 * (y - y_pred) ** 2).mean()
#     print("iteration {}: loss {}".format(iteration, loss))
    result.append(loss)
    
    # 由于pytorch反向传播中,梯度是累加的,所以如果不想先前的梯度影响当前梯度的计算,需要手动清0
    #if iteration > 0: 
    #    w.grad.data.zero_()
    #    b.grad.data.zero_()
  
    # 反向传播
    loss.backward()
 
    # 更新参数
    b.data.sub_(lr * b.grad)
    w.data.sub_(lr * w.grad)
    
    if loss.data.numpy() < 1:
        break
   plt.plot(result)

上面的代码中,我没有进行手动清零,迭代1000次, 把每一次的loss放到来result中, 然后画出图像,感受一下结果:

Pytorch反向传播中的细节-计算梯度时的默认累加操作

接下来,我把手动清零的注释打开,进行每次迭代之后的手动清零操作,得到的结果:

Pytorch反向传播中的细节-计算梯度时的默认累加操作

可以看到,这个才是理想中的反向传播求导,然后更新参数后得到的loss值的变化。

总结

这次主要是记录一下,pytorch在进行反向传播计算梯度的时候的累加机制到底是什么样子? 至于为什么采用这种机制,我也搜了一下,大部分给出的结果是这样子的:

Pytorch反向传播中的细节-计算梯度时的默认累加操作

但是如果不想累加的话,可以采用手动清零的方式,只需要在每次迭代时加上即可

w.grad.data.zero_()
b.grad.data.zero_()

另外, 在搜索资料的时候,在一篇博客上看到两个不错的线性回归时pytorch的计算图在这里借用一下:

Pytorch反向传播中的细节-计算梯度时的默认累加操作
Pytorch反向传播中的细节-计算梯度时的默认累加操作

以上为个人经验,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
python中的reduce内建函数使用方法指南
Aug 31 Python
python生成随机mac地址的方法
Mar 16 Python
Python+request+unittest实现接口测试框架集成实例
Mar 16 Python
用Python写一段用户登录的程序代码
Apr 22 Python
Python中关键字global和nonlocal的区别详解
Sep 03 Python
python读取各种文件数据方法解析
Dec 29 Python
python3 中的字符串(单引号、双引号、三引号)以及字符串与数字的运算
Jul 18 Python
python爬虫 正则表达式解析
Sep 28 Python
python根据文本生成词云图代码实例
Nov 15 Python
Python使用jupyter notebook查看ipynb文件过程解析
Jun 02 Python
python利用蒙版抠图(使用PIL.Image和cv2)输出透明背景图
Aug 04 Python
pandas针对excel处理的实现
Jan 15 Python
pytorch 梯度NAN异常值的解决方案
Jun 05 #Python
pytorch 权重weight 与 梯度grad 可视化操作
PyTorch 如何检查模型梯度是否可导
python-opencv 中值滤波{cv2.medianBlur(src, ksize)}的用法
解决Pytorch修改预训练模型时遇到key不匹配的情况
Jun 05 #Python
pytorch 预训练模型读取修改相关参数的填坑问题
Jun 05 #Python
解决pytorch 损失函数中输入输出不匹配的问题
Jun 05 #Python
You might like
php park、unpark、ord 函数使用方法(二进制流接口应用实例)
2010/10/19 PHP
利用PHP扩展vld查看PHP opcode操作步骤
2013/03/04 PHP
PHP-FPM运行状态的实时查看及监控详解
2016/11/18 PHP
PHP实现阿里大鱼短信验证的实例代码
2017/07/10 PHP
PHP中OpenSSL加密问题整理
2017/12/14 PHP
PHP registerXPathNamespace()函数讲解
2019/02/03 PHP
拖动Html元素集合 Drag and Drop any item
2006/12/22 Javascript
JavaScript获得选中文本内容的方法
2008/12/02 Javascript
javascript 面向对象编程基础:继承
2009/08/21 Javascript
传智播客学习之JavaScript基础篇
2009/11/13 Javascript
js 居中漂浮广告
2010/03/21 Javascript
基于jsTree的无限级树JSON数据的转换代码
2010/07/27 Javascript
与jquery serializeArray()一起使用的函数,主要来方便提交表单
2011/01/31 Javascript
JS cookie中文乱码解决方法
2014/01/28 Javascript
javascript实现博客园页面右下角返回顶部按钮
2015/02/22 Javascript
jQuery进行组件开发完整实例
2015/12/15 Javascript
javascript实现2016新年版日历
2016/01/25 Javascript
javascript实现可键盘控制的抽奖系统
2016/03/10 Javascript
JavaScript导航脚本判断当前导航
2016/07/12 Javascript
BootStrap给table表格的每一行添加一个按钮事件
2017/09/07 Javascript
jquery实现搜索框功能实例详解
2018/07/23 jQuery
vue 移动端适配方案详解
2018/11/15 Javascript
vue进入页面时滚动条始终在底部代码实例
2019/03/26 Javascript
react-native聊天室|RN版聊天App仿微信实例|RN仿微信界面
2019/11/12 Javascript
Vue使用自定义指令实现拖拽行为实例分析
2020/06/06 Javascript
html中创建并调用vue组件的几种方法汇总
2020/11/17 Javascript
VUE+Element实现增删改查的示例源码
2020/11/23 Vue.js
Python交换变量
2008/09/06 Python
numpy 对矩阵中Nan的处理:采用平均值的方法
2018/10/30 Python
Python 使用PyQt5 完成选择文件或目录的对话框方法
2019/06/27 Python
Django-xadmin+rule对象级权限的实现方式
2020/03/30 Python
如何高效率的查找一个月以内的数据
2012/04/15 面试题
关于VPN
2012/06/10 面试题
2016年国培心得体会及反思
2016/01/13 职场文书
使用numpy nonzero 找出非0元素
2021/05/14 Python
Oracle用户管理及赋权
2022/04/24 Oracle