PyTorch的Optimizer训练工具的实现


Posted in Python onAugust 18, 2019

torch.optim 是一个实现了各种优化算法的库。大部分常用的方法得到支持,并且接口具备足够的通用性,使得未来能够集成更加复杂的方法。

使用 torch.optim,必须构造一个 optimizer 对象。这个对象能保存当前的参数状态并且基于计算梯度更新参数。

例如:

optimizer = optim.SGD(model.parameters(), lr = 0.01, momentum=0.9)
optimizer = optim.Adam([var1, var2], lr = 0.0001)

构造方法

Optimizer 的 __init__ 函数接收两个参数:第一个是需要被优化的参数,其形式必须是 Tensor 或者 dict;第二个是优化选项,包括学习率、衰减率等。

被优化的参数一般是 model.parameters(),当有特殊需求时可以手动写一个 dict 来作为输入。

例如:

optim.SGD([
  {'params': model.base.parameters()},
  {'params': model.classifier.parameters(), 'lr': 1e-3}
], lr=1e-2, momentum=0.9)

这样 model.base 或者说大部分的参数使用 1e-2 的学习率,而 model.classifier 的参数使用 1e-3 的学习率,并且 0.9 的 momentum 被用于所有的参数。

梯度控制

在进行反向传播之前,必须要用 zero_grad() 清空梯度。具体的方法是遍历 self.param_groups 中全部参数,根据 grad 属性做清除。

例如:

for input, target in dataset:
  def closure():
    optimizer.zero_grad()
    output = model(input)
    loss = loss_fn(output, target)
    loss.backward()
    return loss
  optimizer.step(closure)

调整学习率

lr_scheduler 用于在训练过程中根据轮次灵活调控学习率。调整学习率的方法有很多种,但是其使用方法是大致相同的:用一个 Schedule 把原始 Optimizer 装饰上,然后再输入一些相关参数,然后用这个 Schedule 做 step()。

比如以 LambdaLR 举例:

lambda1 = lambda epoch: epoch // 30
lambda2 = lambda epoch: 0.95 ** epoch
scheduler = LambdaLR(optimizer, lr_lambda=[lambda1, lambda2])
for epoch in range(100):
 train(...)
 validate(...)
 scheduler.step()

上面用了两种优化器

优化方法

optim 库中实现的算法包括 Adadelta、Adagrad、Adam、基于离散张量的 Adam、基于 ∞ \infty∞ 范式的 Adam(Adamax)、Averaged SGD、L-BFGS、RMSProp、resilient BP、基于 Nesterov 的 SGD 算法。

以 SGD 举例:

optimizer = torch.optim.SGD(model.parameters(), lr=0.1, momentum=0.9)
optimizer.zero_grad()
loss_fn(model(input), target).backward()
optimizer.step()

其它方法的使用也一样:

opt_Adam = torch.optim.Adam(net_Adam.parameters(), lr=0.1, betas=(0.9, 0.99)
opt_RMSprop = torch.optim.RMSprop(net_RMSprop.parameters(), lr=0.1, alpha=0.9)
...
...

以上就是本文的全部内容,希望对大家的学习有所帮助,也希望大家多多支持三水点靠木。

Python 相关文章推荐
Fabric 应用案例
Aug 28 Python
python 剪切移动文件的实现代码
Aug 02 Python
kafka-python批量发送数据的实例
Dec 27 Python
Python设计模式之桥接模式原理与用法实例分析
Jan 10 Python
python实现ip地址查询经纬度定位详解
Aug 30 Python
python在OpenCV里实现投影变换效果
Aug 30 Python
Python计算两个矩形重合面积代码实例
Sep 16 Python
python获取全国城市pm2.5、臭氧等空气质量过程解析
Oct 12 Python
Python imageio读取视频并进行编解码详解
Dec 10 Python
python logging设置level失败的解决方法
Feb 19 Python
基于python实现复制文件并重命名
Sep 16 Python
python实现测试工具(二)——简单的ui测试工具
Oct 19 Python
Pytorch反向求导更新网络参数的方法
Aug 17 #Python
pytorch 模型可视化的例子
Aug 17 #Python
pytorch 输出中间层特征的实例
Aug 17 #Python
基于pytorch的保存和加载模型参数的方法
Aug 17 #Python
pytorch 固定部分参数训练的方法
Aug 17 #Python
python之PyQt按钮右键菜单功能的实现代码
Aug 17 #Python
pytorch 在网络中添加可训练参数,修改预训练权重文件的方法
Aug 17 #Python
You might like
php 获取远程网页内容的函数
2009/09/08 PHP
PHP答题类应用接口实例
2015/02/09 PHP
使用PHP处理数据库数据如何将数据返回客户端并显示当前状态
2016/02/16 PHP
PHP数组相加操作及与array_merge的区别浅析
2016/11/26 PHP
PHP实现生成模糊图片的方法示例
2017/12/21 PHP
JavaScript 比较时间大小的代码
2010/04/24 Javascript
javascript 文本框水印/占位符(watermark/placeholder)实现方法
2012/01/15 Javascript
jquery限定文本框只能输入数字即整数和小数
2013/11/29 Javascript
Jquery easyUI 更新行示例
2014/03/06 Javascript
jQuery filter函数使用方法
2014/05/19 Javascript
jQuery调用ajax请求的常见方法汇总
2015/03/24 Javascript
WordPress中鼠标悬停显示和隐藏评论及引用按钮的实现
2016/01/12 Javascript
JavaScipt中栈的实现方法
2016/02/17 Javascript
微信小程序 image组件binderror使用例子与js中的onerror区别
2017/02/15 Javascript
详解vue中computed 和 watch的异同
2017/06/30 Javascript
Angular2之二级路由详解
2018/08/31 Javascript
vue组件中节流函数的失效的原因和解决方法
2020/12/02 Vue.js
Python ORM框架SQLAlchemy学习笔记之数据添加和事务回滚介绍
2014/06/10 Python
Python函数中的函数(闭包)用法实例
2016/03/15 Python
python初学之用户登录的实现过程(实例讲解)
2017/12/23 Python
解决python3 安装完Pycurl在import pycurl时报错的问题
2018/10/15 Python
Selenium启动Chrome时配置选项详解
2020/03/18 Python
python解释器安装教程的方法步骤
2020/07/02 Python
详解Python中的文件操作
2021/01/14 Python
各大浏览器 CSS3 和 HTML5 兼容速查表 图文
2010/04/01 HTML / CSS
html5各种页面切换效果和模态对话框用法总结
2014/12/15 HTML / CSS
智能旅行箱:Horizn Studios
2018/04/30 全球购物
电子商务专业个人的自我评价
2013/12/19 职场文书
2014新年寄语
2014/01/20 职场文书
骨干教师培训方案
2014/05/06 职场文书
普通党员对照检查材料
2014/08/28 职场文书
关于运动会的广播稿
2014/09/22 职场文书
Python编解码问题及文本文件处理方法详解
2021/06/20 Python
Python多线程 Queue 模块常见用法
2021/07/04 Python
Python 数据可视化神器Pyecharts绘制图像练习
2022/02/28 Python
Java Lambda表达式常用的函数式接口
2022/04/07 Java/Android