关于torch.optim的灵活使用详解(包括重写SGD,加上L1正则)


Posted in Python onFebruary 20, 2020

torch.optim的灵活使用详解

1. 基本用法:

要构建一个优化器Optimizer,必须给它一个包含参数的迭代器来优化,然后,我们可以指定特定的优化选项,

例如学习速率,重量衰减值等。

注:如果要把model放在GPU中,需要在构建一个Optimizer之前就执行model.cuda(),确保优化器里面的参数也是在GPU中。

例子:

optimizer = optim.SGD(model.parameters(), lr = 0.01, momentum=0.9)

2. 灵活的设置各层的学习率

将model中需要进行BP的层的参数送到torch.optim中,这些层不一定是连续的。

这个时候,Optimizer的参数不是一个可迭代的变量,而是一个可迭代的字典

(字典的key必须包含'params'(查看源码可以得知optimizer通过'params'访问parameters),

其他的key就是optimizer可以接受的,比如说'lr','weight_decay'),可以将这些字典构成一个list,

这样就是一个可迭代的字典了。

注:这个时候,可以在optimizer设置选项作为关键字参数传递,这时它们将被认为是默认值(当字典里面没有这个关键字参数key-value对时,就使用这个默认的参数)

This is useful when you only want to vary a single option, while keeping all others consistent between parameter groups.

例子:

optimizer = SGD([
        {'params': model.features12.parameters(), 'lr': 1e-2},
        {'params': model.features22.parameters()},
        {'params': model.features32.parameters()},
        {'params': model.features42.parameters()},
        {'params': model.features52.parameters()},
      ], weight_decay1=5e-4, lr=1e-1, momentum=0.9)

上面创建的optim.SGD类型的Optimizer,lr默认值为1e-1,momentum默认值为0.9。features12的参数学习率为1e-2。

灵活更改各层的学习率

torch.optim.optimizer.Optimizer的初始化函数如下:

__init__(self, params, lr=<object object>, momentum=0, dampening=0, weight_decay=0, nesterov=False)

params (iterable): iterable of parameters to optimize or dicts defining parameter groups (params可以是可迭代的参数,或者一个定义参数组的字典,如上所示,字典的键值包括:params,lr,momentum,dampening,weight_decay,nesterov)

想要改变各层的学习率,可以访问optimizer的param_groups属性。type(optimizer.param_groups) -> list

optimizer.param_groups[0].keys()
Out[21]: ['dampening', 'nesterov', 'params', 'lr', 'weight_decay', 'momentum']

因此,想要更改某层参数的学习率,可以访问optimizer.param_groups,指定某个索引更改'lr'参数就可以。

def adjust_learning_rate(optimizer, decay_rate=0.9):
  for para in optimizer.param_groups:
    para['lr'] = para['lr']*decay_rate

重写torch.optim,加上L1正则

查看torch.optim.SGD等Optimizer的源码,发现没有L1正则的选项,而L1正则更容易得到稀疏解。

这个时候,可以更改/home/smiles/anaconda2/lib/python2.7/site-packages/torch/optim/sgd.py文件,模拟L2正则化的操作。

L1正则化求导如下:

dw = 1 * sign(w)

更改后的sgd.py如下:

import torch
from torch.optim.optimizer import Optimizer, required

class SGD(Optimizer):
  def __init__(self, params, lr=required, momentum=0, dampening=0,
         weight_decay1=0, weight_decay2=0, nesterov=False):
    defaults = dict(lr=lr, momentum=momentum, dampening=dampening,
            weight_decay1=weight_decay1, weight_decay2=weight_decay2, nesterov=nesterov)
    if nesterov and (momentum <= 0 or dampening != 0):
      raise ValueError("Nesterov momentum requires a momentum and zero dampening")
    super(SGD, self).__init__(params, defaults)

  def __setstate__(self, state):
    super(SGD, self).__setstate__(state)
    for group in self.param_groups:
      group.setdefault('nesterov', False)

  def step(self, closure=None):
    """Performs a single optimization step.

    Arguments:
      closure (callable, optional): A closure that reevaluates the model
        and returns the loss.
    """
    loss = None
    if closure is not None:
      loss = closure()

    for group in self.param_groups:
      weight_decay1 = group['weight_decay1']
      weight_decay2 = group['weight_decay2']
      momentum = group['momentum']
      dampening = group['dampening']
      nesterov = group['nesterov']

      for p in group['params']:
        if p.grad is None:
          continue
        d_p = p.grad.data
        if weight_decay1 != 0:
          d_p.add_(weight_decay1, torch.sign(p.data))
        if weight_decay2 != 0:
          d_p.add_(weight_decay2, p.data)
        if momentum != 0:
          param_state = self.state[p]
          if 'momentum_buffer' not in param_state:
            buf = param_state['momentum_buffer'] = torch.zeros_like(p.data)
            buf.mul_(momentum).add_(d_p)
          else:
            buf = param_state['momentum_buffer']
            buf.mul_(momentum).add_(1 - dampening, d_p)
          if nesterov:
            d_p = d_p.add(momentum, buf)
          else:
            d_p = buf

        p.data.add_(-group['lr'], d_p)

    return loss

