PyTorch的Optimizer训练工具的实现


Posted in Python onAugust 18, 2019

torch.optim 是一个实现了各种优化算法的库。大部分常用的方法得到支持,并且接口具备足够的通用性,使得未来能够集成更加复杂的方法。

使用 torch.optim,必须构造一个 optimizer 对象。这个对象能保存当前的参数状态并且基于计算梯度更新参数。

例如:

optimizer = optim.SGD(model.parameters(), lr = 0.01, momentum=0.9)
optimizer = optim.Adam([var1, var2], lr = 0.0001)

构造方法

Optimizer 的 __init__ 函数接收两个参数:第一个是需要被优化的参数,其形式必须是 Tensor 或者 dict;第二个是优化选项,包括学习率、衰减率等。

被优化的参数一般是 model.parameters(),当有特殊需求时可以手动写一个 dict 来作为输入。

例如:

optim.SGD([
  {'params': model.base.parameters()},
  {'params': model.classifier.parameters(), 'lr': 1e-3}
], lr=1e-2, momentum=0.9)

这样 model.base 或者说大部分的参数使用 1e-2 的学习率,而 model.classifier 的参数使用 1e-3 的学习率,并且 0.9 的 momentum 被用于所有的参数。

梯度控制

在进行反向传播之前,必须要用 zero_grad() 清空梯度。具体的方法是遍历 self.param_groups 中全部参数,根据 grad 属性做清除。

例如:

for input, target in dataset:
  def closure():
    optimizer.zero_grad()
    output = model(input)
    loss = loss_fn(output, target)
    loss.backward()
    return loss
  optimizer.step(closure)

调整学习率

lr_scheduler 用于在训练过程中根据轮次灵活调控学习率。调整学习率的方法有很多种,但是其使用方法是大致相同的:用一个 Schedule 把原始 Optimizer 装饰上,然后再输入一些相关参数,然后用这个 Schedule 做 step()。

比如以 LambdaLR 举例:

lambda1 = lambda epoch: epoch // 30
lambda2 = lambda epoch: 0.95 ** epoch
scheduler = LambdaLR(optimizer, lr_lambda=[lambda1, lambda2])
for epoch in range(100):
 train(...)
 validate(...)
 scheduler.step()

上面用了两种优化器

优化方法

optim 库中实现的算法包括 Adadelta、Adagrad、Adam、基于离散张量的 Adam、基于 ∞ \infty∞ 范式的 Adam(Adamax)、Averaged SGD、L-BFGS、RMSProp、resilient BP、基于 Nesterov 的 SGD 算法。

以 SGD 举例:

optimizer = torch.optim.SGD(model.parameters(), lr=0.1, momentum=0.9)
optimizer.zero_grad()
loss_fn(model(input), target).backward()
optimizer.step()

其它方法的使用也一样:

opt_Adam = torch.optim.Adam(net_Adam.parameters(), lr=0.1, betas=(0.9, 0.99)
opt_RMSprop = torch.optim.RMSprop(net_RMSprop.parameters(), lr=0.1, alpha=0.9)
...
...

以上就是本文的全部内容,希望对大家的学习有所帮助,也希望大家多多支持三水点靠木。

Python 相关文章推荐
python编程之requests在网络请求中添加cookies参数方法详解
Oct 25 Python
利用Tkinter(python3.6)实现一个简单计算器
Dec 21 Python
python获取酷狗音乐top500的下载地址 MP3格式
Apr 17 Python
Python根据已知邻接矩阵绘制无向图操作示例
Jun 23 Python
解决PySide+Python子线程更新UI线程的问题
Jan 11 Python
Python面向对象程序设计构造函数和析构函数用法分析
Apr 12 Python
python for和else语句趣谈
Jul 02 Python
python把ipynb文件转换成pdf文件过程详解
Jul 09 Python
在django中,关于session的通用设置方法
Aug 06 Python
在python中计算ssim的方法(与Matlab结果一致)
Dec 19 Python
Django media static外部访问Django中的图片设置教程
Apr 07 Python
python 制作一个gui界面的翻译工具
May 14 Python
Pytorch反向求导更新网络参数的方法
Aug 17 #Python
pytorch 模型可视化的例子
Aug 17 #Python
pytorch 输出中间层特征的实例
Aug 17 #Python
基于pytorch的保存和加载模型参数的方法
Aug 17 #Python
pytorch 固定部分参数训练的方法
Aug 17 #Python
python之PyQt按钮右键菜单功能的实现代码
Aug 17 #Python
pytorch 在网络中添加可训练参数,修改预训练权重文件的方法
Aug 17 #Python
You might like
深入解析php中的foreach函数
2013/08/31 PHP
ThinkPHP中的系统常量和预定义常量集合
2014/07/01 PHP
php封装的图片(缩略图)处理类完整实例
2016/10/19 PHP
Nginx环境下PHP flush失效的解决方法
2016/10/19 PHP
通过javascript设置css属性的代码
2009/12/28 Javascript
js对象关系图 方便dom操作
2012/03/18 Javascript
Jquery实现的角色左右选择特效
2014/05/21 Javascript
javascript动态创建链接的方法
2015/05/13 Javascript
JS封装cookie操作函数实例(设置、读取、删除)
2015/11/17 Javascript
JavaScript encodeURI 和encodeURIComponent
2015/12/04 Javascript
javascript实现的左右无缝滚动效果
2016/09/19 Javascript
JS中LocalStorage与SessionStorage五种循序渐进的使用方法
2017/07/12 Javascript
Webpack 服务器端代码打包的示例代码
2017/09/19 Javascript
详解基于vue的服务端渲染框架NUXT
2018/06/20 Javascript
JS实现用特殊符号替换字符串的中间部分区域的实例代码
2018/07/24 Javascript
vue-rx的初步使用教程
2018/09/21 Javascript
node和vue实现商城用户地址模块
2018/12/05 Javascript
VUE前后端学习tab写法实例
2019/08/06 Javascript
Vue-cli3项目引入Typescript的实现方法
2019/10/18 Javascript
vue+elementui 对话框取消 表单验证重置示例
2019/10/29 Javascript
Python查看多台服务器进程的脚本分享
2014/06/11 Python
解决Python中由于logging模块误用导致的内存泄露
2015/04/23 Python
python函数装饰器用法实例详解
2015/06/04 Python
Python的time模块中的常用方法整理
2015/06/18 Python
在Python的Django框架中创建语言文件
2015/07/27 Python
Django自定义manage命令实例代码
2018/02/11 Python
一行python实现树形结构的方法
2019/08/09 Python
Python 利用邮件系统完成远程控制电脑的实现(关机、重启等)
2019/11/19 Python
python 实现保存最新的三份文件,其余的都删掉
2019/12/22 Python
python实现五子棋程序
2020/04/24 Python
python字典的值可以修改吗
2020/06/29 Python
使用Python爬虫爬取小红书完完整整的全过程
2021/01/19 Python
京东奢侈品:全球奢侈品牌
2018/03/17 全球购物
Java和Javasciprt的区别
2012/09/02 面试题
不同浏览器创建XMLHttpRequest方法有什么不同
2014/11/17 面试题
基于PyQt5制作一个群发邮件工具
2022/04/08 Python