关于torch.optim的灵活使用详解(包括重写SGD,加上L1正则)


Posted in Python onFebruary 20, 2020

torch.optim的灵活使用详解

1. 基本用法:

要构建一个优化器Optimizer,必须给它一个包含参数的迭代器来优化,然后,我们可以指定特定的优化选项,

例如学习速率,重量衰减值等。

注:如果要把model放在GPU中,需要在构建一个Optimizer之前就执行model.cuda(),确保优化器里面的参数也是在GPU中。

例子:

optimizer = optim.SGD(model.parameters(), lr = 0.01, momentum=0.9)

2. 灵活的设置各层的学习率

将model中需要进行BP的层的参数送到torch.optim中,这些层不一定是连续的。

这个时候,Optimizer的参数不是一个可迭代的变量,而是一个可迭代的字典

(字典的key必须包含'params'(查看源码可以得知optimizer通过'params'访问parameters),

其他的key就是optimizer可以接受的,比如说'lr','weight_decay'),可以将这些字典构成一个list,

这样就是一个可迭代的字典了。

注:这个时候,可以在optimizer设置选项作为关键字参数传递,这时它们将被认为是默认值(当字典里面没有这个关键字参数key-value对时,就使用这个默认的参数)

This is useful when you only want to vary a single option, while keeping all others consistent between parameter groups.

例子:

optimizer = SGD([
        {'params': model.features12.parameters(), 'lr': 1e-2},
        {'params': model.features22.parameters()},
        {'params': model.features32.parameters()},
        {'params': model.features42.parameters()},
        {'params': model.features52.parameters()},
      ], weight_decay1=5e-4, lr=1e-1, momentum=0.9)

上面创建的optim.SGD类型的Optimizer,lr默认值为1e-1,momentum默认值为0.9。features12的参数学习率为1e-2。

灵活更改各层的学习率

torch.optim.optimizer.Optimizer的初始化函数如下:

__init__(self, params, lr=<object object>, momentum=0, dampening=0, weight_decay=0, nesterov=False)

params (iterable): iterable of parameters to optimize or dicts defining parameter groups (params可以是可迭代的参数,或者一个定义参数组的字典,如上所示,字典的键值包括:params,lr,momentum,dampening,weight_decay,nesterov)

想要改变各层的学习率,可以访问optimizer的param_groups属性。type(optimizer.param_groups) -> list

optimizer.param_groups[0].keys()
Out[21]: ['dampening', 'nesterov', 'params', 'lr', 'weight_decay', 'momentum']

因此,想要更改某层参数的学习率,可以访问optimizer.param_groups,指定某个索引更改'lr'参数就可以。

def adjust_learning_rate(optimizer, decay_rate=0.9):
  for para in optimizer.param_groups:
    para['lr'] = para['lr']*decay_rate

重写torch.optim,加上L1正则

查看torch.optim.SGD等Optimizer的源码,发现没有L1正则的选项,而L1正则更容易得到稀疏解。

这个时候,可以更改/home/smiles/anaconda2/lib/python2.7/site-packages/torch/optim/sgd.py文件,模拟L2正则化的操作。

L1正则化求导如下:

dw = 1 * sign(w)

更改后的sgd.py如下:

import torch
from torch.optim.optimizer import Optimizer, required

class SGD(Optimizer):
  def __init__(self, params, lr=required, momentum=0, dampening=0,
         weight_decay1=0, weight_decay2=0, nesterov=False):
    defaults = dict(lr=lr, momentum=momentum, dampening=dampening,
            weight_decay1=weight_decay1, weight_decay2=weight_decay2, nesterov=nesterov)
    if nesterov and (momentum <= 0 or dampening != 0):
      raise ValueError("Nesterov momentum requires a momentum and zero dampening")
    super(SGD, self).__init__(params, defaults)

  def __setstate__(self, state):
    super(SGD, self).__setstate__(state)
    for group in self.param_groups:
      group.setdefault('nesterov', False)

  def step(self, closure=None):
    """Performs a single optimization step.

    Arguments:
      closure (callable, optional): A closure that reevaluates the model
        and returns the loss.
    """
    loss = None
    if closure is not None:
      loss = closure()

    for group in self.param_groups:
      weight_decay1 = group['weight_decay1']
      weight_decay2 = group['weight_decay2']
      momentum = group['momentum']
      dampening = group['dampening']
      nesterov = group['nesterov']

      for p in group['params']:
        if p.grad is None:
          continue
        d_p = p.grad.data
        if weight_decay1 != 0:
          d_p.add_(weight_decay1, torch.sign(p.data))
        if weight_decay2 != 0:
          d_p.add_(weight_decay2, p.data)
        if momentum != 0:
          param_state = self.state[p]
          if 'momentum_buffer' not in param_state:
            buf = param_state['momentum_buffer'] = torch.zeros_like(p.data)
            buf.mul_(momentum).add_(d_p)
          else:
            buf = param_state['momentum_buffer']
            buf.mul_(momentum).add_(1 - dampening, d_p)
          if nesterov:
            d_p = d_p.add(momentum, buf)
          else:
            d_p = buf

        p.data.add_(-group['lr'], d_p)

    return loss

