Pytorch 中的optimizer使用说明


Posted in Python onMarch 03, 2021

与优化函数相关的部分在torch.optim模块中,其中包含了大部分现在已有的流行的优化方法。

如何使用Optimizer

要想使用optimizer,需要创建一个optimizer 对象,这个对象会保存当前状态,并根据梯度更新参数。

怎样构造Optimizer

要构造一个Optimizer,需要使用一个用来包含所有参数(Tensor形式)的iterable,把相关参数(如learning rate、weight decay等)装进去。

注意,如果想要使用.cuda()方法来将model移到GPU中,一定要确保这一步在构造Optimizer之前。因为调用.cuda()之后,model里面的参数已经不是之前的参数了。

示例代码如下:

optimizer = optim.SGD(model.parameters(), lr = 0.01, momentum = 0.9)
optimizer = optim.Adam([var1, var2], lr = 0.0001)

常用参数

last_epoch代表上一次的epoch的值,初始值为-1。

单独指定参数

也可以用一个dict的iterable指定参数。这里的每个dict都必须要params这个key,params包含它所属的参数列表。除此之外的key必须它的Optimizer(如SGD)里面有的参数。

You can still pass options as keyword arguments. They will be used as defaults, in the groups that didn't override them. This is useful when you only want to vary a single option, while keeping all others consistent between parameter groups.

这在针对特定部分进行操作时很有用。比如只希望给指定的几个层单独设置学习率:

optim.SGD([
  {'params': model.base.parameters()},
  {'params': model.classifier.parameters(), 'lr': 0.001}
  ],
  
  lr = 0.01, momentum = 0.9)

在上面这段代码中model.base将会使用默认学习率0.01,而model.classifier的参数蒋欢使用0.001的学习率。

怎样进行单次优化

所有optimizer都实现了step()方法,调用这个方法可以更新参数,这个方法有以下两种使用方法:

optimizer.step()

多数optimizer里都可以这么做,每次用backward()这类的方法计算出了梯度后,就可以调用一次这个方法来更新参数。

示例程序:

for input, target in dataset:
 optimizer.zero_grad()
 ouput = model(input)
 loss = loss_fn(output, target)
 loss.backward()
 optimizer.step()

optimizer.step(closure)

有些优化算法会多次重新计算函数(比如Conjugate Gradient、LBFGS),这样的话你就要使用一个闭包(closure)来支持多次计算model的操作。

这个closure的运行过程是,清除梯度,计算loss,返回loss。

(这个我不太理解,因为这些优化算法不熟悉)

示例程序:

for input, target in dataset:
  def closure():
    optimizer.zero_grad()
    output = model(input)
    loss = loss_fn(output, target)
    loss.backward()
    return loss
  optimizer.step(closure)

优化算法

这里就不完整介绍documentation中的内容了,只介绍基类。具体的算法的参数需要理解它们的原理才能明白,这个改天单独来一篇文章介绍。

Optimizer

class torch.optim.Optimizer(params, defaults)

这是所有optimizer的基类。

注意,各参数的顺序必须保证每次运行都一致。有些数据结构就不满足这个条件,比如dictionary的iterator和set。

参数

params(iterable)是torch.Tensor或者dict的iterable。这个参数指定了需要更新的Tensor。

defaults(dict)是一个dict,它包含了默认的的优化选项。

方法

add_param_group(param_group)

这个方法的作用是增加一个参数组,在fine tuning一个预训练的网络时有用。

load_state_dict(state_dict)

这个方法的作用是加载optimizer的状态。

state_dict()

获取一个optimizer的状态(一个dict)。

zero_grad()方法用于清空梯度。

step(closure)用于进行单次更新。

Adam

class torch.optim.Adam(params, lr=0.001, betas=(0.9, 0.999), eps=1e-08, weight_decay=0, amsgrad=False)

补充:pytorch里面的Optimizer和optimizer.step()用法

当我们想指定每一层的学习率时:

optim.SGD([
          {'params': model.base.parameters()},
          {'params': model.classifier.parameters(), 'lr': 1e-3}
        ], lr=1e-2, momentum=0.9)

这意味着model.base的参数将会使用1e-2的学习率,model.classifier的参数将会使用1e-3的学习率,并且0.9的momentum将会被用于所有的参数。

进行单次优化

所有的optimizer都实现了step()方法,这个方法会更新所有的参数。它能按两种方式来使用:

optimizer.step()

这是大多数optimizer所支持的简化版本。一旦梯度被如backward()之类的函数计算好后,我们就可以调用这个函数。

例子

for input, target in dataset:
    optimizer.zero_grad()
    output = model(input)
    loss = loss_fn(output, target)
    loss.backward()
    optimizer.step()
optimizer.step(closure)

