PyTorch中model.zero_grad()和optimizer.zero_grad()用法


Posted in Python onJune 24, 2020

废话不多说,直接上代码吧~

model.zero_grad()
optimizer.zero_grad()

首先,这两种方式都是把模型中参数的梯度设为0

当optimizer = optim.Optimizer(net.parameters())时,二者等效,其中Optimizer可以是Adam、SGD等优化器

def zero_grad(self):
 """Sets gradients of all model parameters to zero."""
 for p in self.parameters():
  if p.grad is not None:
  p.grad.data.zero_()

补充知识:Pytorch中的optimizer.zero_grad和loss和net.backward和optimizer.step的理解

引言

一般训练神经网络,总是逃不开optimizer.zero_grad之后是loss(后面有的时候还会写forward,看你网络怎么写了)之后是是net.backward之后是optimizer.step的这个过程。

real_a, real_b = batch[0].to(device), batch[1].to(device)

fake_b = net_g(real_a)
optimizer_d.zero_grad()

# 判别器对虚假数据进行训练
fake_ab = torch.cat((real_a, fake_b), 1)
pred_fake = net_d.forward(fake_ab.detach())
loss_d_fake = criterionGAN(pred_fake, False)

# 判别器对真实数据进行训练
real_ab = torch.cat((real_a, real_b), 1)
pred_real = net_d.forward(real_ab)
loss_d_real = criterionGAN(pred_real, True)

# 判别器损失
loss_d = (loss_d_fake + loss_d_real) * 0.5

loss_d.backward()
optimizer_d.step()

上面这是一段cGAN的判别器训练过程。标题中所涉及到的这些方法,其实整个神经网络的参数更新过程(特别是反向传播),具体是怎么操作的,我们一起来探讨一下。

参数更新和反向传播

PyTorch中model.zero_grad()和optimizer.zero_grad()用法

上图为一个简单的梯度下降示意图。比如以SGD为例,是算一个batch计算一次梯度,然后进行一次梯度更新。这里梯度值就是对应偏导数的计算结果。显然,我们进行下一次batch梯度计算的时候,前一个batch的梯度计算结果,没有保留的必要了。所以在下一次梯度更新的时候,先使用optimizer.zero_grad把梯度信息设置为0。

我们使用loss来定义损失函数,是要确定优化的目标是什么,然后以目标为头,才可以进行链式法则和反向传播。

调用loss.backward方法时候,Pytorch的autograd就会自动沿着计算图反向传播,计算每一个叶子节点的梯度(如果某一个变量是由用户创建的,则它为叶子节点)。使用该方法,可以计算链式法则求导之后计算的结果值。

optimizer.step用来更新参数,就是图片中下半部分的w和b的参数更新操作。

以上这篇PyTorch中model.zero_grad()和optimizer.zero_grad()用法就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
python实现代码行数统计示例分享
Feb 10 Python
python使用SMTP发送qq或sina邮件
Oct 21 Python
python 实现数组list 添加、修改、删除的方法
Apr 04 Python
CentOS7下python3.7.0安装教程
Jul 30 Python
python ddt数据驱动最简实例代码
Feb 22 Python
基于Python的ModbusTCP客户端实现详解
Jul 13 Python
使用python telnetlib批量备份交换机配置的方法
Jul 25 Python
Python企业编码生成系统之主程序模块设计详解
Jul 26 Python
python tkinter实现彩球碰撞屏保
Jul 30 Python
pytorch梯度剪裁方式
Feb 04 Python
python实现简单飞行棋
Feb 06 Python
一文带你掌握Pyecharts地理数据可视化的方法
Feb 06 Python
Pytorch实现将模型的所有参数的梯度清0
Jun 24 #Python
你需要学会的8个Python列表技巧
Jun 24 #Python
pytorch实现查看当前学习率
Jun 24 #Python
在pytorch中动态调整优化器的学习率方式
Jun 24 #Python
CentOS 7如何实现定时执行python脚本
Jun 24 #Python
python tkiner实现 一个小小的图片翻页功能的示例代码
Jun 24 #Python
在tensorflow实现直接读取网络的参数(weight and bias)的值
Jun 24 #Python
You might like
mysql From_unixtime及UNIX_TIMESTAMP及DATE_FORMAT日期函数
2010/03/21 PHP
ThinkPHP3.1查询语言详解
2014/06/19 PHP
PHP的几个常用加密函数
2016/02/03 PHP
Yii框架实现多数据库配置和操作的方法
2017/05/25 PHP
使用jquery为table动态添加行的实现代码
2011/03/30 Javascript
IE事件对象(The Internet Explorer Event Object)
2012/06/27 Javascript
JavaScript自执行闭包的小例子
2013/06/29 Javascript
js取模(求余数)隔行变色
2014/05/15 Javascript
原生js实现移动端瀑布流式代码示例
2015/12/18 Javascript
JavaScript动态生成二维码图片
2016/04/20 Javascript
JS+Canvas 实现下雨下雪效果
2016/05/18 Javascript
javascript另类方法实现htmlencode()与htmldecode()函数实例分析
2016/11/17 Javascript
微信小程序 定位到当前城市实现实例代码
2017/02/23 Javascript
Js自动截取字符串长度,添加省略号(……)的实现方法
2017/03/06 Javascript
基于JavaScript实现新增内容滚动播放效果附完整代码
2017/08/24 Javascript
利用JavaScript缓存远程窃取Wi-Fi密码的思路详解
2018/11/05 Javascript
微信小程序学习笔记之跳转页面、传递参数获得数据操作图文详解
2019/03/28 Javascript
javascript设计模式 ? 单例模式原理与应用实例分析
2020/04/09 Javascript
Python+django实现文件下载
2016/01/17 Python
python矩阵/字典实现最短路径算法
2019/01/17 Python
python中的协程深入理解
2019/06/10 Python
Python实现Mysql数据统计及numpy统计函数
2019/07/15 Python
python实现按关键字筛选日志文件
2019/12/24 Python
jupyter 使用Pillow包显示图像时inline显示方式
2020/04/24 Python
Python建造者模式案例运行原理解析
2020/06/29 Python
护理学毕业生自荐信
2013/10/02 职场文书
县优秀教师事迹材料
2014/01/31 职场文书
社区服务标语
2014/07/01 职场文书
骨干教师事迹材料
2014/12/17 职场文书
2015年大学生社会实践评语
2015/03/26 职场文书
财务部岗位职责范本
2015/04/14 职场文书
预备党员考察表党小组意见
2015/06/01 职场文书
2016廉洁教育心得体会
2016/01/20 职场文书
2016大学生毕业实习心得体会
2016/01/23 职场文书
护士业务学习心得体会
2016/01/25 职场文书
2019年恭贺升学祝福语集锦
2019/08/15 职场文书