PyTorch的Optimizer训练工具的实现


Posted in Python onAugust 18, 2019

torch.optim 是一个实现了各种优化算法的库。大部分常用的方法得到支持,并且接口具备足够的通用性,使得未来能够集成更加复杂的方法。

使用 torch.optim,必须构造一个 optimizer 对象。这个对象能保存当前的参数状态并且基于计算梯度更新参数。

例如:

optimizer = optim.SGD(model.parameters(), lr = 0.01, momentum=0.9)
optimizer = optim.Adam([var1, var2], lr = 0.0001)

构造方法

Optimizer 的 __init__ 函数接收两个参数:第一个是需要被优化的参数,其形式必须是 Tensor 或者 dict;第二个是优化选项,包括学习率、衰减率等。

被优化的参数一般是 model.parameters(),当有特殊需求时可以手动写一个 dict 来作为输入。

例如:

optim.SGD([
  {'params': model.base.parameters()},
  {'params': model.classifier.parameters(), 'lr': 1e-3}
], lr=1e-2, momentum=0.9)

这样 model.base 或者说大部分的参数使用 1e-2 的学习率,而 model.classifier 的参数使用 1e-3 的学习率,并且 0.9 的 momentum 被用于所有的参数。

梯度控制

在进行反向传播之前,必须要用 zero_grad() 清空梯度。具体的方法是遍历 self.param_groups 中全部参数,根据 grad 属性做清除。

例如:

for input, target in dataset:
  def closure():
    optimizer.zero_grad()
    output = model(input)
    loss = loss_fn(output, target)
    loss.backward()
    return loss
  optimizer.step(closure)

调整学习率

lr_scheduler 用于在训练过程中根据轮次灵活调控学习率。调整学习率的方法有很多种,但是其使用方法是大致相同的:用一个 Schedule 把原始 Optimizer 装饰上,然后再输入一些相关参数,然后用这个 Schedule 做 step()。

比如以 LambdaLR 举例:

lambda1 = lambda epoch: epoch // 30
lambda2 = lambda epoch: 0.95 ** epoch
scheduler = LambdaLR(optimizer, lr_lambda=[lambda1, lambda2])
for epoch in range(100):
 train(...)
 validate(...)
 scheduler.step()

上面用了两种优化器

优化方法

optim 库中实现的算法包括 Adadelta、Adagrad、Adam、基于离散张量的 Adam、基于 ∞ \infty∞ 范式的 Adam(Adamax)、Averaged SGD、L-BFGS、RMSProp、resilient BP、基于 Nesterov 的 SGD 算法。

以 SGD 举例:

optimizer = torch.optim.SGD(model.parameters(), lr=0.1, momentum=0.9)
optimizer.zero_grad()
loss_fn(model(input), target).backward()
optimizer.step()

其它方法的使用也一样:

opt_Adam = torch.optim.Adam(net_Adam.parameters(), lr=0.1, betas=(0.9, 0.99)
opt_RMSprop = torch.optim.RMSprop(net_RMSprop.parameters(), lr=0.1, alpha=0.9)
...
...

以上就是本文的全部内容,希望对大家的学习有所帮助,也希望大家多多支持三水点靠木。

Python 相关文章推荐
Python中if __name__ == '__main__'作用解析
Jun 29 Python
Python读写unicode文件的方法
Jul 10 Python
Python网络编程中urllib2模块的用法总结
Jul 12 Python
Python生成随机密码的方法
Jun 16 Python
python利用lxml读写xml格式的文件
Aug 10 Python
Django分页查询并返回jsons数据(中文乱码解决方法)
Aug 02 Python
在python中使用with打开多个文件的方法
Jan 07 Python
python3 pygame实现接小球游戏
May 14 Python
python tkinter图形界面代码统计工具
Sep 18 Python
如何在Django中使用聚合的实现示例
Mar 23 Python
Python Tornado实现WEB服务器Socket服务器共存并实现交互的方法
May 26 Python
Django def clean()函数对表单中的数据进行验证操作
Jul 09 Python
Pytorch反向求导更新网络参数的方法
Aug 17 #Python
pytorch 模型可视化的例子
Aug 17 #Python
pytorch 输出中间层特征的实例
Aug 17 #Python
基于pytorch的保存和加载模型参数的方法
Aug 17 #Python
pytorch 固定部分参数训练的方法
Aug 17 #Python
python之PyQt按钮右键菜单功能的实现代码
Aug 17 #Python
pytorch 在网络中添加可训练参数,修改预训练权重文件的方法
Aug 17 #Python
You might like
分享一下贝贝成长进度的php代码
2012/09/14 PHP
YII模块实现绑定二级域名的方法
2014/07/09 PHP
php中preg_replace_callback函数简单用法示例
2016/07/21 PHP
用js实现的检测浏览器和系统的函数
2009/04/09 Javascript
JS backgroundImage控制
2009/05/19 Javascript
自己的js工具 Cookie 封装
2009/08/21 Javascript
JavaScript Event学习第九章 鼠标事件
2010/02/08 Javascript
类似GMAIL的Ajax信息反馈显示
2010/02/16 Javascript
非阻塞动态加载javascript广告实现代码
2010/11/17 Javascript
初学js插入节点appendChild insertBefore使用方法
2011/07/04 Javascript
jquery 简单应用示例总结
2013/08/09 Javascript
extjs表格文本启用选择复制功能具体实现
2013/10/11 Javascript
javascript制作网页图片上实现下雨效果
2015/02/26 Javascript
jQuery检测滚动条是否到达底部
2015/12/15 Javascript
jquery+css3实现会动的小圆圈效果
2016/01/27 Javascript
AngularJS使用ng-options指令实现下拉框
2016/08/23 Javascript
js实现拖拽功能
2017/03/01 Javascript
Javascript中类式继承和原型式继承的实现方法和区别之处
2017/04/25 Javascript
深入理解NodeJS 多进程和集群
2018/10/17 NodeJs
微信小程序位置授权处理方法
2019/06/13 Javascript
Python学生成绩管理系统简洁版
2020/04/05 Python
Python实现简易版的Web服务器(推荐)
2018/01/29 Python
Python序列循环移位的3种方法推荐
2018/04/09 Python
python贪婪匹配以及多行匹配的实例讲解
2018/04/19 Python
Python虚拟环境库virtualenvwrapper安装及使用
2020/06/17 Python
Django实现简单的分页功能
2021/02/22 Python
Spartoo西班牙官网:法国时尚购物网站
2018/03/27 全球购物
Watch Station官方网站:世界一流的手表和智能手表
2020/01/05 全球购物
Trench London官方网站:高级风衣和意大利皮夹克
2020/07/11 全球购物
Kingsoft金山公司C/C++笔试题
2016/05/10 面试题
医学生自我评价
2014/01/27 职场文书
迎国庆演讲稿
2014/09/05 职场文书
项目经理岗位职责
2015/01/31 职场文书
浅谈JS的二进制家族
2021/05/09 Javascript
使用CSS定位HTML元素的实现方法
2022/07/07 HTML / CSS
MySQL索引失效场景及解决方案
2022/07/23 MySQL