PyTorch的Optimizer训练工具的实现


Posted in Python onAugust 18, 2019

torch.optim 是一个实现了各种优化算法的库。大部分常用的方法得到支持,并且接口具备足够的通用性,使得未来能够集成更加复杂的方法。

使用 torch.optim,必须构造一个 optimizer 对象。这个对象能保存当前的参数状态并且基于计算梯度更新参数。

例如:

optimizer = optim.SGD(model.parameters(), lr = 0.01, momentum=0.9)
optimizer = optim.Adam([var1, var2], lr = 0.0001)

构造方法

Optimizer 的 __init__ 函数接收两个参数:第一个是需要被优化的参数,其形式必须是 Tensor 或者 dict;第二个是优化选项,包括学习率、衰减率等。

被优化的参数一般是 model.parameters(),当有特殊需求时可以手动写一个 dict 来作为输入。

例如:

optim.SGD([
  {'params': model.base.parameters()},
  {'params': model.classifier.parameters(), 'lr': 1e-3}
], lr=1e-2, momentum=0.9)

这样 model.base 或者说大部分的参数使用 1e-2 的学习率,而 model.classifier 的参数使用 1e-3 的学习率,并且 0.9 的 momentum 被用于所有的参数。

梯度控制

在进行反向传播之前,必须要用 zero_grad() 清空梯度。具体的方法是遍历 self.param_groups 中全部参数,根据 grad 属性做清除。

例如:

for input, target in dataset:
  def closure():
    optimizer.zero_grad()
    output = model(input)
    loss = loss_fn(output, target)
    loss.backward()
    return loss
  optimizer.step(closure)

调整学习率

lr_scheduler 用于在训练过程中根据轮次灵活调控学习率。调整学习率的方法有很多种,但是其使用方法是大致相同的:用一个 Schedule 把原始 Optimizer 装饰上,然后再输入一些相关参数,然后用这个 Schedule 做 step()。

比如以 LambdaLR 举例:

lambda1 = lambda epoch: epoch // 30
lambda2 = lambda epoch: 0.95 ** epoch
scheduler = LambdaLR(optimizer, lr_lambda=[lambda1, lambda2])
for epoch in range(100):
 train(...)
 validate(...)
 scheduler.step()

上面用了两种优化器

优化方法

optim 库中实现的算法包括 Adadelta、Adagrad、Adam、基于离散张量的 Adam、基于 ∞ \infty∞ 范式的 Adam(Adamax)、Averaged SGD、L-BFGS、RMSProp、resilient BP、基于 Nesterov 的 SGD 算法。

以 SGD 举例:

optimizer = torch.optim.SGD(model.parameters(), lr=0.1, momentum=0.9)
optimizer.zero_grad()
loss_fn(model(input), target).backward()
optimizer.step()

其它方法的使用也一样:

opt_Adam = torch.optim.Adam(net_Adam.parameters(), lr=0.1, betas=(0.9, 0.99)
opt_RMSprop = torch.optim.RMSprop(net_RMSprop.parameters(), lr=0.1, alpha=0.9)
...
...

以上就是本文的全部内容,希望对大家的学习有所帮助,也希望大家多多支持三水点靠木。

Python 相关文章推荐
Python Socket编程入门教程
Jul 11 Python
详解Python中的Cookie模块使用
Jul 06 Python
Python的Flask框架应用调用Redis队列数据的方法
Jun 06 Python
Python实现将数据库一键导出为Excel表格的实例
Dec 30 Python
ubuntu系统下 python链接mysql数据库的方法
Jan 09 Python
python设置环境变量的原因和方法
Jun 24 Python
Centos7 下安装最新的python3.8
Oct 28 Python
将python安装信息加入注册表的示例
Nov 20 Python
解决Numpy中sum函数求和结果维度的问题
Dec 06 Python
pytorch实现建立自己的数据集(以mnist为例)
Jan 18 Python
Jupyter notebook如何修改平台字体
May 13 Python
用Python可视化新冠疫情数据
Jan 18 Python
Pytorch反向求导更新网络参数的方法
Aug 17 #Python
pytorch 模型可视化的例子
Aug 17 #Python
pytorch 输出中间层特征的实例
Aug 17 #Python
基于pytorch的保存和加载模型参数的方法
Aug 17 #Python
pytorch 固定部分参数训练的方法
Aug 17 #Python
python之PyQt按钮右键菜单功能的实现代码
Aug 17 #Python
pytorch 在网络中添加可训练参数,修改预训练权重文件的方法
Aug 17 #Python
You might like
德劲1103二次变频版的打磨
2021/03/02 无线电
解析阿里云ubuntu12.04环境下配置Apache+PHP+PHPmyadmin+MYsql
2013/06/26 PHP
PHP7.1方括号数组符号多值复制及指定键值赋值用法分析
2016/09/26 PHP
php下载文件超时时间的设置方法
2016/10/06 PHP
PHP实现上传图片到 zimg 服务器
2016/10/19 PHP
greybox——不开新窗口看新的网页
2007/02/20 Javascript
Jquery 选中表格一列并对表格排序实现原理
2012/12/15 Javascript
jquery DIV撑大让滚动条滚到最底部代码
2013/06/06 Javascript
jquery快捷动态绑定键盘事件的操作函数代码
2013/10/17 Javascript
javascript抽象工厂模式详细说明
2014/12/16 Javascript
分享有关jQuery中animate、slide、fade等动画的连续触发、滞后反复执行的bug
2016/01/10 Javascript
Vue父子模版传值及组件传值的三种方法
2017/11/27 Javascript
webpack3里使用uglifyjs压缩js时打包报错的解决
2018/12/13 Javascript
nodejs读取图片返回给浏览器显示
2019/07/25 NodeJs
Vue CLI项目 axios模块前后端交互的使用(类似ajax提交)
2019/09/01 Javascript
在vue中阻止浏览器后退的实例
2019/11/06 Javascript
JS数组属性去重并校验重复数据
2020/01/10 Javascript
[55:45]DOTA2上海特级锦标赛D组败者赛 Liquid VS COL第一局
2016/02/28 DOTA
wxPython框架类和面板类的使用实例
2014/09/28 Python
Django与JS交互的示例代码
2017/08/23 Python
详解TensorFlow在windows上安装与简单示例
2018/03/05 Python
python实现对变位词的判断方法
2020/04/05 Python
Python collections模块的使用方法
2020/10/09 Python
多个版本的python共存时使用pip的正确做法
2020/10/26 Python
python中实现词云图的示例
2020/12/19 Python
CSS3样式linear-gradient的使用实例
2017/01/16 HTML / CSS
HTML5 source标签:媒介元素定义媒介资源
2018/01/29 HTML / CSS
英国儿童鞋和靴子:Start-Rite
2019/05/06 全球购物
研发工程师的岗位职责
2013/11/18 职场文书
最受欢迎的自我评价
2013/12/22 职场文书
表彰先进集体通报
2014/01/12 职场文书
生产操作工岗位职责
2014/09/16 职场文书
企业三严三实学习心得体会
2014/10/13 职场文书
不尊敬老师的检讨书
2014/12/21 职场文书
利用html+css实现菜单栏缓慢下拉效果的示例代码
2021/03/30 HTML / CSS
用position:sticky完美解决小程序吸顶问题的实现方法
2021/04/24 HTML / CSS