Pytorch反向传播中的细节-计算梯度时的默认累加操作


Posted in Python onJune 05, 2021

Pytorch反向传播计算梯度默认累加

今天学习pytorch实现简单的线性回归,发现了pytorch的反向传播时计算梯度采用的累加机制, 于是百度来一下,好多博客都说了累加机制,但是好多都没有说明这个累加机制到底会有啥影响, 所以我趁着自己练习的一个例子正好直观的看一下以及如何解决:

pytorch实现线性回归

先附上试验代码来感受一下:

torch.manual_seed(6)
lr = 0.01   # 学习率
result = []

# 创建训练数据
x = torch.rand(20, 1) * 10
y = 2 * x + (5 + torch.randn(20, 1)) 

# 构建线性回归函数
w = torch.randn((1), requires_grad=True)
b = torch.zeros((1), requires_grad=True)
# 这里是迭代过程,为了看pytorch的反向传播计算梯度的细节,我先迭代两次
for iteration in range(2):

    # 前向传播
    wx = torch.mul(w, x)
    y_pred = torch.add(wx, b)

    # 计算 MSE loss
    loss = (0.5 * (y - y_pred) ** 2).mean()
    
    # 反向传播
    loss.backward()
    
    # 这里看一下反向传播计算的梯度
    print("w.grad:", w.grad)
    print("b.grad:", b.grad)
    
    # 更新参数
    b.data.sub_(lr * b.grad)
    w.data.sub_(lr * w.grad)

上面的代码比较简单,迭代了两次, 看一下计算的梯度结果:

w.grad: tensor([-74.6261])
b.grad: tensor([-12.5532])
w.grad: tensor([-122.9075])
b.grad: tensor([-20.9364])

然后我稍微加两行代码, 就是在反向传播上面,我手动添加梯度清零操作的代码,再感受一下结果:

torch.manual_seed(6)
lr = 0.01
result = []
# 创建训练数据
x = torch.rand(20, 1) * 10
#print(x)
y = 2 * x + (5 + torch.randn(20, 1)) 
#print(y)
# 构建线性回归函数
w = torch.randn((1), requires_grad=True)
#print(w)
b = torch.zeros((1), requires_grad=True)
#print(b)
for iteration in range(2):
    # 前向传播
    wx = torch.mul(w, x)
    y_pred = torch.add(wx, b)

    # 计算 MSE loss
    loss = (0.5 * (y - y_pred) ** 2).mean()
    
    # 由于pytorch反向传播中,梯度是累加的,所以如果不想先前的梯度影响当前梯度的计算,需要手动清0
     if iteration > 0: 
        w.grad.data.zero_()
        b.grad.data.zero_()
    
    # 反向传播
    loss.backward()
    
    # 看一下梯度
    print("w.grad:", w.grad)
    print("b.grad:", b.grad)
    
    # 更新参数
    b.data.sub_(lr * b.grad)
    w.data.sub_(lr * w.grad)

w.grad: tensor([-74.6261])
b.grad: tensor([-12.5532])
w.grad: tensor([-48.2813])
b.grad: tensor([-8.3831])

从上面可以发现,pytorch在反向传播的时候,确实是默认累加上了上一次求的梯度, 如果不想让上一次的梯度影响自己本次梯度计算的话,需要手动的清零。

但是, 如果不进行手动清零的话,会有什么后果呢? 我在这次线性回归试验中,遇到的后果就是loss值反复的震荡不收敛。下面感受一下:

torch.manual_seed(6)
lr = 0.01
result = []
# 创建训练数据
x = torch.rand(20, 1) * 10
#print(x)
y = 2 * x + (5 + torch.randn(20, 1)) 
#print(y)
# 构建线性回归函数
w = torch.randn((1), requires_grad=True)
#print(w)
b = torch.zeros((1), requires_grad=True)
#print(b)

for iteration in range(1000):
    # 前向传播
    wx = torch.mul(w, x)
    y_pred = torch.add(wx, b)

    # 计算 MSE loss
    loss = (0.5 * (y - y_pred) ** 2).mean()
#     print("iteration {}: loss {}".format(iteration, loss))
    result.append(loss)
    
    # 由于pytorch反向传播中,梯度是累加的,所以如果不想先前的梯度影响当前梯度的计算,需要手动清0
    #if iteration > 0: 
    #    w.grad.data.zero_()
    #    b.grad.data.zero_()
  
    # 反向传播
    loss.backward()
 
    # 更新参数
    b.data.sub_(lr * b.grad)
    w.data.sub_(lr * w.grad)
    
    if loss.data.numpy() < 1:
        break
   plt.plot(result)

上面的代码中,我没有进行手动清零,迭代1000次, 把每一次的loss放到来result中, 然后画出图像,感受一下结果:

Pytorch反向传播中的细节-计算梯度时的默认累加操作

接下来,我把手动清零的注释打开,进行每次迭代之后的手动清零操作,得到的结果:

Pytorch反向传播中的细节-计算梯度时的默认累加操作

