Pytorch 中的optimizer使用说明


Posted in Python onMarch 03, 2021

与优化函数相关的部分在torch.optim模块中,其中包含了大部分现在已有的流行的优化方法。

如何使用Optimizer

要想使用optimizer,需要创建一个optimizer 对象,这个对象会保存当前状态,并根据梯度更新参数。

怎样构造Optimizer

要构造一个Optimizer,需要使用一个用来包含所有参数(Tensor形式)的iterable,把相关参数(如learning rate、weight decay等)装进去。

注意,如果想要使用.cuda()方法来将model移到GPU中,一定要确保这一步在构造Optimizer之前。因为调用.cuda()之后,model里面的参数已经不是之前的参数了。

示例代码如下:

optimizer = optim.SGD(model.parameters(), lr = 0.01, momentum = 0.9)
optimizer = optim.Adam([var1, var2], lr = 0.0001)

常用参数

last_epoch代表上一次的epoch的值,初始值为-1。

单独指定参数

也可以用一个dict的iterable指定参数。这里的每个dict都必须要params这个key,params包含它所属的参数列表。除此之外的key必须它的Optimizer(如SGD)里面有的参数。

You can still pass options as keyword arguments. They will be used as defaults, in the groups that didn't override them. This is useful when you only want to vary a single option, while keeping all others consistent between parameter groups.

这在针对特定部分进行操作时很有用。比如只希望给指定的几个层单独设置学习率:

optim.SGD([
  {'params': model.base.parameters()},
  {'params': model.classifier.parameters(), 'lr': 0.001}
  ],
  
  lr = 0.01, momentum = 0.9)

在上面这段代码中model.base将会使用默认学习率0.01,而model.classifier的参数蒋欢使用0.001的学习率。

怎样进行单次优化

所有optimizer都实现了step()方法,调用这个方法可以更新参数,这个方法有以下两种使用方法:

optimizer.step()

多数optimizer里都可以这么做,每次用backward()这类的方法计算出了梯度后,就可以调用一次这个方法来更新参数。

示例程序:

for input, target in dataset:
 optimizer.zero_grad()
 ouput = model(input)
 loss = loss_fn(output, target)
 loss.backward()
 optimizer.step()

optimizer.step(closure)

有些优化算法会多次重新计算函数(比如Conjugate Gradient、LBFGS),这样的话你就要使用一个闭包(closure)来支持多次计算model的操作。

这个closure的运行过程是,清除梯度,计算loss,返回loss。

(这个我不太理解,因为这些优化算法不熟悉)

示例程序:

for input, target in dataset:
  def closure():
    optimizer.zero_grad()
    output = model(input)
    loss = loss_fn(output, target)
    loss.backward()
    return loss
  optimizer.step(closure)

优化算法

这里就不完整介绍documentation中的内容了,只介绍基类。具体的算法的参数需要理解它们的原理才能明白,这个改天单独来一篇文章介绍。

Optimizer

class torch.optim.Optimizer(params, defaults)

这是所有optimizer的基类。

注意,各参数的顺序必须保证每次运行都一致。有些数据结构就不满足这个条件,比如dictionary的iterator和set。

参数

params(iterable)是torch.Tensor或者dict的iterable。这个参数指定了需要更新的Tensor。

defaults(dict)是一个dict,它包含了默认的的优化选项。

方法

add_param_group(param_group)

这个方法的作用是增加一个参数组,在fine tuning一个预训练的网络时有用。

load_state_dict(state_dict)

这个方法的作用是加载optimizer的状态。

state_dict()

获取一个optimizer的状态(一个dict)。

zero_grad()方法用于清空梯度。

step(closure)用于进行单次更新。

Adam

class torch.optim.Adam(params, lr=0.001, betas=(0.9, 0.999), eps=1e-08, weight_decay=0, amsgrad=False)

补充:pytorch里面的Optimizer和optimizer.step()用法

当我们想指定每一层的学习率时:

optim.SGD([
          {'params': model.base.parameters()},
          {'params': model.classifier.parameters(), 'lr': 1e-3}
        ], lr=1e-2, momentum=0.9)

这意味着model.base的参数将会使用1e-2的学习率,model.classifier的参数将会使用1e-3的学习率,并且0.9的momentum将会被用于所有的参数。

进行单次优化

所有的optimizer都实现了step()方法,这个方法会更新所有的参数。它能按两种方式来使用:

optimizer.step()

这是大多数optimizer所支持的简化版本。一旦梯度被如backward()之类的函数计算好后,我们就可以调用这个函数。

例子

for input, target in dataset:
    optimizer.zero_grad()
    output = model(input)
    loss = loss_fn(output, target)
    loss.backward()
    optimizer.step()
optimizer.step(closure)

一些优化算法例如Conjugate Gradient和LBFGS需要重复多次计算函数,因此你需要传入一个闭包去允许它们重新计算你的模型。

