对pytorch中的梯度更新方法详解


Posted in Python onAugust 20, 2019

背景

使用pytorch时,有一个yolov3的bug,我认为涉及到学习率的调整。收集到tencent yolov3和mxnet开源的yolov3,两个优化器中的学习率设置不一样,而且使用GPU数目和batch的更新也不太一样。据此,我简单的了解了下pytorch的权重梯度的更新策略,看看能否一窥究竟。

对代码说明

共三个实验,分布写在代码中的(一)(二)(三)三个地方。运行实验时注释掉其他两个

实验及其结果

实验(三):

不使用zero_grad()时,grad累加在一起,官网是使用accumulate 来表述的,所以不太清楚是取的和还是均值(这两种最有可能)。

不使用zero_grad()时,是直接叠加add的方式累加的。

tensor([[[ 1., 1.],……torch.Size([2, 2, 2])
0 2 * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * 
tensor([[[ 2., 2.],…… torch.Size([2, 2, 2])
1 2 * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * 
tensor([[[ 3., 3.],…… torch.Size([2, 2, 2])
2 2 * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * *

实验(二):

单卡上不同的batchsize对梯度是怎么作用的。 mini-batch SGD中的batch是加快训练,同时保持一定的噪声。但设置不同的batchsize的权重的梯度是怎么计算的呢。

设置运行实验(二),可以看到结果如下:所以单卡batchsize计算梯度是取均值的

tensor([[[ 3., 3.],…… torch.Size([2, 2, 2])

实验(一):

多gpu情况下,梯度怎么合并在一起的。

在《training imagenet in 1 hours》中提到grad是allreduce的,是累加的形式。但是当设置g=2,实验一运行时,结果也是取均值的,类同于实验(二)

tensor([[[ 3., 3.],…… torch.Size([2, 2, 2])

实验代码

import torch
import torch.nn as nn
from torch.autograd import Variable


class model(nn.Module):
 def __init__(self, w):
  super(model, self).__init__()
  self.w = w

 def forward(self, xx):
  b, c, _, _ = xx.shape
  # extra = xx.device.index + 1 ## 实验(一)
  y = xx.reshape(b, -1).mm(self.w.cuda(xx.device).reshape(-1, 2) * extra)
  return y.reshape(len(xx), -1)


g = 1
x = Variable(torch.ones(2, 1, 2, 2))
# x[1] += 1 ## 实验(二)
w = Variable(torch.ones(2, 2, 2) * 2, requires_grad=True)
# optim = torch.optim.SGD({'params': x},
lr = 0.01
momentum = 0.9
M = model(w)

M = torch.nn.DataParallel(M, device_ids=range(g))

for i in range(3):
 b = len(x)
 z = M(x)
 zz = z.sum(1)
 l = (zz - Variable(torch.ones(b).cuda())).mean()
 # zz.backward(Variable(torch.ones(b).cuda()))
 l.backward()
 print(w.grad, w.grad.shape)
 # w.grad.zero_() ## 实验(三)
 print(i, b, '* * ' * 20)

以上这篇对pytorch中的梯度更新方法详解就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
Python实现SVN的目录周期性备份实例
Jul 17 Python
Python中的模块导入和读取键盘输入的方法
Oct 16 Python
python中zip()方法应用实例分析
Apr 16 Python
Python中set与frozenset方法和区别详解
May 23 Python
Python编程实现二叉树及七种遍历方法详解
Jun 02 Python
Python实现列表删除重复元素的三种常用方法分析
Nov 24 Python
Python minidom模块用法示例【DOM写入和解析XML】
Mar 25 Python
使用PyInstaller将Pygame库编写的小游戏程序打包为exe文件及出现问题解决方法
Sep 06 Python
python实现五子棋游戏(pygame版)
Jan 19 Python
在django中使用post方法时,需要增加csrftoken的例子
Mar 13 Python
Python利用pip安装tar.gz格式的离线资源包
Sep 14 Python
如何利用python 读取配置文件
Jan 06 Python
PyTorch: 梯度下降及反向传播的实例详解
Aug 20 #Python
python爬虫 urllib模块发起post请求过程解析
Aug 20 #Python
pytorch 加载(.pth)格式的模型实例
Aug 20 #Python
python multiprocessing模块用法及原理介绍
Aug 20 #Python
python 并发编程 阻塞IO模型原理解析
Aug 20 #Python
PyTorch中常用的激活函数的方法示例
Aug 20 #Python
Pytorch抽取网络层的Feature Map(Vgg)实例
Aug 20 #Python
You might like
PHPLog php 程序调试追踪工具
2009/09/09 PHP
php学习之 数组声明
2011/06/09 PHP
Erlang的运算符(比较运算符,数值运算符,移位运算符,逻辑运算符)
2012/07/23 PHP
详解PHP匿名函数与注意事项
2016/03/29 PHP
用HTML/JS/PHP方式实现页面延时跳转的简单实例
2016/07/18 PHP
JQuery扩展插件Validate 5添加自定义验证方法
2011/09/05 Javascript
在表单提交前进行验证的几种方式整理
2013/07/31 Javascript
没有document.getElementByName方法
2013/08/19 Javascript
使用js实现雪花飘落效果
2013/08/26 Javascript
用C/C++来实现 Node.js 的模块(二)
2014/09/24 Javascript
jQuery中removeClass()方法用法实例
2015/01/05 Javascript
js中利用cookie实现记住密码功能
2020/08/20 Javascript
javascript实现复选框全选或反选
2017/02/04 Javascript
JavaScript重复元素处理方法分析【统计个数、计算、去重复等】
2017/12/14 Javascript
React Native 图片查看组件的方法
2018/03/01 Javascript
浅谈webpack 自动刷新与解析
2018/04/09 Javascript
基于vue-element组件实现音乐播放器功能
2018/05/06 Javascript
JS实现获取当前所在周的周六、周日示例分析
2019/05/11 Javascript
实用Javascript调试技巧分享(小结)
2019/06/18 Javascript
swiper4实现移动端导航切换
2020/10/16 Javascript
如何搜索查找并解决Django相关的问题
2014/06/30 Python
python结合selenium获取XX省交通违章数据的实现思路及代码
2016/06/26 Python
Python闭包之返回函数的函数用法示例
2018/01/27 Python
Python3.6连接Oracle数据库的方法详解
2018/05/18 Python
python 普通克里金(Kriging)法的实现
2019/12/19 Python
Python sep参数使用方法详解
2020/02/12 Python
python str字符串转uuid实例
2020/03/03 Python
python3 使用traceback定位异常实例
2020/03/09 Python
pycharm下配置pyqt5的教程(anaconda虚拟环境下+tensorflow)
2020/03/25 Python
使用opencv中匹配点对的坐标提取方式
2020/06/04 Python
意大利包包和行李箱销售网站:Bagaglio.it
2021/03/02 全球购物
会计专业应届生求职信
2013/11/24 职场文书
保证书范文大全
2014/04/28 职场文书
施工安全标语
2014/06/07 职场文书
植物园观后感
2015/06/11 职场文书
2015初中团委工作总结
2015/07/28 职场文书