PyTorch中model.zero_grad()和optimizer.zero_grad()用法


Posted in Python onJune 24, 2020

废话不多说,直接上代码吧~

model.zero_grad()
optimizer.zero_grad()

首先,这两种方式都是把模型中参数的梯度设为0

当optimizer = optim.Optimizer(net.parameters())时,二者等效,其中Optimizer可以是Adam、SGD等优化器

def zero_grad(self):
 """Sets gradients of all model parameters to zero."""
 for p in self.parameters():
  if p.grad is not None:
  p.grad.data.zero_()

补充知识:Pytorch中的optimizer.zero_grad和loss和net.backward和optimizer.step的理解

引言

一般训练神经网络,总是逃不开optimizer.zero_grad之后是loss(后面有的时候还会写forward,看你网络怎么写了)之后是是net.backward之后是optimizer.step的这个过程。

real_a, real_b = batch[0].to(device), batch[1].to(device)

fake_b = net_g(real_a)
optimizer_d.zero_grad()

# 判别器对虚假数据进行训练
fake_ab = torch.cat((real_a, fake_b), 1)
pred_fake = net_d.forward(fake_ab.detach())
loss_d_fake = criterionGAN(pred_fake, False)

# 判别器对真实数据进行训练
real_ab = torch.cat((real_a, real_b), 1)
pred_real = net_d.forward(real_ab)
loss_d_real = criterionGAN(pred_real, True)

# 判别器损失
loss_d = (loss_d_fake + loss_d_real) * 0.5

loss_d.backward()
optimizer_d.step()

上面这是一段cGAN的判别器训练过程。标题中所涉及到的这些方法,其实整个神经网络的参数更新过程(特别是反向传播),具体是怎么操作的,我们一起来探讨一下。

参数更新和反向传播

PyTorch中model.zero_grad()和optimizer.zero_grad()用法

上图为一个简单的梯度下降示意图。比如以SGD为例,是算一个batch计算一次梯度,然后进行一次梯度更新。这里梯度值就是对应偏导数的计算结果。显然,我们进行下一次batch梯度计算的时候,前一个batch的梯度计算结果,没有保留的必要了。所以在下一次梯度更新的时候,先使用optimizer.zero_grad把梯度信息设置为0。

我们使用loss来定义损失函数,是要确定优化的目标是什么,然后以目标为头,才可以进行链式法则和反向传播。

调用loss.backward方法时候,Pytorch的autograd就会自动沿着计算图反向传播,计算每一个叶子节点的梯度(如果某一个变量是由用户创建的,则它为叶子节点)。使用该方法,可以计算链式法则求导之后计算的结果值。

optimizer.step用来更新参数,就是图片中下半部分的w和b的参数更新操作。

以上这篇PyTorch中model.zero_grad()和optimizer.zero_grad()用法就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
Python基于回溯法子集树模板解决选排问题示例
Sep 07 Python
python数据结构之列表和元组的详解
Sep 23 Python
python 爬虫一键爬取 淘宝天猫宝贝页面主图颜色图和详情图的教程
May 22 Python
python3实现字符串的全排列的方法(无重复字符)
Jul 07 Python
python3通过selenium爬虫获取到dj商品的实例代码
Apr 25 Python
python nmap实现端口扫描器教程
May 28 Python
Python configparser模块常用方法解析
May 22 Python
浅谈django不使用restframework自定义接口与使用的区别
Jul 15 Python
Python字符串及文本模式方法详解
Sep 10 Python
用Python将库打包发布到pypi
Apr 13 Python
Python关于OS文件目录处理的实例分享
May 23 Python
Python加密技术之RSA加密解密的实现
Apr 08 Python
Pytorch实现将模型的所有参数的梯度清0
Jun 24 #Python
你需要学会的8个Python列表技巧
Jun 24 #Python
pytorch实现查看当前学习率
Jun 24 #Python
在pytorch中动态调整优化器的学习率方式
Jun 24 #Python
CentOS 7如何实现定时执行python脚本
Jun 24 #Python
python tkiner实现 一个小小的图片翻页功能的示例代码
Jun 24 #Python
在tensorflow实现直接读取网络的参数(weight and bias)的值
Jun 24 #Python
You might like
PHP 单引号与双引号的区别
2009/11/24 PHP
InnerHtml和InnerText的区别分析
2009/03/13 Javascript
JavaScript使用过程中需要注意的地方和一些基本语法
2010/08/26 Javascript
jQuery中after的两种用法实例
2013/07/03 Javascript
JS函数的定义与调用方法推荐
2016/05/12 Javascript
D3.js封装文本实现自动换行和旋转平移等功能
2016/10/14 Javascript
如何学JavaScript?前辈的经验之谈
2016/12/28 Javascript
bootstrap 点击空白处popover弹出框隐藏实例
2018/01/24 Javascript
使用elementUI实现将图片上传到本地的示例
2018/09/04 Javascript
vue自定义全局共用函数详解
2018/09/18 Javascript
vue项目持久化存储数据的实现代码
2018/10/01 Javascript
详解keep-alive + vuex 让缓存的页面灵活起来
2019/04/19 Javascript
基于vue-cli 路由 实现类似tab切换效果(vue 2.0)
2019/05/08 Javascript
vue使用video.js进行视频播放功能
2019/07/18 Javascript
vue源码中的检测方法的实现
2019/09/26 Javascript
基于JS实现简单滑块拼图游戏
2019/10/12 Javascript
微信小程序实现多行文字超出部分省略号显示功能
2019/10/23 Javascript
解决vuex刷新数据消失问题
2020/11/12 Javascript
Python远程桌面协议RDPY安装使用介绍
2015/04/15 Python
利用python发送和接收邮件
2016/09/27 Python
python好玩的项目—色情图片识别代码分享
2017/11/07 Python
基于Python socket的端口扫描程序实例代码
2018/02/09 Python
Python OpenCV处理图像之滤镜和图像运算
2018/07/10 Python
Python二叉树的遍历操作示例【前序遍历,中序遍历,后序遍历,层序遍历】
2018/12/24 Python
python使用sklearn实现决策树的方法示例
2019/09/12 Python
python实现在内存中读写str和二进制数据代码
2020/04/24 Python
详解python中GPU版本的opencv常用方法介绍
2020/07/24 Python
深入了解NumPy 高级索引
2020/07/24 Python
Pyecharts 中Geo函数常用参数的用法说明
2021/02/01 Python
印尼最大的网上书店:Gramedia.com
2018/09/13 全球购物
音乐学院硕士生的自我评价分享
2013/11/01 职场文书
暑期研修感言
2014/02/17 职场文书
工作疏忽、懈怠的检讨书
2014/09/11 职场文书
小学生九一八纪念日83周年演讲稿500字
2014/09/17 职场文书
歌舞青春观后感
2015/06/10 职场文书
创业计划书之物流运送
2019/09/17 职场文书