PyTorch的Optimizer训练工具的实现


Posted in Python onAugust 18, 2019

torch.optim 是一个实现了各种优化算法的库。大部分常用的方法得到支持,并且接口具备足够的通用性,使得未来能够集成更加复杂的方法。

使用 torch.optim,必须构造一个 optimizer 对象。这个对象能保存当前的参数状态并且基于计算梯度更新参数。

例如:

optimizer = optim.SGD(model.parameters(), lr = 0.01, momentum=0.9)
optimizer = optim.Adam([var1, var2], lr = 0.0001)

构造方法

Optimizer 的 __init__ 函数接收两个参数:第一个是需要被优化的参数,其形式必须是 Tensor 或者 dict;第二个是优化选项,包括学习率、衰减率等。

被优化的参数一般是 model.parameters(),当有特殊需求时可以手动写一个 dict 来作为输入。

例如:

optim.SGD([
  {'params': model.base.parameters()},
  {'params': model.classifier.parameters(), 'lr': 1e-3}
], lr=1e-2, momentum=0.9)

这样 model.base 或者说大部分的参数使用 1e-2 的学习率,而 model.classifier 的参数使用 1e-3 的学习率,并且 0.9 的 momentum 被用于所有的参数。

梯度控制

在进行反向传播之前,必须要用 zero_grad() 清空梯度。具体的方法是遍历 self.param_groups 中全部参数,根据 grad 属性做清除。

例如:

for input, target in dataset:
  def closure():
    optimizer.zero_grad()
    output = model(input)
    loss = loss_fn(output, target)
    loss.backward()
    return loss
  optimizer.step(closure)

调整学习率

lr_scheduler 用于在训练过程中根据轮次灵活调控学习率。调整学习率的方法有很多种,但是其使用方法是大致相同的:用一个 Schedule 把原始 Optimizer 装饰上,然后再输入一些相关参数,然后用这个 Schedule 做 step()。

比如以 LambdaLR 举例:

lambda1 = lambda epoch: epoch // 30
lambda2 = lambda epoch: 0.95 ** epoch
scheduler = LambdaLR(optimizer, lr_lambda=[lambda1, lambda2])
for epoch in range(100):
 train(...)
 validate(...)
 scheduler.step()

上面用了两种优化器

优化方法

optim 库中实现的算法包括 Adadelta、Adagrad、Adam、基于离散张量的 Adam、基于 ∞ \infty∞ 范式的 Adam(Adamax)、Averaged SGD、L-BFGS、RMSProp、resilient BP、基于 Nesterov 的 SGD 算法。

以 SGD 举例:

optimizer = torch.optim.SGD(model.parameters(), lr=0.1, momentum=0.9)
optimizer.zero_grad()
loss_fn(model(input), target).backward()
optimizer.step()

其它方法的使用也一样:

opt_Adam = torch.optim.Adam(net_Adam.parameters(), lr=0.1, betas=(0.9, 0.99)
opt_RMSprop = torch.optim.RMSprop(net_RMSprop.parameters(), lr=0.1, alpha=0.9)
...
...

以上就是本文的全部内容,希望对大家的学习有所帮助,也希望大家多多支持三水点靠木。

Python 相关文章推荐
python获取指定网页上所有超链接的方法
Apr 04 Python
利用Python中的mock库对Python代码进行模拟测试
Apr 16 Python
python+Django+apache的配置方法详解
Jun 01 Python
详解Python中最难理解的点-装饰器
Apr 03 Python
python实现Decorator模式实例代码
Feb 09 Python
python和shell监控linux服务器的详细代码
Jun 22 Python
Python使用folium excel绘制point
Jan 03 Python
python和c语言的主要区别总结
Jul 07 Python
Python 操作 ElasticSearch的完整代码
Aug 04 Python
利用python实现平稳时间序列的建模方式
Jun 03 Python
Python连接mysql方法及常用参数
Sep 01 Python
OpenCV3.3+Python3.6实现图片高斯模糊
May 18 Python
Pytorch反向求导更新网络参数的方法
Aug 17 #Python
pytorch 模型可视化的例子
Aug 17 #Python
pytorch 输出中间层特征的实例
Aug 17 #Python
基于pytorch的保存和加载模型参数的方法
Aug 17 #Python
pytorch 固定部分参数训练的方法
Aug 17 #Python
python之PyQt按钮右键菜单功能的实现代码
Aug 17 #Python
pytorch 在网络中添加可训练参数,修改预训练权重文件的方法
Aug 17 #Python
You might like
php 计算两个时间戳相隔的时间的函数(小时)
2009/12/18 PHP
关于尾递归的使用详解
2013/05/02 PHP
php判断文件上传图片格式的实例详解
2017/09/30 PHP
PHP设计模式之装饰器模式实例详解
2018/02/07 PHP
XAMPP升级PHP版本实现步骤解析
2020/09/04 PHP
IE 条件注释详解总结(附实例代码)
2009/08/29 Javascript
再谈ie和firefox下的document.all属性
2009/10/21 Javascript
js绑定事件this指向发生改变的问题解决方法
2013/04/23 Javascript
js,jquery滚动/跳转页面到指定位置的实现思路
2014/06/03 Javascript
JS数组(Array)处理函数整理
2014/12/07 Javascript
jquery图片轮播特效代码分享
2020/04/20 Javascript
JavaScript利用Date实现简单的倒计时实例
2017/01/12 Javascript
JavaScript 基础表单验证示例(纯Js实现)
2017/07/20 Javascript
Vue2.0利用vue-resource上传文件到七牛的实例代码
2017/07/28 Javascript
Echarts动态加载多条折线图的实现代码
2019/05/24 Javascript
vue使用screenfull插件实现全屏功能
2020/09/17 Javascript
python k-近邻算法实例分享
2014/06/11 Python
python实现用于测试网站访问速率的方法
2015/05/26 Python
Python列出一个文件夹及其子目录的所有文件
2016/06/30 Python
解决出现Incorrect integer value: '' for column 'id' at row 1的问题
2017/10/29 Python
python实现读取excel写入mysql的小工具详解
2017/11/20 Python
python 将字符串转换成字典dict的各种方式总结
2018/03/23 Python
Python浅复制中对象生存周期实例分析
2018/04/02 Python
python如何求解两数的最大公约数
2018/09/27 Python
Python学习笔记之lambda表达式用法详解
2019/08/08 Python
解决安装pyqt5之后无法打开spyder的问题
2019/12/13 Python
python数据爬下来保存的位置
2020/02/17 Python
Django使用rest_framework写出API
2020/05/21 Python
Python-split()函数实例用法讲解
2020/12/18 Python
利用HTML5+CSS3实现3D转换效果实例详解
2017/05/02 HTML / CSS
C#怎么让一个窗口居中显示?
2015/10/20 面试题
公司成立感言
2014/01/11 职场文书
法学专业自我鉴定
2014/02/05 职场文书
廉洁自律准则学习心得体会
2016/01/13 职场文书
CSS 实现多彩、智能的阴影效果
2021/05/12 HTML / CSS
React更新渲染原理深入分析
2022/12/24 Javascript