一个使用的例子:

optimizer = SGD([
        {'params': model.features12.parameters()},
        {'params': model.features22.parameters()},
        {'params': model.features32.parameters()},
        {'params': model.features42.parameters()},
        {'params': model.features52.parameters()},
      ], weight_decay1=5e-4, lr=1e-1, momentum=0.9)

以上这篇关于torch.optim的灵活使用详解(包括重写SGD,加上L1正则)就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
python学习笔记:字典的使用示例详解
Jun 13 Python
Python实现的彩票机选器实例
Jun 17 Python
Django中反向生成models.py的实例讲解
May 30 Python
Python3用tkinter和PIL实现看图工具
Jun 21 Python
python 列表,数组和矩阵sum的用法及区别介绍
Jun 28 Python
Python日期时间模块datetime详解与Python 日期时间的比较,计算实例代码
Sep 14 Python
python实现连续图文识别
Dec 18 Python
Pycharm激活码激活两种快速方式(附最新激活码和插件)
Mar 12 Python
jupyter notebook清除输出方式
Apr 10 Python
一篇文章带你搞定Ubuntu中打开Pycharm总是卡顿崩溃
Nov 02 Python
学会迭代器设计模式,帮你大幅提升python性能
Jan 03 Python
上手简单,功能强大的Python爬虫框架——feapder
Apr 27 Python
Python sys模块常用方法解析
Feb 20 #Python
pytorch 实现在一个优化器中设置多个网络参数的例子
Feb 20 #Python
pytorch ImageFolder的覆写实例
Feb 20 #Python
pytorch torchvision.ImageFolder的用法介绍
Feb 20 #Python
详解python常用命令行选项与环境变量
Feb 20 #Python
用什么库写 Python 命令行程序(示例代码详解)
Feb 20 #Python
在 Linux/Mac 下为Python函数添加超时时间的方法
Feb 20 #Python
You might like
从零开始学YII2框架(三)扩展插件yii2-gird
2014/08/20 PHP
Laravel5中Cookie的使用详解
2017/05/03 PHP
php-fpm重启导致的程序执行中断问题详解
2019/04/29 PHP
Add a Table to a Word Document
2007/06/15 Javascript
JS URL传中文参数引发的乱码问题
2009/09/02 Javascript
JavaScript arguments 多参传值函数
2010/10/24 Javascript
解析javascript 实用函数的使用详解
2013/05/10 Javascript
Jquery和JS用外部变量获取Ajax返回的参数值的方法实例(超简单)
2013/06/17 Javascript
浅析LigerUi开发中谨慎载入common.css文件
2013/07/09 Javascript
js中string转int把String类型转化成int类型
2014/08/13 Javascript
JavaScript获取DOM元素的11种方法总结
2015/04/25 Javascript
JS中产生标识符方式的演变
2015/06/12 Javascript
Immutable 在 JavaScript 中的应用
2016/05/02 Javascript
js实现无缝循环滚动
2020/06/23 Javascript
js数组操作方法总结(必看篇)
2016/11/22 Javascript
使用JavaScript解决网页图片拉伸问题(推荐)
2016/11/25 Javascript
Vue2单一事件管理组件通信
2017/05/09 Javascript
Vue 使用中的小技巧
2018/04/26 Javascript
JavaScript常见JSON操作实例分析
2018/08/08 Javascript
详解在vue-test-utils中mock全局对象
2018/11/07 Javascript
使用jquery-easyui的布局layout写后台管理页面的代码详解
2019/06/19 jQuery
Python中pip安装非PyPI官网第三方库的方法
2015/06/02 Python
对pandas replace函数的使用方法小结
2018/05/18 Python
python3实现单目标粒子群算法
2019/11/14 Python
Python实现多线程下载脚本的示例代码
2020/04/03 Python
联想墨西哥官方网站:Lenovo墨西哥
2016/08/17 全球购物
Clarks鞋法国官方网站:英国其乐鞋品牌
2018/02/11 全球购物
金融专业大学生职业生涯规划范文
2014/01/16 职场文书
函授毕业生自我鉴定范文
2014/03/25 职场文书
活动宣传策划方案
2014/05/23 职场文书
2014年民政局关于保密工作整改措施
2014/09/19 职场文书
2014年调度员工作总结
2014/11/19 职场文书
销售2014年度工作总结
2014/12/08 职场文书
作文评语集锦
2014/12/25 职场文书
公司周年庆典致辞
2015/07/30 职场文书
2019年公司快递收发管理制度模板
2019/11/20 职场文书