PyTorch中model.zero_grad()和optimizer.zero_grad()用法


Posted in Python onJune 24, 2020

废话不多说,直接上代码吧~

model.zero_grad()
optimizer.zero_grad()

首先,这两种方式都是把模型中参数的梯度设为0

当optimizer = optim.Optimizer(net.parameters())时,二者等效,其中Optimizer可以是Adam、SGD等优化器

def zero_grad(self):
 """Sets gradients of all model parameters to zero."""
 for p in self.parameters():
  if p.grad is not None:
  p.grad.data.zero_()

补充知识:Pytorch中的optimizer.zero_grad和loss和net.backward和optimizer.step的理解

引言

一般训练神经网络,总是逃不开optimizer.zero_grad之后是loss(后面有的时候还会写forward,看你网络怎么写了)之后是是net.backward之后是optimizer.step的这个过程。

real_a, real_b = batch[0].to(device), batch[1].to(device)

fake_b = net_g(real_a)
optimizer_d.zero_grad()

# 判别器对虚假数据进行训练
fake_ab = torch.cat((real_a, fake_b), 1)
pred_fake = net_d.forward(fake_ab.detach())
loss_d_fake = criterionGAN(pred_fake, False)

# 判别器对真实数据进行训练
real_ab = torch.cat((real_a, real_b), 1)
pred_real = net_d.forward(real_ab)
loss_d_real = criterionGAN(pred_real, True)

# 判别器损失
loss_d = (loss_d_fake + loss_d_real) * 0.5

loss_d.backward()
optimizer_d.step()

上面这是一段cGAN的判别器训练过程。标题中所涉及到的这些方法,其实整个神经网络的参数更新过程(特别是反向传播),具体是怎么操作的,我们一起来探讨一下。

参数更新和反向传播

PyTorch中model.zero_grad()和optimizer.zero_grad()用法

上图为一个简单的梯度下降示意图。比如以SGD为例,是算一个batch计算一次梯度,然后进行一次梯度更新。这里梯度值就是对应偏导数的计算结果。显然,我们进行下一次batch梯度计算的时候,前一个batch的梯度计算结果,没有保留的必要了。所以在下一次梯度更新的时候,先使用optimizer.zero_grad把梯度信息设置为0。

我们使用loss来定义损失函数,是要确定优化的目标是什么,然后以目标为头,才可以进行链式法则和反向传播。

调用loss.backward方法时候,Pytorch的autograd就会自动沿着计算图反向传播,计算每一个叶子节点的梯度(如果某一个变量是由用户创建的,则它为叶子节点)。使用该方法,可以计算链式法则求导之后计算的结果值。

optimizer.step用来更新参数,就是图片中下半部分的w和b的参数更新操作。

以上这篇PyTorch中model.zero_grad()和optimizer.zero_grad()用法就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
python实现socket端口重定向示例
Feb 10 Python
python3抓取中文网页的方法
Jul 28 Python
微信跳一跳辅助python代码实现
Jan 05 Python
zookeeper python接口实例详解
Jan 18 Python
Python机器学习k-近邻算法(K Nearest Neighbor)实例详解
Jun 25 Python
Flask Web开发入门之文件上传(八)
Aug 17 Python
Python实现繁?转为简体的方法示例
Dec 18 Python
Python3 单行多行万能正则匹配方法
Jan 07 Python
python 日期排序的实例代码
Jul 11 Python
解决Numpy中sum函数求和结果维度的问题
Dec 06 Python
pycharm 的Structure界面设置操作
Feb 05 Python
Python基础之字符串格式化详解
Apr 21 Python
Pytorch实现将模型的所有参数的梯度清0
Jun 24 #Python
你需要学会的8个Python列表技巧
Jun 24 #Python
pytorch实现查看当前学习率
Jun 24 #Python
在pytorch中动态调整优化器的学习率方式
Jun 24 #Python
CentOS 7如何实现定时执行python脚本
Jun 24 #Python
python tkiner实现 一个小小的图片翻页功能的示例代码
Jun 24 #Python
在tensorflow实现直接读取网络的参数(weight and bias)的值
Jun 24 #Python
You might like
15个小时----从修改程序到自己些程序
2006/10/09 PHP
发布一个迷你php+AJAX聊天程序[聊天室]提供下载
2007/07/21 PHP
Laravel5.* 打印出执行的sql语句的方法
2017/07/24 PHP
PHP设计模式(三)建造者模式Builder实例详解【创建型】
2020/05/02 PHP
关于使用runtimeStyle属性问题讨论文章
2007/03/08 Javascript
Javascript 错误处理的几种方法
2009/06/13 Javascript
JS实现关键字搜索时的相关下拉字段效果
2014/08/05 Javascript
深入剖析JavaScript中的函数currying柯里化
2016/04/29 Javascript
详解JavaScript节流函数中的Throttle
2016/07/16 Javascript
Bootstrap基本组件学习笔记之列表组(11)
2016/12/07 Javascript
JavaScript登录记住密码操作(超简单代码)
2017/03/22 Javascript
jQuery树插件zTree使用方法详解
2017/05/02 jQuery
基于js的变量提升和函数提升(详解)
2017/09/17 Javascript
vue.js 微信支付前端代码分享
2018/02/10 Javascript
Vue中jsx不完全应用指南小结
2019/11/01 Javascript
node.js开发辅助工具nodemon安装与配置详解
2020/02/06 Javascript
vue-cli3使用mock数据的方法分析
2020/03/16 Javascript
Vue实现Layui的集成方法步骤
2020/04/10 Javascript
[01:24]2014DOTA2 TI第二日 YYF表示这届谁赢都有可能
2014/07/11 DOTA
使用pygame模块编写贪吃蛇的实例讲解
2018/02/05 Python
使用Python机器学习降低静态日志噪声
2018/09/29 Python
Python开启线程,在函数中开线程的实例
2019/02/22 Python
Python实现word2Vec model过程解析
2019/12/16 Python
pymysql 插入数据 转义处理方式
2020/03/02 Python
简单的HTML5初步入门教程
2015/09/29 HTML / CSS
HTML5 video 上传预览图片视频如何设置、预览视频某秒的海报帧
2018/08/28 HTML / CSS
美国最便宜的旅游网站:CheapTickets
2017/07/09 全球购物
结婚典礼证婚词
2014/01/11 职场文书
网吧七夕活动策划方案
2014/08/31 职场文书
教师自我剖析材料范文
2014/09/30 职场文书
课堂打架检讨书200字
2014/11/21 职场文书
小学生光盘行动倡议书
2015/04/28 职场文书
教学副校长工作总结
2015/08/13 职场文书
2016年村党支部公开承诺书
2016/03/24 职场文书
2016年全国爱眼日宣传教育活动总结
2016/04/05 职场文书
用Python爬取某乎手机APP数据
2021/06/15 Python