一些优化算法例如Conjugate Gradient和LBFGS需要重复多次计算函数,因此你需要传入一个闭包去允许它们重新计算你的模型。

这个闭包应当清空梯度,计算损失,然后返回。

例子:

for input, target in dataset:
  def closure():
    optimizer.zero_grad()
    output = model(input)
    loss = loss_fn(output, target)
    loss.backward()
    return loss
  optimizer.step(closure)

以上为个人经验,希望能给大家一个参考,也希望大家多多支持三水点靠木。如有错误或未考虑完全的地方,望不吝赐教。

Python 相关文章推荐
python中的错误处理
Apr 10 Python
Scrapy框架CrawlSpiders的介绍以及使用详解
Nov 29 Python
python中实现控制小数点位数的方法
Jan 24 Python
python+os根据文件名自动生成文本
Mar 21 Python
pyqt实现.ui文件批量转换为对应.py文件脚本
Jun 19 Python
python与mysql数据库交互的实现
Jan 06 Python
pip install 使用国内镜像的方法示例
Apr 03 Python
jupyter notebook 多环境conda kernel配置方式
Apr 10 Python
如何理解python面向对象编程
Jun 01 Python
Python flask框架如何显示图像到web页面
Jun 03 Python
python自动从arxiv下载paper的示例代码
Dec 05 Python
Python 发送SMTP邮件的简单教程
Jun 24 Python
解决pytorch 的state_dict()拷贝问题
Mar 03 #Python
解决pytorch 保存模型遇到的问题
Mar 03 #Python
解决pytorch 模型复制的一些问题
Mar 03 #Python
Pytorch模型迁移和迁移学习,导入部分模型参数的操作
Mar 03 #Python
pytorch 实现L2和L1正则化regularization的操作
Mar 03 #Python
Pytorch自定义Dataset和DataLoader去除不存在和空数据的操作
Mar 03 #Python
python爬取youtube视频的示例代码
Mar 03 #Python
You might like
ThinkPHP学习笔记(一)ThinkPHP部署
2014/06/22 PHP
php版微信公众平台之微信网页登陆授权示例
2016/09/23 PHP
Javascript延迟执行实现方法(setTimeout)
2010/12/30 Javascript
JavaScript打开word文档的实现代码(c#)
2012/04/16 Javascript
用Jquery选择器计算table中的某一列某一行的合计
2014/08/13 Javascript
Javscript调用iframe框架页面中函数的方法
2014/11/01 Javascript
javascript面向对象之访问对象属性的两种方式分析
2015/01/13 Javascript
JS根据生日算年龄的方法
2015/05/05 Javascript
vue,angular,avalon这三种MVVM框架优缺点
2016/04/27 Javascript
JS 清除字符串数组中,重复元素的实现方法
2016/05/24 Javascript
js微信分享API
2020/10/11 Javascript
jQuery Easyui使用(二)之可折叠面板动态加载无效果的解决方法
2016/08/17 Javascript
javascript验证内容为数字以及长度为10的简单实例
2016/08/20 Javascript
微信小程序 wxapp导航 navigator详解
2016/10/31 Javascript
Javascript this 函数深入详解
2016/12/13 Javascript
jQuery无刷新上传之uploadify简单代码
2017/01/17 Javascript
NodeJS实现图片上传代码(Express)
2017/06/30 NodeJs
基于zepto.js实现手机相册功能
2017/07/11 Javascript
vue 配置多页面应用的示例代码
2018/10/22 Javascript
layui富文本编辑器前端无法取值的解决方法
2019/09/18 Javascript
Element InfiniteScroll无限滚动的具体使用方法
2020/07/27 Javascript
[02:49]DOTA2完美大师赛首日观众采访
2017/11/23 DOTA
python中的列表推导浅析
2014/04/26 Python
对变量赋值的理解--Pyton中让两个值互换的实现方法
2017/11/29 Python
Python文件常见操作实例分析【读写、遍历】
2018/12/10 Python
PyQt5 实现给窗口设置背景图片的方法
2019/06/13 Python
解决tensorflow添加ptb库的问题
2020/02/10 Python
python爬取”顶点小说网“《纯阳剑尊》的示例代码
2020/10/16 Python
CSS3 3D位移translate效果实例介绍
2016/05/03 HTML / CSS
canvas实现圆形进度条动画的示例代码
2017/12/26 HTML / CSS
印度尼西亚综合购物网站:Lazada印尼
2016/09/07 全球购物
Genny意大利官网:意大利高级时装品牌
2020/04/15 全球购物
大学生创业感言
2014/01/25 职场文书
2015年学生会纪检部工作总结
2015/03/31 职场文书
酒店工程部的岗位职责汇总大全
2019/10/23 职场文书
根德5570型九灯四波段立体声收音机是电子管收音机的楷模 ? 再论5570
2022/04/05 无线电