PyTorch中model.zero_grad()和optimizer.zero_grad()用法


Posted in Python onJune 24, 2020

废话不多说,直接上代码吧~

model.zero_grad()
optimizer.zero_grad()

首先,这两种方式都是把模型中参数的梯度设为0

当optimizer = optim.Optimizer(net.parameters())时,二者等效,其中Optimizer可以是Adam、SGD等优化器

def zero_grad(self):
 """Sets gradients of all model parameters to zero."""
 for p in self.parameters():
  if p.grad is not None:
  p.grad.data.zero_()

补充知识:Pytorch中的optimizer.zero_grad和loss和net.backward和optimizer.step的理解

引言

一般训练神经网络,总是逃不开optimizer.zero_grad之后是loss(后面有的时候还会写forward,看你网络怎么写了)之后是是net.backward之后是optimizer.step的这个过程。

real_a, real_b = batch[0].to(device), batch[1].to(device)

fake_b = net_g(real_a)
optimizer_d.zero_grad()

# 判别器对虚假数据进行训练
fake_ab = torch.cat((real_a, fake_b), 1)
pred_fake = net_d.forward(fake_ab.detach())
loss_d_fake = criterionGAN(pred_fake, False)

# 判别器对真实数据进行训练
real_ab = torch.cat((real_a, real_b), 1)
pred_real = net_d.forward(real_ab)
loss_d_real = criterionGAN(pred_real, True)

# 判别器损失
loss_d = (loss_d_fake + loss_d_real) * 0.5

loss_d.backward()
optimizer_d.step()

上面这是一段cGAN的判别器训练过程。标题中所涉及到的这些方法,其实整个神经网络的参数更新过程(特别是反向传播),具体是怎么操作的,我们一起来探讨一下。

参数更新和反向传播

PyTorch中model.zero_grad()和optimizer.zero_grad()用法

上图为一个简单的梯度下降示意图。比如以SGD为例,是算一个batch计算一次梯度,然后进行一次梯度更新。这里梯度值就是对应偏导数的计算结果。显然,我们进行下一次batch梯度计算的时候,前一个batch的梯度计算结果,没有保留的必要了。所以在下一次梯度更新的时候,先使用optimizer.zero_grad把梯度信息设置为0。

我们使用loss来定义损失函数,是要确定优化的目标是什么,然后以目标为头,才可以进行链式法则和反向传播。

调用loss.backward方法时候,Pytorch的autograd就会自动沿着计算图反向传播,计算每一个叶子节点的梯度(如果某一个变量是由用户创建的,则它为叶子节点)。使用该方法,可以计算链式法则求导之后计算的结果值。

optimizer.step用来更新参数,就是图片中下半部分的w和b的参数更新操作。

以上这篇PyTorch中model.zero_grad()和optimizer.zero_grad()用法就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
Python设计模式之单例模式实例
Apr 26 Python
在Python的web框架中编写创建日志的程序的教程
Apr 30 Python
python图片验证码生成代码
Jul 02 Python
python 数据的清理行为实例详解
Jul 12 Python
python中的随机函数小结
Jan 27 Python
python 将字符串转换成字典dict的各种方式总结
Mar 23 Python
对python读取zip压缩文件里面的csv数据实例详解
Feb 08 Python
PYTHON EVAL的用法及注意事项解析
Sep 06 Python
Python操作SQLite/MySQL/LMDB数据库的方法
Nov 07 Python
Python跑循环时内存泄露的解决方法
Jan 13 Python
Python itertools.product方法代码实例
Mar 27 Python
Python绘画好看的星空图
Mar 17 Python
Pytorch实现将模型的所有参数的梯度清0
Jun 24 #Python
你需要学会的8个Python列表技巧
Jun 24 #Python
pytorch实现查看当前学习率
Jun 24 #Python
在pytorch中动态调整优化器的学习率方式
Jun 24 #Python
CentOS 7如何实现定时执行python脚本
Jun 24 #Python
python tkiner实现 一个小小的图片翻页功能的示例代码
Jun 24 #Python
在tensorflow实现直接读取网络的参数(weight and bias)的值
Jun 24 #Python
You might like
ThinkPHP php 框架学习笔记
2009/10/30 PHP
基于PHP常用字符串的总结(待续)
2013/06/07 PHP
Symfony2 session用法实例分析
2016/02/04 PHP
jquery text,radio,checkbox,select操作实现代码
2009/07/09 Javascript
jQuery 过滤not()与filter()实例代码
2012/05/10 Javascript
javascript解决IE6下hover问题的方法
2015/07/28 Javascript
JavaScript获取浏览器信息的方法
2015/11/20 Javascript
JS实现颜色动态淡化效果
2017/03/06 Javascript
微信小程序 合法域名校验出错详解及解决办法
2017/03/09 Javascript
JavaScript使用ZeroClipboard操作剪切板
2017/05/10 Javascript
详解nodejs http请求相关总结
2019/03/31 NodeJs
搭建一个nodejs脚手架的方法步骤
2019/06/28 NodeJs
layui实现数据表格自定义数据项
2019/10/26 Javascript
JS实现瀑布流效果
2020/03/07 Javascript
vue实现桌面向网页拖动文件的示例代码(可显示图片/音频/视频)
2021/03/01 Vue.js
[46:58]完美世界DOTA2联赛PWL S3 Forest vs LBZS 第一场 12.17
2020/12/19 DOTA
python使用WMI检测windows系统信息、硬盘信息、网卡信息的方法
2015/05/15 Python
Python字符串匹配算法KMP实例
2015/07/18 Python
使用Python如何测试InnoDB与MyISAM的读写性能
2018/09/18 Python
python画图的函数用法以及技巧
2019/06/28 Python
用Cython加速Python到“起飞”(推荐)
2019/08/01 Python
Python Gitlab Api 使用方法
2019/08/28 Python
Python 读取WAV音频文件 画频谱的实例
2020/03/14 Python
tensorflow实现将ckpt转pb文件的方法
2020/04/22 Python
python gui开发——制作抖音无水印视频下载工具(附源码)
2021/02/07 Python
使用数据结构给女朋友写个Html5走迷宫游戏
2019/11/26 HTML / CSS
Joules官网:女士、男士和儿童服装和鞋类
2018/10/23 全球购物
以下的初始化有什么区别
2013/12/16 面试题
SQL Server 2000数据库的文件有哪些,分别进行描述。
2015/11/09 面试题
计算机专业毕业生推荐信
2013/11/25 职场文书
专升本学生毕业自我鉴定
2014/10/04 职场文书
乡镇党的群众路线教育实践活动总结报告
2014/10/30 职场文书
支行行长竞聘报告
2014/11/06 职场文书
首次购房证明
2015/06/19 职场文书
2015年党风廉政建设个人总结
2015/08/18 职场文书
小米11和iphone12哪个值得买?小米11对比iphone12评测
2021/04/21 数码科技