这个闭包应当清空梯度,计算损失,然后返回。

例子:

for input, target in dataset:
  def closure():
    optimizer.zero_grad()
    output = model(input)
    loss = loss_fn(output, target)
    loss.backward()
    return loss
  optimizer.step(closure)

以上为个人经验,希望能给大家一个参考,也希望大家多多支持三水点靠木。如有错误或未考虑完全的地方,望不吝赐教。

Python 相关文章推荐
python 使用get_argument获取url query参数
Apr 28 Python
教你利用Python玩转histogram直方图的五种方法
Jul 30 Python
Face++ API实现手势识别系统设计
Nov 21 Python
python 使用正则表达式按照多个空格分割字符的实例
Dec 20 Python
Django2.1集成xadmin管理后台所遇到的错误集锦(填坑)
Dec 20 Python
Python3直接爬取图片URL并保存示例
Dec 18 Python
安装多个版本的TensorFlow的方法步骤
Apr 21 Python
基于Python+QT的gui程序开发实现
Jul 03 Python
用python给csv里的数据排序的具体代码
Jul 17 Python
Windows下PyCharm配置Anaconda环境(超详细教程)
Jul 31 Python
python实现快速文件格式批量转换的方法
Oct 16 Python
Python 中random 库的详细使用
Jun 03 Python
解决pytorch 的state_dict()拷贝问题
Mar 03 #Python
解决pytorch 保存模型遇到的问题
Mar 03 #Python
解决pytorch 模型复制的一些问题
Mar 03 #Python
Pytorch模型迁移和迁移学习,导入部分模型参数的操作
Mar 03 #Python
pytorch 实现L2和L1正则化regularization的操作
Mar 03 #Python
Pytorch自定义Dataset和DataLoader去除不存在和空数据的操作
Mar 03 #Python
python爬取youtube视频的示例代码
Mar 03 #Python
You might like
php采用curl访问域名返回405 method not allowed提示的解决方法
2014/06/26 PHP
基于PHP-FPM进程池探秘
2017/10/17 PHP
PHP中递归的实现实例详解
2017/11/14 PHP
laravel 输出最后执行sql 附:whereIn的使用方法
2019/10/10 PHP
Laravel等框架模型关联的可用性浅析
2019/12/15 PHP
通过JAVAScript实现页面自适应
2007/01/19 Javascript
ExtJS 2.0 实用简明教程之布局概述
2009/04/29 Javascript
打豆豆小游戏 用javascript编写的[打豆豆]小游戏
2013/01/08 Javascript
Jquery 过滤器(first,last,not,even,odd)的使用
2014/01/22 Javascript
jquery遍历checkbox介绍
2014/02/21 Javascript
20条学习javascript的编程规范的建议
2014/11/28 Javascript
学习javascript文件加载优化
2016/02/19 Javascript
jQuery+Ajax+PHP弹出层异步登录效果(附源码下载)
2016/05/27 Javascript
Angular ng-repeat 对象和数组遍历实例
2016/09/14 Javascript
canvas+gif.js打造自己的数字雨头像的示例代码
2017/10/26 Javascript
web前端vue之vuex单独一文件使用方式实例详解
2018/01/11 Javascript
js判断输入框不能为空格或null值的实现方法
2018/03/02 Javascript
Vue iview-admin框架二级菜单改为三级菜单的方法
2018/07/03 Javascript
vue 解决循环引用组件报错的问题
2018/09/06 Javascript
vue项目从node8.x升级到12.x后的问题解决
2019/10/25 Javascript
JavaScript享元模式原理与用法实例详解
2020/03/09 Javascript
Python下rrdtool模块的基本使用方法
2015/11/13 Python
Python中实现最小二乘法思路及实现代码
2018/01/04 Python
批量将ppt转换为pdf的Python代码 只要27行!
2018/02/26 Python
python3通过qq邮箱发送邮件以及附件
2020/05/20 Python
css3和jquery实现自定义checkbox和radiobox组件
2014/04/22 HTML / CSS
电子商务专业实习生自我鉴定
2013/09/24 职场文书
《伯牙绝弦》教学反思
2014/03/02 职场文书
大学军训感言1500字
2014/03/09 职场文书
2014年预备党员端正入党动机思想汇报
2014/09/13 职场文书
2014年污水处理厂工作总结
2014/12/19 职场文书
老乡聚会通知
2015/04/23 职场文书
《时代广场的蟋蟀》读后感:真挚友情,温暖世界!
2020/01/08 职场文书
Python列表删除重复元素与图像相似度判断及删除实例代码
2021/05/07 Python
一文带你理解vue创建一个后台管理系统流程(Vue+Element)
2021/05/18 Vue.js
Python函数中的不定长参数相关知识总结
2021/06/24 Python