PyTorch中model.zero_grad()和optimizer.zero_grad()用法


Posted in Python onJune 24, 2020

废话不多说,直接上代码吧~

model.zero_grad()
optimizer.zero_grad()

首先,这两种方式都是把模型中参数的梯度设为0

当optimizer = optim.Optimizer(net.parameters())时,二者等效,其中Optimizer可以是Adam、SGD等优化器

def zero_grad(self):
 """Sets gradients of all model parameters to zero."""
 for p in self.parameters():
  if p.grad is not None:
  p.grad.data.zero_()

补充知识:Pytorch中的optimizer.zero_grad和loss和net.backward和optimizer.step的理解

引言

一般训练神经网络,总是逃不开optimizer.zero_grad之后是loss(后面有的时候还会写forward,看你网络怎么写了)之后是是net.backward之后是optimizer.step的这个过程。

real_a, real_b = batch[0].to(device), batch[1].to(device)

fake_b = net_g(real_a)
optimizer_d.zero_grad()

# 判别器对虚假数据进行训练
fake_ab = torch.cat((real_a, fake_b), 1)
pred_fake = net_d.forward(fake_ab.detach())
loss_d_fake = criterionGAN(pred_fake, False)

# 判别器对真实数据进行训练
real_ab = torch.cat((real_a, real_b), 1)
pred_real = net_d.forward(real_ab)
loss_d_real = criterionGAN(pred_real, True)

# 判别器损失
loss_d = (loss_d_fake + loss_d_real) * 0.5

loss_d.backward()
optimizer_d.step()

上面这是一段cGAN的判别器训练过程。标题中所涉及到的这些方法,其实整个神经网络的参数更新过程(特别是反向传播),具体是怎么操作的,我们一起来探讨一下。

参数更新和反向传播

PyTorch中model.zero_grad()和optimizer.zero_grad()用法

上图为一个简单的梯度下降示意图。比如以SGD为例,是算一个batch计算一次梯度,然后进行一次梯度更新。这里梯度值就是对应偏导数的计算结果。显然,我们进行下一次batch梯度计算的时候,前一个batch的梯度计算结果,没有保留的必要了。所以在下一次梯度更新的时候,先使用optimizer.zero_grad把梯度信息设置为0。

我们使用loss来定义损失函数,是要确定优化的目标是什么,然后以目标为头,才可以进行链式法则和反向传播。

调用loss.backward方法时候,Pytorch的autograd就会自动沿着计算图反向传播,计算每一个叶子节点的梯度(如果某一个变量是由用户创建的,则它为叶子节点)。使用该方法,可以计算链式法则求导之后计算的结果值。

optimizer.step用来更新参数,就是图片中下半部分的w和b的参数更新操作。

以上这篇PyTorch中model.zero_grad()和optimizer.zero_grad()用法就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
使用Python装饰器在Django框架下去除冗余代码的教程
Apr 16 Python
windows10系统中安装python3.x+scrapy教程
Nov 08 Python
Python 实现中值滤波、均值滤波的方法
Jan 09 Python
Django  ORM 练习题及答案
Jul 19 Python
python3使用GUI统计代码量
Sep 18 Python
Django实现auth模块下的登录注册与注销功能
Oct 10 Python
django框架单表操作之增删改实例分析
Dec 16 Python
python实现从ftp上下载文件的实例方法
Jul 19 Python
浅谈Python爬虫原理与数据抓取
Jul 21 Python
Python matplotlib图例放在外侧保存时显示不完整问题解决
Jul 28 Python
python缩进长度是否统一
Aug 02 Python
详解python的变量缓存机制
Jan 24 Python
Pytorch实现将模型的所有参数的梯度清0
Jun 24 #Python
你需要学会的8个Python列表技巧
Jun 24 #Python
pytorch实现查看当前学习率
Jun 24 #Python
在pytorch中动态调整优化器的学习率方式
Jun 24 #Python
CentOS 7如何实现定时执行python脚本
Jun 24 #Python
python tkiner实现 一个小小的图片翻页功能的示例代码
Jun 24 #Python
在tensorflow实现直接读取网络的参数(weight and bias)的值
Jun 24 #Python
You might like
php遍历目录viewDir函数
2009/12/15 PHP
WAMP环境中扩展oracle函数库(oci)
2015/06/26 PHP
PHP调用接口API封装的例子
2019/10/11 PHP
TP5框架实现签到功能的方法分析
2020/04/05 PHP
jquery.cvtooltip.js 基于jquery的气泡提示插件
2010/11/19 Javascript
超酷的网页音乐播放器DewPlayer使用方法
2010/12/18 Javascript
Jquery插件写法笔记整理
2012/09/06 Javascript
jquery实现excel导出的方法
2013/04/04 Javascript
分享28款免费实用的 JQuery 图片和内容滑块插件
2014/12/15 Javascript
JavaScript AOP编程实例
2015/06/16 Javascript
jQuery实现的图文高亮滚动切换特效实例
2015/08/10 Javascript
jQuery实现页面点击后退弹出提示框的方法
2016/08/24 Javascript
JavaScript面试题大全(推荐)
2016/09/22 Javascript
微信小程序开发之入门实例教程篇
2017/03/07 Javascript
JavaScript实现简单的四则运算计算器完整实例
2017/04/28 Javascript
详解webpack-dev-server使用方法
2018/09/14 Javascript
node.js连接mysql与基本用法示例
2019/01/05 Javascript
使用 js 简单的实现 bind、call 、aplly代码实例
2019/09/07 Javascript
JavaScript Dom 绑定事件操作实例详解
2019/10/02 Javascript
微信小程序加载机制及运行机制图解
2019/11/27 Javascript
详解React 条件渲染
2020/07/08 Javascript
基于JS实现操作成功之后自动跳转页面
2020/09/25 Javascript
python变量不能以数字打头详解
2016/07/06 Python
python3中的md5加密实例
2018/05/29 Python
一文了解Python并发编程的工程实现方法
2019/05/31 Python
Python中@property的理解和使用示例
2019/06/11 Python
python cv2截取不规则区域图片实例
2019/12/21 Python
英国折扣高尔夫商店:Discount Golf Store
2019/11/19 全球购物
现代化办公人员工作的自我评价
2013/10/16 职场文书
小学红领巾中秋节广播稿
2014/01/13 职场文书
商场促销活动方案
2014/02/08 职场文书
培训讲师岗位职责
2014/04/13 职场文书
房地产推广策划方案
2014/05/19 职场文书
2015公务员试用期工作总结
2014/12/12 职场文书
中标通知书范本
2015/04/17 职场文书
公文格式,规则明细(新手收藏)
2019/07/23 职场文书