Pytorch反向传播中的细节-计算梯度时的默认累加操作


Posted in Python onJune 05, 2021

Pytorch反向传播计算梯度默认累加

今天学习pytorch实现简单的线性回归,发现了pytorch的反向传播时计算梯度采用的累加机制, 于是百度来一下,好多博客都说了累加机制,但是好多都没有说明这个累加机制到底会有啥影响, 所以我趁着自己练习的一个例子正好直观的看一下以及如何解决:

pytorch实现线性回归

先附上试验代码来感受一下:

torch.manual_seed(6)
lr = 0.01   # 学习率
result = []

# 创建训练数据
x = torch.rand(20, 1) * 10
y = 2 * x + (5 + torch.randn(20, 1)) 

# 构建线性回归函数
w = torch.randn((1), requires_grad=True)
b = torch.zeros((1), requires_grad=True)
# 这里是迭代过程,为了看pytorch的反向传播计算梯度的细节,我先迭代两次
for iteration in range(2):

    # 前向传播
    wx = torch.mul(w, x)
    y_pred = torch.add(wx, b)

    # 计算 MSE loss
    loss = (0.5 * (y - y_pred) ** 2).mean()
    
    # 反向传播
    loss.backward()
    
    # 这里看一下反向传播计算的梯度
    print("w.grad:", w.grad)
    print("b.grad:", b.grad)
    
    # 更新参数
    b.data.sub_(lr * b.grad)
    w.data.sub_(lr * w.grad)

上面的代码比较简单,迭代了两次, 看一下计算的梯度结果:

w.grad: tensor([-74.6261])
b.grad: tensor([-12.5532])
w.grad: tensor([-122.9075])
b.grad: tensor([-20.9364])

然后我稍微加两行代码, 就是在反向传播上面,我手动添加梯度清零操作的代码,再感受一下结果:

torch.manual_seed(6)
lr = 0.01
result = []
# 创建训练数据
x = torch.rand(20, 1) * 10
#print(x)
y = 2 * x + (5 + torch.randn(20, 1)) 
#print(y)
# 构建线性回归函数
w = torch.randn((1), requires_grad=True)
#print(w)
b = torch.zeros((1), requires_grad=True)
#print(b)
for iteration in range(2):
    # 前向传播
    wx = torch.mul(w, x)
    y_pred = torch.add(wx, b)

    # 计算 MSE loss
    loss = (0.5 * (y - y_pred) ** 2).mean()
    
    # 由于pytorch反向传播中,梯度是累加的,所以如果不想先前的梯度影响当前梯度的计算,需要手动清0
     if iteration > 0: 
        w.grad.data.zero_()
        b.grad.data.zero_()
    
    # 反向传播
    loss.backward()
    
    # 看一下梯度
    print("w.grad:", w.grad)
    print("b.grad:", b.grad)
    
    # 更新参数
    b.data.sub_(lr * b.grad)
    w.data.sub_(lr * w.grad)

w.grad: tensor([-74.6261])
b.grad: tensor([-12.5532])
w.grad: tensor([-48.2813])
b.grad: tensor([-8.3831])

从上面可以发现,pytorch在反向传播的时候,确实是默认累加上了上一次求的梯度, 如果不想让上一次的梯度影响自己本次梯度计算的话,需要手动的清零。

但是, 如果不进行手动清零的话,会有什么后果呢? 我在这次线性回归试验中,遇到的后果就是loss值反复的震荡不收敛。下面感受一下:

torch.manual_seed(6)
lr = 0.01
result = []
# 创建训练数据
x = torch.rand(20, 1) * 10
#print(x)
y = 2 * x + (5 + torch.randn(20, 1)) 
#print(y)
# 构建线性回归函数
w = torch.randn((1), requires_grad=True)
#print(w)
b = torch.zeros((1), requires_grad=True)
#print(b)

for iteration in range(1000):
    # 前向传播
    wx = torch.mul(w, x)
    y_pred = torch.add(wx, b)

    # 计算 MSE loss
    loss = (0.5 * (y - y_pred) ** 2).mean()
#     print("iteration {}: loss {}".format(iteration, loss))
    result.append(loss)
    
    # 由于pytorch反向传播中,梯度是累加的,所以如果不想先前的梯度影响当前梯度的计算,需要手动清0
    #if iteration > 0: 
    #    w.grad.data.zero_()
    #    b.grad.data.zero_()
  
    # 反向传播
    loss.backward()
 
    # 更新参数
    b.data.sub_(lr * b.grad)
    w.data.sub_(lr * w.grad)
    
    if loss.data.numpy() < 1:
        break
   plt.plot(result)

上面的代码中,我没有进行手动清零,迭代1000次, 把每一次的loss放到来result中, 然后画出图像,感受一下结果:

Pytorch反向传播中的细节-计算梯度时的默认累加操作

接下来,我把手动清零的注释打开,进行每次迭代之后的手动清零操作,得到的结果:

Pytorch反向传播中的细节-计算梯度时的默认累加操作