可以看到,这个才是理想中的反向传播求导,然后更新参数后得到的loss值的变化。

总结

这次主要是记录一下,pytorch在进行反向传播计算梯度的时候的累加机制到底是什么样子? 至于为什么采用这种机制,我也搜了一下,大部分给出的结果是这样子的:

Pytorch反向传播中的细节-计算梯度时的默认累加操作

但是如果不想累加的话,可以采用手动清零的方式,只需要在每次迭代时加上即可

w.grad.data.zero_()
b.grad.data.zero_()

另外, 在搜索资料的时候,在一篇博客上看到两个不错的线性回归时pytorch的计算图在这里借用一下:

Pytorch反向传播中的细节-计算梯度时的默认累加操作
Pytorch反向传播中的细节-计算梯度时的默认累加操作

以上为个人经验,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
Python3访问并下载网页内容的方法
Jul 28 Python
Python使用time模块实现指定时间触发器示例
May 18 Python
Python实现PS图像调整黑白效果示例
Jan 25 Python
对python中Json与object转化的方法详解
Dec 31 Python
Django框架组成结构、基本概念与文件功能分析
Jul 30 Python
django自带serializers序列化返回指定字段的方法
Aug 21 Python
python多线程扫描端口(线程池)
Sep 04 Python
python动态视频下载器的实现方法
Sep 16 Python
opencv3/C++ 平面对象识别&amp;透视变换方式
Dec 11 Python
解决torch.autograd.backward中的参数问题
Jan 07 Python
Python列表切片常用操作实例解析
Mar 10 Python
python实现三壶谜题的示例详解
Nov 02 Python
pytorch 梯度NAN异常值的解决方案
Jun 05 #Python
pytorch 权重weight 与 梯度grad 可视化操作
PyTorch 如何检查模型梯度是否可导
python-opencv 中值滤波{cv2.medianBlur(src, ksize)}的用法
解决Pytorch修改预训练模型时遇到key不匹配的情况
Jun 05 #Python
pytorch 预训练模型读取修改相关参数的填坑问题
Jun 05 #Python
解决pytorch 损失函数中输入输出不匹配的问题
Jun 05 #Python
You might like
php 一元分词算法
2009/11/30 PHP
ThinkPHP 404页面的设置方法
2015/01/14 PHP
JQuery 入门实例1
2009/06/25 Javascript
jQuery 判断页面元素是否存在的代码
2009/08/14 Javascript
JavaScript中常见陷阱小结
2010/04/27 Javascript
JavaScript中也使用$美元符号来代替document.getElementById
2010/06/19 Javascript
JQERY limittext 插件0.2版(长内容限制显示)
2010/08/27 Javascript
javascript使用eval或者new Function进行语法检查
2010/10/16 Javascript
基于jquery的从一个页面跳转到另一个页面的指定位置的实现代码(带平滑移动的效果)
2011/05/24 Javascript
JavaScript中Number.MAX_VALUE属性的使用方法
2015/06/04 Javascript
jquery+json实现动态商品内容展示的方法
2016/01/14 Javascript
dul无法加载bootstrap实现unload table/user恢复
2016/09/29 Javascript
jquery 标签 隔若干行加空白或者加虚线的方法
2016/12/07 Javascript
js仿QQ邮箱收件人选择与搜索功能
2017/02/10 Javascript
jQueryMobile之窗体长内容的缺陷与解决方法实例分析
2017/09/20 jQuery
ExtJs整合Echarts的示例代码
2018/02/27 Javascript
vant IndexBar实现的城市列表的示例代码
2019/11/20 Javascript
Vue组件通信$attrs、$listeners实现原理解析
2020/09/03 Javascript
vue中解决chrome浏览器自动播放音频和MP3语音打包到线上的实现方法
2020/10/09 Javascript
Python实现一个简单的MySQL类
2015/01/07 Python
python中as用法实例分析
2015/04/30 Python
Python编程实现粒子群算法(PSO)详解
2017/11/13 Python
PyTorch读取Cifar数据集并显示图片的实例讲解
2018/07/27 Python
python-pyinstaller、打包后获取路径的实例
2019/06/10 Python
python二维码操作:对QRCode和MyQR入门详解
2019/06/24 Python
Python使用pyautocad+openpyxl处理cad文件示例
2019/07/11 Python
Python Django简单实现session登录注销过程详解
2019/08/06 Python
使用Python三角函数公式计算三角形的夹角案例
2020/04/15 Python
html5定位获取当前位置并在百度地图上显示
2014/08/22 HTML / CSS
越南母婴用品购物网站:Kids Plaza
2020/04/09 全球购物
创建卫生先进单位实施方案
2014/03/10 职场文书
英文请假条
2014/04/11 职场文书
信息工作经验交流材料
2014/05/28 职场文书
2014年禁毒工作总结
2014/11/24 职场文书
心灵捕手观后感
2015/06/02 职场文书
php双向队列实例讲解
2021/11/17 PHP