PyTorch的Optimizer训练工具的实现


Posted in Python onAugust 18, 2019

torch.optim 是一个实现了各种优化算法的库。大部分常用的方法得到支持,并且接口具备足够的通用性,使得未来能够集成更加复杂的方法。

使用 torch.optim,必须构造一个 optimizer 对象。这个对象能保存当前的参数状态并且基于计算梯度更新参数。

例如:

optimizer = optim.SGD(model.parameters(), lr = 0.01, momentum=0.9)
optimizer = optim.Adam([var1, var2], lr = 0.0001)

构造方法

Optimizer 的 __init__ 函数接收两个参数:第一个是需要被优化的参数,其形式必须是 Tensor 或者 dict;第二个是优化选项,包括学习率、衰减率等。

被优化的参数一般是 model.parameters(),当有特殊需求时可以手动写一个 dict 来作为输入。

例如:

optim.SGD([
  {'params': model.base.parameters()},
  {'params': model.classifier.parameters(), 'lr': 1e-3}
], lr=1e-2, momentum=0.9)

这样 model.base 或者说大部分的参数使用 1e-2 的学习率,而 model.classifier 的参数使用 1e-3 的学习率,并且 0.9 的 momentum 被用于所有的参数。

梯度控制

在进行反向传播之前,必须要用 zero_grad() 清空梯度。具体的方法是遍历 self.param_groups 中全部参数,根据 grad 属性做清除。

例如:

for input, target in dataset:
  def closure():
    optimizer.zero_grad()
    output = model(input)
    loss = loss_fn(output, target)
    loss.backward()
    return loss
  optimizer.step(closure)

调整学习率

lr_scheduler 用于在训练过程中根据轮次灵活调控学习率。调整学习率的方法有很多种,但是其使用方法是大致相同的:用一个 Schedule 把原始 Optimizer 装饰上,然后再输入一些相关参数,然后用这个 Schedule 做 step()。

比如以 LambdaLR 举例:

lambda1 = lambda epoch: epoch // 30
lambda2 = lambda epoch: 0.95 ** epoch
scheduler = LambdaLR(optimizer, lr_lambda=[lambda1, lambda2])
for epoch in range(100):
 train(...)
 validate(...)
 scheduler.step()

上面用了两种优化器

优化方法

optim 库中实现的算法包括 Adadelta、Adagrad、Adam、基于离散张量的 Adam、基于 ∞ \infty∞ 范式的 Adam(Adamax)、Averaged SGD、L-BFGS、RMSProp、resilient BP、基于 Nesterov 的 SGD 算法。

以 SGD 举例:

optimizer = torch.optim.SGD(model.parameters(), lr=0.1, momentum=0.9)
optimizer.zero_grad()
loss_fn(model(input), target).backward()
optimizer.step()

其它方法的使用也一样:

opt_Adam = torch.optim.Adam(net_Adam.parameters(), lr=0.1, betas=(0.9, 0.99)
opt_RMSprop = torch.optim.RMSprop(net_RMSprop.parameters(), lr=0.1, alpha=0.9)
...
...

以上就是本文的全部内容,希望对大家的学习有所帮助,也希望大家多多支持三水点靠木。

Python 相关文章推荐
Python多进程编程技术实例分析
Sep 16 Python
Python中处理字符串之isalpha()方法的使用
May 18 Python
python得到电脑的开机时间方法
Oct 15 Python
Python读取excel指定列生成指定sql脚本的方法
Nov 28 Python
Python使用Selenium爬取淘宝异步加载的数据方法
Dec 17 Python
Python控制键盘鼠标pynput的详细用法
Jan 28 Python
如何使用python把ppt转换成pdf
Jun 29 Python
python调用其他文件函数或类的示例
Jul 16 Python
python数据爬下来保存的位置
Feb 17 Python
Python按照list dict key进行排序过程解析
Apr 04 Python
python中绕过反爬虫的方法总结
Nov 25 Python
pytest fixtures装饰器的使用和如何控制用例的执行顺序
Jan 28 Python
Pytorch反向求导更新网络参数的方法
Aug 17 #Python
pytorch 模型可视化的例子
Aug 17 #Python
pytorch 输出中间层特征的实例
Aug 17 #Python
基于pytorch的保存和加载模型参数的方法
Aug 17 #Python
pytorch 固定部分参数训练的方法
Aug 17 #Python
python之PyQt按钮右键菜单功能的实现代码
Aug 17 #Python
pytorch 在网络中添加可训练参数,修改预训练权重文件的方法
Aug 17 #Python
You might like
PHP 写文本日志实现代码
2010/05/18 PHP
php后台多用户权限组思路与实现程序代码分享
2012/02/13 PHP
php无限极分类实现的两种解决方法
2013/04/28 PHP
ThinkPHP3.1新特性之内容解析输出详解
2014/06/19 PHP
如何解决phpmyadmin导入数据库文件最大限制2048KB
2015/10/09 PHP
PHP新特性详解之命名空间、性状与生成器
2017/07/18 PHP
Laravel使用支付宝进行支付的示例代码
2017/08/16 PHP
laravel 解决多库下的DB::transaction()事务失效问题
2019/10/21 PHP
Laravel Eloquent分表方法并使用模型关联的实现
2019/11/25 PHP
在多个页面使用同一个HTML片段《续》
2011/03/04 Javascript
Jquery中Ajax 缓存带来的影响的解决方法
2011/05/19 Javascript
JS获取月的最后一天与JS得到一个月份最大天数的实例代码
2013/12/16 Javascript
jQuery控制网页打印指定区域的方法
2015/04/07 Javascript
JavaScript统计网站访问次数的实现代码
2015/11/18 Javascript
JavaScript函数内部属性和函数方法实例详解
2016/03/17 Javascript
vue.js 获取当前自定义属性值
2017/06/01 Javascript
微信小程序 循环及嵌套循环的使用总结
2017/09/26 Javascript
详解如何用VUE写一个多用模态框组件模版
2018/09/27 Javascript
Vue.js 中的实用工具方法【推荐】
2019/07/04 Javascript
python算法学习之桶排序算法实例(分块排序)
2013/12/18 Python
python从入门到精通(DAY 2)
2015/12/20 Python
Python二叉搜索树与双向链表转换实现方法
2016/04/29 Python
Python实现简单的获取图片爬虫功能示例
2017/07/12 Python
对pandas进行数据预处理的实例讲解
2018/04/20 Python
python批量将excel内容进行翻译写入功能
2019/10/10 Python
使用ITK-SNAP进行抠图操作并保存mask的实例
2020/07/01 Python
全球酒店预订网站:Hotels.com
2016/08/10 全球购物
init进程的作用
2012/04/12 面试题
英语三分钟演讲稿
2014/08/19 职场文书
2014公司党员自我评价范文
2014/09/11 职场文书
2014年党员整改措施范文
2014/09/21 职场文书
人民调解协议书范本
2014/10/11 职场文书
2015年小学语文工作总结
2015/05/25 职场文书
九年级历史教学反思
2016/02/19 职场文书
2016年社区文体活动总结
2016/04/06 职场文书
Python实现生活常识解答机器人
2021/06/28 Python