可以看到,这个才是理想中的反向传播求导,然后更新参数后得到的loss值的变化。

总结

这次主要是记录一下,pytorch在进行反向传播计算梯度的时候的累加机制到底是什么样子? 至于为什么采用这种机制,我也搜了一下,大部分给出的结果是这样子的:

Pytorch反向传播中的细节-计算梯度时的默认累加操作

但是如果不想累加的话,可以采用手动清零的方式,只需要在每次迭代时加上即可

w.grad.data.zero_()
b.grad.data.zero_()

另外, 在搜索资料的时候,在一篇博客上看到两个不错的线性回归时pytorch的计算图在这里借用一下:

Pytorch反向传播中的细节-计算梯度时的默认累加操作
Pytorch反向传播中的细节-计算梯度时的默认累加操作

以上为个人经验,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
python根据出生日期获得年龄的方法
Mar 31 Python
教大家玩转Python字符串处理的七种技巧
Mar 31 Python
python 3利用Dlib 19.7实现摄像头人脸检测特征点标定
Feb 26 Python
完美解决Python matplotlib绘图时汉字显示不正常的问题
Jan 29 Python
Python 字符串处理特殊空格\xc2\xa0\t\n Non-breaking space
Feb 23 Python
浅析python表达式4+0.5值的数据类型
Feb 26 Python
python str字符串转uuid实例
Mar 03 Python
Django+RestFramework API接口及接口文档并返回json数据操作
Jul 12 Python
python中threading和queue库实现多线程编程
Feb 06 Python
python 批量将中文名转换为拼音
Feb 07 Python
Python实现GIF动图以及视频卡通化详解
Dec 06 Python
Matplotlib绘制条形图的方法你知道吗
Mar 21 Python
pytorch 梯度NAN异常值的解决方案
Jun 05 #Python
pytorch 权重weight 与 梯度grad 可视化操作
PyTorch 如何检查模型梯度是否可导
python-opencv 中值滤波{cv2.medianBlur(src, ksize)}的用法
解决Pytorch修改预训练模型时遇到key不匹配的情况
Jun 05 #Python
pytorch 预训练模型读取修改相关参数的填坑问题
Jun 05 #Python
解决pytorch 损失函数中输入输出不匹配的问题
Jun 05 #Python
You might like
PHP中对各种加密算法、Hash算法的速度测试对比代码
2014/07/08 PHP
配置php.ini实现PHP文件上传功能
2014/11/27 PHP
php面向对象中static静态属性和静态方法的调用
2015/02/08 PHP
CodeIgniter基于Email类发邮件的方法
2016/03/29 PHP
php面向对象基础详解【星际争霸游戏案例】
2020/01/23 PHP
动态表格Table类的实现
2009/08/26 Javascript
javascript 定义初始化数组函数
2009/09/07 Javascript
js动态添加onload、onresize、onscroll事件(另类方法)
2012/12/26 Javascript
Js判断CSS文件加载完毕的具体实现
2014/01/17 Javascript
js写出遮罩层登陆框和对联广告并自动跟随滚动条滚动
2014/04/29 Javascript
jQuery学习笔记之toArray()
2014/06/09 Javascript
jQuery实现企业网站横幅焦点图切换功能实例
2015/04/30 Javascript
JavaScript中的getMilliseconds()方法使用详解
2015/06/10 Javascript
AngularJS之依赖注入模拟实现
2016/08/19 Javascript
真正好用的js验证上传文件大小的简单方法
2016/10/27 Javascript
用file标签实现多图文件上传预览
2017/02/14 Javascript
javascript数组去重常用方法实例分析
2017/04/11 Javascript
Angular 4依赖注入学习教程之简介(一)
2017/06/04 Javascript
JS实现动态添加外部js、css到head标签的方法
2019/06/05 Javascript
在NodeJs中使用node-schedule增加定时器任务的方法
2020/06/08 NodeJs
python学习之第三方包安装方法(两种方法)
2015/07/30 Python
解决python2.7用pip安装包时出现错误的问题
2017/01/23 Python
关于Python如何避免循环导入问题详解
2017/09/14 Python
python 剪切移动文件的实现代码
2018/08/02 Python
通过python的matplotlib包将Tensorflow数据进行可视化的方法
2019/01/09 Python
结合OpenCV与TensorFlow进行人脸识别的实现
2019/10/10 Python
利用keras加载训练好的.H5文件,并实现预测图片
2020/01/24 Python
时尚的CSS3进度条效果
2012/02/22 HTML / CSS
巴西化妆品商店:Lojas Rede
2019/07/26 全球购物
俄罗斯电子产品、计算机和家用电器购物网站:OLDI
2019/10/27 全球购物
Muziker英国:中欧最大的音乐家商店
2020/02/05 全球购物
校企合作协议书
2014/04/16 职场文书
计算机软件专业求职信
2014/06/10 职场文书
博士生导师推荐信
2014/07/08 职场文书
运动会广播稿200字
2015/08/19 职场文书
html粘性页脚的具体使用
2022/01/18 HTML / CSS