PyTorch中model.zero_grad()和optimizer.zero_grad()用法


Posted in Python onJune 24, 2020

废话不多说,直接上代码吧~

model.zero_grad()
optimizer.zero_grad()

首先,这两种方式都是把模型中参数的梯度设为0

当optimizer = optim.Optimizer(net.parameters())时,二者等效,其中Optimizer可以是Adam、SGD等优化器

def zero_grad(self):
 """Sets gradients of all model parameters to zero."""
 for p in self.parameters():
  if p.grad is not None:
  p.grad.data.zero_()

补充知识:Pytorch中的optimizer.zero_grad和loss和net.backward和optimizer.step的理解

引言

一般训练神经网络,总是逃不开optimizer.zero_grad之后是loss(后面有的时候还会写forward,看你网络怎么写了)之后是是net.backward之后是optimizer.step的这个过程。

real_a, real_b = batch[0].to(device), batch[1].to(device)

fake_b = net_g(real_a)
optimizer_d.zero_grad()

# 判别器对虚假数据进行训练
fake_ab = torch.cat((real_a, fake_b), 1)
pred_fake = net_d.forward(fake_ab.detach())
loss_d_fake = criterionGAN(pred_fake, False)

# 判别器对真实数据进行训练
real_ab = torch.cat((real_a, real_b), 1)
pred_real = net_d.forward(real_ab)
loss_d_real = criterionGAN(pred_real, True)

# 判别器损失
loss_d = (loss_d_fake + loss_d_real) * 0.5

loss_d.backward()
optimizer_d.step()

上面这是一段cGAN的判别器训练过程。标题中所涉及到的这些方法,其实整个神经网络的参数更新过程(特别是反向传播),具体是怎么操作的,我们一起来探讨一下。

参数更新和反向传播

PyTorch中model.zero_grad()和optimizer.zero_grad()用法

上图为一个简单的梯度下降示意图。比如以SGD为例,是算一个batch计算一次梯度,然后进行一次梯度更新。这里梯度值就是对应偏导数的计算结果。显然,我们进行下一次batch梯度计算的时候,前一个batch的梯度计算结果,没有保留的必要了。所以在下一次梯度更新的时候,先使用optimizer.zero_grad把梯度信息设置为0。

我们使用loss来定义损失函数,是要确定优化的目标是什么,然后以目标为头,才可以进行链式法则和反向传播。

调用loss.backward方法时候,Pytorch的autograd就会自动沿着计算图反向传播,计算每一个叶子节点的梯度(如果某一个变量是由用户创建的,则它为叶子节点)。使用该方法,可以计算链式法则求导之后计算的结果值。

optimizer.step用来更新参数,就是图片中下半部分的w和b的参数更新操作。

以上这篇PyTorch中model.zero_grad()和optimizer.zero_grad()用法就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
Python时间戳与时间字符串互相转换实例代码
Nov 28 Python
Python中常用操作字符串的函数与方法总结
Feb 04 Python
wxpython中自定义事件的实现与使用方法分析
Jul 21 Python
TensorFlow入门使用 tf.train.Saver()保存模型
Apr 24 Python
Python读取txt某几列绘图的方法
Oct 14 Python
对Python强大的可变参数传递机制详解
Jun 13 Python
Django在admin后台集成TinyMCE富文本编辑器的例子
Aug 09 Python
浅谈PyTorch的可重复性问题(如何使实验结果可复现)
Feb 20 Python
Pycharm中配置远程Docker运行环境的教程图解
Jun 11 Python
python ssh 执行shell命令的示例
Sep 29 Python
使用pandas或numpy处理数据中的空值(np.isnan()/pd.isnull())
May 14 Python
python实现商品进销存管理系统
May 30 Python
Pytorch实现将模型的所有参数的梯度清0
Jun 24 #Python
你需要学会的8个Python列表技巧
Jun 24 #Python
pytorch实现查看当前学习率
Jun 24 #Python
在pytorch中动态调整优化器的学习率方式
Jun 24 #Python
CentOS 7如何实现定时执行python脚本
Jun 24 #Python
python tkiner实现 一个小小的图片翻页功能的示例代码
Jun 24 #Python
在tensorflow实现直接读取网络的参数(weight and bias)的值
Jun 24 #Python
You might like
精美漂亮的php分页类代码
2013/04/02 PHP
PHP自定义多进制的方法
2016/11/03 PHP
jQuery Ajax之load()方法
2009/10/12 Javascript
jquery中ajax学习笔记3
2011/10/16 Javascript
使用Mootools动态添加Css样式表代码,兼容各浏览器
2011/12/12 Javascript
javascript学习笔记(十九) 节点的操作实现代码
2012/06/20 Javascript
javascript利用控件对windows的操作实现原理与应用
2012/12/23 Javascript
JQuery实现点击div以外的位置隐藏该div窗口
2013/09/13 Javascript
导入extjs、jquery 文件时$使用冲突问题解决方法
2014/01/14 Javascript
fckeditor粘贴Word时弹出窗口取消的方法
2014/10/30 Javascript
javascript中基本类型和引用类型的区别分析
2015/05/12 Javascript
js实现网页抽奖实例
2015/08/05 Javascript
基于javascript实现tab选项卡切换特效调试笔记
2016/03/30 Javascript
基于jquery实现二级联动效果
2017/03/30 jQuery
JavaScript模拟文件拖选框样式v1.0的实例
2017/08/04 Javascript
微信小程序实现天气预报功能
2018/07/18 Javascript
微信小程序实现图片翻转效果的实例代码
2019/09/20 Javascript
JavaScript实现滑动门效果
2020/01/18 Javascript
jQuery 动态粒子效果示例代码
2020/07/07 jQuery
详解JavaScript作用域、作用域链和闭包的用法
2020/09/03 Javascript
[03:57]2016完美“圣”典风云人物:rOtk专访
2016/12/09 DOTA
python新手经常遇到的17个错误分析
2014/07/30 Python
Python 通过pip安装Django详细介绍
2017/04/28 Python
Python程序暂停的正常处理方法
2019/11/07 Python
python matplotlib模块基本图形绘制方法小结【直线,曲线,直方图,饼图等】
2020/04/26 Python
Python sublime安装及配置过程详解
2020/06/29 Python
HTML5离线缓存Manifest是什么
2016/03/09 HTML / CSS
HTML5 body设置全屏背景图片的示例代码
2020/12/08 HTML / CSS
J.Crew官网:美国知名休闲服装品牌
2017/05/19 全球购物
学前班教师的自我鉴定
2013/12/05 职场文书
信息专业大学生自我评价分享
2014/01/17 职场文书
会计自我鉴定
2014/02/04 职场文书
优秀少先队员主要事迹材料
2014/05/28 职场文书
爱心捐书活动总结
2014/07/05 职场文书
村主任当选感言
2015/08/01 职场文书
python munch库的使用解析
2021/05/25 Python