对pytorch中的梯度更新方法详解


Posted in Python onAugust 20, 2019

背景

使用pytorch时,有一个yolov3的bug,我认为涉及到学习率的调整。收集到tencent yolov3和mxnet开源的yolov3,两个优化器中的学习率设置不一样,而且使用GPU数目和batch的更新也不太一样。据此,我简单的了解了下pytorch的权重梯度的更新策略,看看能否一窥究竟。

对代码说明

共三个实验,分布写在代码中的(一)(二)(三)三个地方。运行实验时注释掉其他两个

实验及其结果

实验(三):

不使用zero_grad()时,grad累加在一起,官网是使用accumulate 来表述的,所以不太清楚是取的和还是均值(这两种最有可能)。

不使用zero_grad()时,是直接叠加add的方式累加的。

tensor([[[ 1., 1.],……torch.Size([2, 2, 2])
0 2 * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * 
tensor([[[ 2., 2.],…… torch.Size([2, 2, 2])
1 2 * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * 
tensor([[[ 3., 3.],…… torch.Size([2, 2, 2])
2 2 * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * *

实验(二):

单卡上不同的batchsize对梯度是怎么作用的。 mini-batch SGD中的batch是加快训练,同时保持一定的噪声。但设置不同的batchsize的权重的梯度是怎么计算的呢。

设置运行实验(二),可以看到结果如下:所以单卡batchsize计算梯度是取均值的

tensor([[[ 3., 3.],…… torch.Size([2, 2, 2])

实验(一):

多gpu情况下,梯度怎么合并在一起的。

在《training imagenet in 1 hours》中提到grad是allreduce的,是累加的形式。但是当设置g=2,实验一运行时,结果也是取均值的,类同于实验(二)

tensor([[[ 3., 3.],…… torch.Size([2, 2, 2])

实验代码

import torch
import torch.nn as nn
from torch.autograd import Variable


class model(nn.Module):
 def __init__(self, w):
  super(model, self).__init__()
  self.w = w

 def forward(self, xx):
  b, c, _, _ = xx.shape
  # extra = xx.device.index + 1 ## 实验(一)
  y = xx.reshape(b, -1).mm(self.w.cuda(xx.device).reshape(-1, 2) * extra)
  return y.reshape(len(xx), -1)


g = 1
x = Variable(torch.ones(2, 1, 2, 2))
# x[1] += 1 ## 实验(二)
w = Variable(torch.ones(2, 2, 2) * 2, requires_grad=True)
# optim = torch.optim.SGD({'params': x},
lr = 0.01
momentum = 0.9
M = model(w)

M = torch.nn.DataParallel(M, device_ids=range(g))

for i in range(3):
 b = len(x)
 z = M(x)
 zz = z.sum(1)
 l = (zz - Variable(torch.ones(b).cuda())).mean()
 # zz.backward(Variable(torch.ones(b).cuda()))
 l.backward()
 print(w.grad, w.grad.shape)
 # w.grad.zero_() ## 实验(三)
 print(i, b, '* * ' * 20)

以上这篇对pytorch中的梯度更新方法详解就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
python 测试实现方法
Dec 24 Python
详解Python中with语句的用法
Apr 15 Python
python构建自定义回调函数详解
Jun 20 Python
Python实现控制台中的进度条功能代码
Dec 22 Python
对pandas中apply函数的用法详解
Apr 10 Python
TensorFlow实现随机训练和批量训练的方法
Apr 28 Python
通过shell+python实现企业微信预警
Mar 07 Python
python实现屏保程序(适用于背单词)
Jul 30 Python
python+OpenCV实现车牌号码识别
Nov 08 Python
Python如何把Spark数据写入ElasticSearch
Apr 18 Python
Selenium浏览器自动化如何上传文件
Apr 06 Python
python turtle绘图
May 04 Python
PyTorch: 梯度下降及反向传播的实例详解
Aug 20 #Python
python爬虫 urllib模块发起post请求过程解析
Aug 20 #Python
pytorch 加载(.pth)格式的模型实例
Aug 20 #Python
python multiprocessing模块用法及原理介绍
Aug 20 #Python
python 并发编程 阻塞IO模型原理解析
Aug 20 #Python
PyTorch中常用的激活函数的方法示例
Aug 20 #Python
Pytorch抽取网络层的Feature Map(Vgg)实例
Aug 20 #Python
You might like
php中使用ExcelFileParser处理excel获得数据(可作批量导入到数据库使用)
2010/08/21 PHP
JS 有名函数表达式全面解析
2010/03/19 Javascript
使用js Math.random()函数生成n到m间的随机数字
2014/10/09 Javascript
2014 HTML5/CSS3热门动画特效TOP10
2014/12/07 Javascript
jQuery trigger()方法用法介绍
2015/01/13 Javascript
JS实现的N多简单无缝滚动代码(包含图文效果)
2015/11/06 Javascript
javascript 中的 delete及delete运算符
2015/11/15 Javascript
每天一篇javascript学习小结(RegExp对象)
2015/11/17 Javascript
JQuery Ajax WebService传递参数的简单实例
2016/11/02 Javascript
详解Vue 普通对象数据更新与 file 对象数据更新
2017/04/26 Javascript
Vue.js用法详解
2017/11/13 Javascript
解决vue 界面在苹果手机上滑动点击事件等卡顿问题
2018/11/27 Javascript
mpvue性能优化实战技巧(小结)
2019/04/17 Javascript
JavaScript基础之this和箭头函数详析
2019/09/05 Javascript
Vue实现剪贴板复制功能
2019/12/31 Javascript
vue+element实现图片上传及裁剪功能
2020/06/29 Javascript
vue打开子组件弹窗都刷新功能的实现
2020/09/21 Javascript
[02:00]DAC2018主宣传片——龙征四海,剑问东方
2018/03/20 DOTA
Python isinstance判断对象类型
2008/09/06 Python
python分割文件的常用方法
2014/11/01 Python
Python通过Django实现用户注册和邮箱验证功能代码
2017/12/11 Python
用matplotlib画等高线图详解
2017/12/14 Python
django中使用事务及接入支付宝支付功能
2019/09/15 Python
Python函数的默认参数设计示例详解
2019/12/01 Python
Python爬虫程序架构和运行流程原理解析
2020/03/09 Python
Python flask路由间传递变量实例详解
2020/06/03 Python
PIP和conda 更换国内安装源的方法步骤
2020/09/21 Python
css3 transform 3d 使用css3创建动态3d立方体(html5实践)
2013/01/06 HTML / CSS
英国花园家具中心:Garden Furniture Centre
2017/08/24 全球购物
来自圣地亚哥的实惠太阳镜:Knockaround
2018/08/27 全球购物
武侯祠导游词
2015/02/04 职场文书
办公室年度工作总结2015
2015/05/21 职场文书
PostgreSQL将数据加载到buffer cache中操作方法
2021/04/16 PostgreSQL
虚拟机linux端mysql数据库无法远程访问的解决办法
2021/05/26 MySQL
html中显示特殊符号(附带特殊字符对应表)
2021/06/21 HTML / CSS
python周期任务调度工具Schedule使用详解
2021/11/23 Python