PyTorch的Optimizer训练工具的实现


Posted in Python onAugust 18, 2019

torch.optim 是一个实现了各种优化算法的库。大部分常用的方法得到支持,并且接口具备足够的通用性,使得未来能够集成更加复杂的方法。

使用 torch.optim,必须构造一个 optimizer 对象。这个对象能保存当前的参数状态并且基于计算梯度更新参数。

例如:

optimizer = optim.SGD(model.parameters(), lr = 0.01, momentum=0.9)
optimizer = optim.Adam([var1, var2], lr = 0.0001)

构造方法

Optimizer 的 __init__ 函数接收两个参数:第一个是需要被优化的参数,其形式必须是 Tensor 或者 dict;第二个是优化选项,包括学习率、衰减率等。

被优化的参数一般是 model.parameters(),当有特殊需求时可以手动写一个 dict 来作为输入。

例如:

optim.SGD([
  {'params': model.base.parameters()},
  {'params': model.classifier.parameters(), 'lr': 1e-3}
], lr=1e-2, momentum=0.9)

这样 model.base 或者说大部分的参数使用 1e-2 的学习率,而 model.classifier 的参数使用 1e-3 的学习率,并且 0.9 的 momentum 被用于所有的参数。

梯度控制

在进行反向传播之前,必须要用 zero_grad() 清空梯度。具体的方法是遍历 self.param_groups 中全部参数,根据 grad 属性做清除。

例如:

for input, target in dataset:
  def closure():
    optimizer.zero_grad()
    output = model(input)
    loss = loss_fn(output, target)
    loss.backward()
    return loss
  optimizer.step(closure)

调整学习率

lr_scheduler 用于在训练过程中根据轮次灵活调控学习率。调整学习率的方法有很多种,但是其使用方法是大致相同的:用一个 Schedule 把原始 Optimizer 装饰上,然后再输入一些相关参数,然后用这个 Schedule 做 step()。

比如以 LambdaLR 举例:

lambda1 = lambda epoch: epoch // 30
lambda2 = lambda epoch: 0.95 ** epoch
scheduler = LambdaLR(optimizer, lr_lambda=[lambda1, lambda2])
for epoch in range(100):
 train(...)
 validate(...)
 scheduler.step()

上面用了两种优化器

优化方法

optim 库中实现的算法包括 Adadelta、Adagrad、Adam、基于离散张量的 Adam、基于 ∞ \infty∞ 范式的 Adam(Adamax)、Averaged SGD、L-BFGS、RMSProp、resilient BP、基于 Nesterov 的 SGD 算法。

以 SGD 举例:

optimizer = torch.optim.SGD(model.parameters(), lr=0.1, momentum=0.9)
optimizer.zero_grad()
loss_fn(model(input), target).backward()
optimizer.step()

其它方法的使用也一样:

opt_Adam = torch.optim.Adam(net_Adam.parameters(), lr=0.1, betas=(0.9, 0.99)
opt_RMSprop = torch.optim.RMSprop(net_RMSprop.parameters(), lr=0.1, alpha=0.9)
...
...

以上就是本文的全部内容,希望对大家的学习有所帮助,也希望大家多多支持三水点靠木。

Python 相关文章推荐
Python函数式编程指南(三):迭代器详解
Jun 24 Python
利用python获取当前日期前后N天或N月日期的方法示例
Jul 30 Python
python执行使用shell命令方法分享
Nov 08 Python
python操作excel的方法(xlsxwriter包的使用)
Jun 11 Python
pygame实现俄罗斯方块游戏(基础篇2)
Oct 29 Python
Python +Selenium解决图片验证码登录或注册问题(推荐)
Feb 09 Python
使用Python打造一款间谍程序的流程分析
Feb 21 Python
查看jupyter notebook每个单元格运行时间实例
Apr 22 Python
如何安装并在pycharm使用selenium的方法
Apr 30 Python
详解Python中Pyyaml模块的使用
Oct 08 Python
python 可视化库PyG2Plot的使用
Jan 21 Python
python 网络编程要点总结
Jun 18 Python
Pytorch反向求导更新网络参数的方法
Aug 17 #Python
pytorch 模型可视化的例子
Aug 17 #Python
pytorch 输出中间层特征的实例
Aug 17 #Python
基于pytorch的保存和加载模型参数的方法
Aug 17 #Python
pytorch 固定部分参数训练的方法
Aug 17 #Python
python之PyQt按钮右键菜单功能的实现代码
Aug 17 #Python
pytorch 在网络中添加可训练参数,修改预训练权重文件的方法
Aug 17 #Python
You might like
十大催泪虐心动漫,你能坚持看到第几部?
2020/03/04 日漫
PHP的可变变量名的使用方法分享
2012/02/05 PHP
PHP中的print_r 与 var_dump 输出数组
2016/06/13 PHP
php getcwd与dirname(__FILE__)区别详解
2016/09/24 PHP
JS弹出层单纯的绝对定位居中示例代码
2014/02/18 Javascript
深入讲解AngularJS中的自定义指令的使用
2015/06/18 Javascript
灵活使用数组制作图片切换js实现
2016/07/28 Javascript
Bootstrap Table服务器分页与在线编辑应用总结
2016/08/08 Javascript
jQuery分页插件jquery.pagination.js使用方法解析
2017/02/09 Javascript
Vue开发中整合axios的文件整理
2017/04/29 Javascript
妙用缓存调用链实现JS方法的重载
2018/04/30 Javascript
详解webpack自定义loader初探
2018/08/29 Javascript
vue+iview 实现可编辑表格的示例代码
2018/10/31 Javascript
ECharts地图绘制和钻取简易接口详解
2019/07/12 Javascript
解决layui表格的表头不滚动的问题
2019/09/04 Javascript
JavaScript判断数组类型的方法
2019/10/23 Javascript
Python+MongoDB自增键值的简单实现
2016/11/04 Python
python递归打印某个目录的内容(实例讲解)
2017/08/30 Python
对numpy Array [: ,] 的取值方法详解
2018/07/02 Python
Python3.5模块的定义、导入、优化操作图文详解
2019/04/27 Python
彻底搞懂 python 中文乱码问题(深入分析)
2020/02/28 Python
如何卸载python插件
2020/07/08 Python
纯CSS实现菜单、导航栏的3D翻转动画效果
2014/04/23 HTML / CSS
Html5元素及基本语法详解
2016/08/02 HTML / CSS
日本热销NO.1胶原蛋白冻:Aishitoto爱希特多
2019/06/20 全球购物
中专毕业生求职简历的自我评价
2013/10/21 职场文书
大学系主任推荐信范文
2013/12/24 职场文书
烹调加工管理制度
2014/02/04 职场文书
现场施工员岗位职责
2014/03/10 职场文书
2014年大学生四年规划书范文
2014/04/03 职场文书
员工安全责任书范本
2014/07/24 职场文书
2014年心理健康教育工作总结
2014/12/06 职场文书
社区党风廉政建设调研报告
2015/01/01 职场文书
幼儿园教师管理制度
2015/08/05 职场文书
Python基础之字符串格式化详解
2021/04/21 Python
windows下快速安装nginx并配置开机自启动的方法
2021/05/11 Servers