PyTorch中model.zero_grad()和optimizer.zero_grad()用法


Posted in Python onJune 24, 2020

废话不多说,直接上代码吧~

model.zero_grad()
optimizer.zero_grad()

首先,这两种方式都是把模型中参数的梯度设为0

当optimizer = optim.Optimizer(net.parameters())时,二者等效,其中Optimizer可以是Adam、SGD等优化器

def zero_grad(self):
 """Sets gradients of all model parameters to zero."""
 for p in self.parameters():
  if p.grad is not None:
  p.grad.data.zero_()

补充知识:Pytorch中的optimizer.zero_grad和loss和net.backward和optimizer.step的理解

引言

一般训练神经网络,总是逃不开optimizer.zero_grad之后是loss(后面有的时候还会写forward,看你网络怎么写了)之后是是net.backward之后是optimizer.step的这个过程。

real_a, real_b = batch[0].to(device), batch[1].to(device)

fake_b = net_g(real_a)
optimizer_d.zero_grad()

# 判别器对虚假数据进行训练
fake_ab = torch.cat((real_a, fake_b), 1)
pred_fake = net_d.forward(fake_ab.detach())
loss_d_fake = criterionGAN(pred_fake, False)

# 判别器对真实数据进行训练
real_ab = torch.cat((real_a, real_b), 1)
pred_real = net_d.forward(real_ab)
loss_d_real = criterionGAN(pred_real, True)

# 判别器损失
loss_d = (loss_d_fake + loss_d_real) * 0.5

loss_d.backward()
optimizer_d.step()

上面这是一段cGAN的判别器训练过程。标题中所涉及到的这些方法,其实整个神经网络的参数更新过程(特别是反向传播),具体是怎么操作的,我们一起来探讨一下。

参数更新和反向传播

PyTorch中model.zero_grad()和optimizer.zero_grad()用法

上图为一个简单的梯度下降示意图。比如以SGD为例,是算一个batch计算一次梯度,然后进行一次梯度更新。这里梯度值就是对应偏导数的计算结果。显然,我们进行下一次batch梯度计算的时候,前一个batch的梯度计算结果,没有保留的必要了。所以在下一次梯度更新的时候,先使用optimizer.zero_grad把梯度信息设置为0。

我们使用loss来定义损失函数,是要确定优化的目标是什么,然后以目标为头,才可以进行链式法则和反向传播。

调用loss.backward方法时候,Pytorch的autograd就会自动沿着计算图反向传播,计算每一个叶子节点的梯度(如果某一个变量是由用户创建的,则它为叶子节点)。使用该方法,可以计算链式法则求导之后计算的结果值。

optimizer.step用来更新参数,就是图片中下半部分的w和b的参数更新操作。

以上这篇PyTorch中model.zero_grad()和optimizer.zero_grad()用法就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
python处理中文编码和判断编码示例
Feb 26 Python
python根据日期返回星期几的方法
Jul 06 Python
pytorch cnn 识别手写的字实现自建图片数据
May 20 Python
使用python实现http及ftp服务进行数据传输的方法
Oct 26 Python
python pands实现execl转csv 并修改csv指定列的方法
Dec 12 Python
浅谈python requests 的put, post 请求参数的问题
Jan 02 Python
python顺序执行多个py文件的方法
Jun 29 Python
Python3 虚拟开发环境搭建过程(图文详解)
Jan 06 Python
Django之form组件自动校验数据实现
Jan 14 Python
python词云库wordCloud使用方法详解(解决中文乱码)
Feb 17 Python
Python利用Faiss库实现ANN近邻搜索的方法详解
Aug 03 Python
Python使用openpyxl模块处理Excel文件
Jun 05 Python
Pytorch实现将模型的所有参数的梯度清0
Jun 24 #Python
你需要学会的8个Python列表技巧
Jun 24 #Python
pytorch实现查看当前学习率
Jun 24 #Python
在pytorch中动态调整优化器的学习率方式
Jun 24 #Python
CentOS 7如何实现定时执行python脚本
Jun 24 #Python
python tkiner实现 一个小小的图片翻页功能的示例代码
Jun 24 #Python
在tensorflow实现直接读取网络的参数(weight and bias)的值
Jun 24 #Python
You might like
php面向对象全攻略 (十) final static const关键字的使用
2009/09/30 PHP
解析二进制流接口应用实例 pack、unpack、ord 函数使用方法
2013/06/18 PHP
php广告加载类用法实例
2014/09/23 PHP
必须收藏的php实用代码片段
2016/02/02 PHP
PHP正则表达式入门教程(推荐)
2016/05/18 PHP
Yii实现的多级联动下拉菜单
2016/07/13 PHP
PHP二进制与字符串之间的相互转换教程
2016/10/14 PHP
curl 出现错误的调试方法(必看)
2017/02/13 PHP
PHP+MYSQL实现读写分离简单实战
2017/03/13 PHP
php取出数组单个值的方法
2018/03/12 PHP
轻轻松松学习JavaScript
2007/02/25 Javascript
JQuery.Ajax之错误调试帮助信息介绍
2013/07/04 Javascript
js字母大小写转换实现方法总结
2013/11/13 Javascript
JavaScript匿名函数用法分析
2015/02/13 Javascript
基于jQuery通过jQuery.form.js插件使用ajax提交form表单
2015/08/17 Javascript
利用JS实现一个同Excel表现的智能填充算法
2018/08/13 Javascript
JS实现把一个页面层数据传递到另一个页面的两种方式
2018/08/13 Javascript
node.js调用C++函数的方法示例
2018/09/21 Javascript
微信小程序+云开发实现欢迎登录注册
2019/05/24 Javascript
微信小程序swiper实现文字纵向轮播提示效果
2020/01/21 Javascript
python中requests爬去网页内容出现乱码问题解决方法介绍
2017/10/25 Python
Python中defaultdict与lambda表达式用法实例小结
2018/04/09 Python
Django如何自定义分页
2018/09/25 Python
解决Python中list里的中文输出到html模板里的问题
2018/12/17 Python
python 机器学习之支持向量机非线性回归SVR模型
2019/06/26 Python
Python中的引用和拷贝实例解析
2019/11/14 Python
使用python客户端访问impala的操作方式
2020/03/28 Python
Python就将所有的英文单词首字母变成大写
2021/02/12 Python
加拿大廉价机票预订网站:CheapOair.ca
2018/03/04 全球购物
大一学生职业生涯规划
2014/03/11 职场文书
党员学习党的群众路线思想汇报(5篇)
2014/09/10 职场文书
英文商务邀请函范文
2015/01/31 职场文书
2015年体检中心工作总结
2015/05/27 职场文书
Go语言实现Base64、Base58编码与解码
2021/07/26 Golang
python3中apply函数和lambda函数的使用详解
2022/02/28 Python
SQL Server携程核心系统无感迁移到MySQL实战
2022/06/01 SQL Server