一个使用的例子:

optimizer = SGD([
        {'params': model.features12.parameters()},
        {'params': model.features22.parameters()},
        {'params': model.features32.parameters()},
        {'params': model.features42.parameters()},
        {'params': model.features52.parameters()},
      ], weight_decay1=5e-4, lr=1e-1, momentum=0.9)

以上这篇关于torch.optim的灵活使用详解(包括重写SGD,加上L1正则)就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
Python IDE PyCharm的基本快捷键和配置简介
Nov 04 Python
Python使用defaultdict读取文件各列的方法
May 11 Python
Python编写登陆接口的方法
Jul 10 Python
Python3实战之爬虫抓取网易云音乐的热门评论
Oct 09 Python
Python实现进程同步和通信的方法
Jan 02 Python
pandas.DataFrame 根据条件新建列并赋值的方法
Apr 08 Python
python中找出numpy array数组的最值及其索引方法
Apr 17 Python
numpy使用fromstring创建矩阵的实例
Jun 15 Python
Python 实现两个服务器之间文件的上传方法
Feb 13 Python
python之当你发现QTimer不能用时的解决方法
Jun 21 Python
Python 根据日志级别打印不同颜色的日志的方法示例
Aug 08 Python
python中xlrd模块的使用详解
Feb 01 Python
Python sys模块常用方法解析
Feb 20 #Python
pytorch 实现在一个优化器中设置多个网络参数的例子
Feb 20 #Python
pytorch ImageFolder的覆写实例
Feb 20 #Python
pytorch torchvision.ImageFolder的用法介绍
Feb 20 #Python
详解python常用命令行选项与环境变量
Feb 20 #Python
用什么库写 Python 命令行程序(示例代码详解)
Feb 20 #Python
在 Linux/Mac 下为Python函数添加超时时间的方法
Feb 20 #Python
You might like
php中支持多种编码的中文字符串截取函数!
2007/03/20 PHP
PHP 加密解密内部算法
2010/04/22 PHP
php获取通过http协议post提交过来xml数据及解析xml
2012/12/16 PHP
PHP实现动态创建XML文档的方法
2018/03/30 PHP
thinkPHP框架实现的无限回复评论功能示例
2018/06/09 PHP
基于thinkphp6.0的success、error实现方法
2019/11/05 PHP
PHP实现基本留言板功能原理与步骤详解
2020/03/26 PHP
js 颜色选择器(兼容firefox)
2009/03/05 Javascript
javascript innerHTML使用分析
2010/12/03 Javascript
jQuery EasyUI API 中文文档 - DataGrid数据表格
2011/11/17 Javascript
js导出table数据到excel即导出为EXCEL文档的方法
2013/10/10 Javascript
jQuery 删除/替换DOM元素的几种方式
2014/05/20 Javascript
javascript设计简单的秒表计时器
2020/09/05 Javascript
用JavaScript获取页面文档内容的实现代码
2016/06/10 Javascript
微信小程序 WXDropDownMenu组件详解及实例代码
2016/10/24 Javascript
使用OPENLAYERS3实现点选的方法
2020/09/24 Javascript
JavaScript callback回调函数用法实例分析
2018/05/08 Javascript
jQuery插件实现的日历功能示例【附源码下载】
2018/09/07 jQuery
JS使用数组实现的队列功能示例
2019/03/04 Javascript
微信小程序Echarts覆盖正常组件问题解决
2019/07/13 Javascript
Layui实现数据表格中鼠标悬浮图片放大效果,离开时恢复原图的方法
2019/09/11 Javascript
element跨分页操作选择详解
2020/06/29 Javascript
微信小程序实现底部弹出模态框
2020/11/18 Javascript
vue实现两个区域滚动条同步滚动
2020/12/13 Vue.js
Python获取Windows或Linux主机名称通用函数分享
2014/11/22 Python
tensorflow建立一个简单的神经网络的方法
2018/02/10 Python
python调用matplotlib模块绘制柱状图
2019/10/18 Python
python对XML文件的操作实现代码
2020/03/27 Python
selenium设置浏览器为headless无头模式(Chrome和Firefox)
2021/01/08 Python
Html5踩坑记之mandMobile使用小记
2020/04/02 HTML / CSS
台湾饭店和机票预订网站:Expedia台湾
2016/08/05 全球购物
《一株紫丁香》教学反思
2014/02/19 职场文书
公益广告标语
2014/06/19 职场文书
单位委托书
2014/10/15 职场文书
浅谈什么是SpringBoot异常处理自动配置的原理
2021/06/21 Java/Android
pnpm对npm及yarn降维打击详解
2022/08/05 Javascript