PyTorch中model.zero_grad()和optimizer.zero_grad()用法


Posted in Python onJune 24, 2020

废话不多说,直接上代码吧~

model.zero_grad()
optimizer.zero_grad()

首先,这两种方式都是把模型中参数的梯度设为0

当optimizer = optim.Optimizer(net.parameters())时,二者等效,其中Optimizer可以是Adam、SGD等优化器

def zero_grad(self):
 """Sets gradients of all model parameters to zero."""
 for p in self.parameters():
  if p.grad is not None:
  p.grad.data.zero_()

补充知识:Pytorch中的optimizer.zero_grad和loss和net.backward和optimizer.step的理解

引言

一般训练神经网络,总是逃不开optimizer.zero_grad之后是loss(后面有的时候还会写forward,看你网络怎么写了)之后是是net.backward之后是optimizer.step的这个过程。

real_a, real_b = batch[0].to(device), batch[1].to(device)

fake_b = net_g(real_a)
optimizer_d.zero_grad()

# 判别器对虚假数据进行训练
fake_ab = torch.cat((real_a, fake_b), 1)
pred_fake = net_d.forward(fake_ab.detach())
loss_d_fake = criterionGAN(pred_fake, False)

# 判别器对真实数据进行训练
real_ab = torch.cat((real_a, real_b), 1)
pred_real = net_d.forward(real_ab)
loss_d_real = criterionGAN(pred_real, True)

# 判别器损失
loss_d = (loss_d_fake + loss_d_real) * 0.5

loss_d.backward()
optimizer_d.step()

上面这是一段cGAN的判别器训练过程。标题中所涉及到的这些方法,其实整个神经网络的参数更新过程(特别是反向传播),具体是怎么操作的,我们一起来探讨一下。

参数更新和反向传播

PyTorch中model.zero_grad()和optimizer.zero_grad()用法

上图为一个简单的梯度下降示意图。比如以SGD为例,是算一个batch计算一次梯度,然后进行一次梯度更新。这里梯度值就是对应偏导数的计算结果。显然,我们进行下一次batch梯度计算的时候,前一个batch的梯度计算结果,没有保留的必要了。所以在下一次梯度更新的时候,先使用optimizer.zero_grad把梯度信息设置为0。

我们使用loss来定义损失函数,是要确定优化的目标是什么,然后以目标为头,才可以进行链式法则和反向传播。

调用loss.backward方法时候,Pytorch的autograd就会自动沿着计算图反向传播,计算每一个叶子节点的梯度(如果某一个变量是由用户创建的,则它为叶子节点)。使用该方法,可以计算链式法则求导之后计算的结果值。

optimizer.step用来更新参数,就是图片中下半部分的w和b的参数更新操作。

以上这篇PyTorch中model.zero_grad()和optimizer.zero_grad()用法就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
python使用自定义user-agent抓取网页的方法
Apr 15 Python
python将字符串转换成数组的方法
Apr 29 Python
简单谈谈Python中的闭包
Nov 30 Python
详解Python异常处理中的Finally else的功能
Dec 29 Python
tensorflow中next_batch的具体使用
Feb 02 Python
Python使用爬虫爬取静态网页图片的方法详解
Jun 05 Python
对Python2与Python3中__bool__方法的差异详解
Nov 01 Python
使用Python制作一个打字训练小工具
Oct 01 Python
使用python实现对元素的长截图功能
Nov 14 Python
python 字典访问的三种方法小结
Dec 05 Python
Python如何实现机器人聊天
Sep 10 Python
Jupyter Notebook添加代码自动补全功能的实现
Jan 07 Python
Pytorch实现将模型的所有参数的梯度清0
Jun 24 #Python
你需要学会的8个Python列表技巧
Jun 24 #Python
pytorch实现查看当前学习率
Jun 24 #Python
在pytorch中动态调整优化器的学习率方式
Jun 24 #Python
CentOS 7如何实现定时执行python脚本
Jun 24 #Python
python tkiner实现 一个小小的图片翻页功能的示例代码
Jun 24 #Python
在tensorflow实现直接读取网络的参数(weight and bias)的值
Jun 24 #Python
You might like
php中simplexml_load_file函数用法实例
2014/11/12 PHP
php常用表单验证类用法实例
2015/06/18 PHP
微信小程序 消息推送php服务器验证实例详解
2017/03/30 PHP
event.srcElement+表格应用
2006/08/29 Javascript
jQuery each()方法的使用方法
2010/03/18 Javascript
瀑布流布局并自动加载实现代码
2013/03/12 Javascript
js获得鼠标的坐标值的方法
2013/03/13 Javascript
js如何获取file控件的完整路径具体实现代码
2013/05/15 Javascript
JS远程获取网页源代码实例
2013/09/05 Javascript
JS 获取滚动条高度示例代码
2013/10/24 Javascript
JS去除iframe滚动条的方法
2015/04/01 Javascript
jQuery Validation Engine验证控件调用外部函数验证的方法
2017/01/18 Javascript
javascript过滤数组重复元素的实现方法
2017/05/03 Javascript
JS创建Tag标签的方法详解
2017/06/09 Javascript
jQuery+CSS实现的table表格行列转置功能示例
2018/01/08 jQuery
Parcel 打包示例(React HelloWorld)
2018/01/16 Javascript
JS实现带阴历的日历功能详解
2019/01/24 Javascript
JS实现换肤功能的方法实例详解
2019/01/30 Javascript
基于Vue 实现一个中规中矩loading组件
2019/04/03 Javascript
vue获取data数据改变前后的值方法
2019/11/07 Javascript
环形加载进度条封装(Vue插件版和原生js版)
2019/12/04 Javascript
Vue select 绑定动态变量的实例讲解
2020/10/22 Javascript
python在每个字符后添加空格的实例
2018/05/07 Python
浅谈python的dataframe与series的创建方法
2018/11/12 Python
python 生成器和迭代器的原理解析
2019/10/12 Python
python 用 xlwings 库 生成图表的操作方法
2019/12/22 Python
PyCharm 2020.2 安装详细教程
2020/09/25 Python
移动端html5 meta标签的神奇功效
2016/01/06 HTML / CSS
九年级家长会邀请函
2014/01/15 职场文书
旅游个人求职信范文
2014/01/30 职场文书
党的生日演讲稿
2014/09/10 职场文书
离婚起诉书范文2015
2015/05/19 职场文书
导游词之澳门玫瑰圣母堂
2019/12/03 职场文书
python实现语音常用度量方法的代码详解
2021/05/25 Python
解决 redis 无法远程连接
2022/05/15 Redis
MySQL 自动填充 create_time 和 update_time
2022/05/20 MySQL