关于torch.optim的灵活使用详解(包括重写SGD,加上L1正则)


Posted in Python onFebruary 20, 2020

torch.optim的灵活使用详解

1. 基本用法:

要构建一个优化器Optimizer,必须给它一个包含参数的迭代器来优化,然后,我们可以指定特定的优化选项,

例如学习速率,重量衰减值等。

注:如果要把model放在GPU中,需要在构建一个Optimizer之前就执行model.cuda(),确保优化器里面的参数也是在GPU中。

例子:

optimizer = optim.SGD(model.parameters(), lr = 0.01, momentum=0.9)

2. 灵活的设置各层的学习率

将model中需要进行BP的层的参数送到torch.optim中,这些层不一定是连续的。

这个时候,Optimizer的参数不是一个可迭代的变量,而是一个可迭代的字典

(字典的key必须包含'params'(查看源码可以得知optimizer通过'params'访问parameters),

其他的key就是optimizer可以接受的,比如说'lr','weight_decay'),可以将这些字典构成一个list,

这样就是一个可迭代的字典了。

注:这个时候,可以在optimizer设置选项作为关键字参数传递,这时它们将被认为是默认值(当字典里面没有这个关键字参数key-value对时,就使用这个默认的参数)

This is useful when you only want to vary a single option, while keeping all others consistent between parameter groups.

例子:

optimizer = SGD([
        {'params': model.features12.parameters(), 'lr': 1e-2},
        {'params': model.features22.parameters()},
        {'params': model.features32.parameters()},
        {'params': model.features42.parameters()},
        {'params': model.features52.parameters()},
      ], weight_decay1=5e-4, lr=1e-1, momentum=0.9)

上面创建的optim.SGD类型的Optimizer,lr默认值为1e-1,momentum默认值为0.9。features12的参数学习率为1e-2。

灵活更改各层的学习率

torch.optim.optimizer.Optimizer的初始化函数如下:

__init__(self, params, lr=<object object>, momentum=0, dampening=0, weight_decay=0, nesterov=False)

params (iterable): iterable of parameters to optimize or dicts defining parameter groups (params可以是可迭代的参数,或者一个定义参数组的字典,如上所示,字典的键值包括:params,lr,momentum,dampening,weight_decay,nesterov)

想要改变各层的学习率,可以访问optimizer的param_groups属性。type(optimizer.param_groups) -> list

optimizer.param_groups[0].keys()
Out[21]: ['dampening', 'nesterov', 'params', 'lr', 'weight_decay', 'momentum']

因此,想要更改某层参数的学习率,可以访问optimizer.param_groups,指定某个索引更改'lr'参数就可以。

def adjust_learning_rate(optimizer, decay_rate=0.9):
  for para in optimizer.param_groups:
    para['lr'] = para['lr']*decay_rate

重写torch.optim,加上L1正则

查看torch.optim.SGD等Optimizer的源码,发现没有L1正则的选项,而L1正则更容易得到稀疏解。

这个时候,可以更改/home/smiles/anaconda2/lib/python2.7/site-packages/torch/optim/sgd.py文件,模拟L2正则化的操作。

L1正则化求导如下:

dw = 1 * sign(w)

更改后的sgd.py如下:

import torch
from torch.optim.optimizer import Optimizer, required

class SGD(Optimizer):
  def __init__(self, params, lr=required, momentum=0, dampening=0,
         weight_decay1=0, weight_decay2=0, nesterov=False):
    defaults = dict(lr=lr, momentum=momentum, dampening=dampening,
            weight_decay1=weight_decay1, weight_decay2=weight_decay2, nesterov=nesterov)
    if nesterov and (momentum <= 0 or dampening != 0):
      raise ValueError("Nesterov momentum requires a momentum and zero dampening")
    super(SGD, self).__init__(params, defaults)

  def __setstate__(self, state):
    super(SGD, self).__setstate__(state)
    for group in self.param_groups:
      group.setdefault('nesterov', False)

  def step(self, closure=None):
    """Performs a single optimization step.

    Arguments:
      closure (callable, optional): A closure that reevaluates the model
        and returns the loss.
    """
    loss = None
    if closure is not None:
      loss = closure()

    for group in self.param_groups:
      weight_decay1 = group['weight_decay1']
      weight_decay2 = group['weight_decay2']
      momentum = group['momentum']
      dampening = group['dampening']
      nesterov = group['nesterov']

      for p in group['params']:
        if p.grad is None:
          continue
        d_p = p.grad.data
        if weight_decay1 != 0:
          d_p.add_(weight_decay1, torch.sign(p.data))
        if weight_decay2 != 0:
          d_p.add_(weight_decay2, p.data)
        if momentum != 0:
          param_state = self.state[p]
          if 'momentum_buffer' not in param_state:
            buf = param_state['momentum_buffer'] = torch.zeros_like(p.data)
            buf.mul_(momentum).add_(d_p)
          else:
            buf = param_state['momentum_buffer']
            buf.mul_(momentum).add_(1 - dampening, d_p)
          if nesterov:
            d_p = d_p.add(momentum, buf)
          else:
            d_p = buf

        p.data.add_(-group['lr'], d_p)

    return loss

一个使用的例子:

optimizer = SGD([
        {'params': model.features12.parameters()},
        {'params': model.features22.parameters()},
        {'params': model.features32.parameters()},
        {'params': model.features42.parameters()},
        {'params': model.features52.parameters()},
      ], weight_decay1=5e-4, lr=1e-1, momentum=0.9)

以上这篇关于torch.optim的灵活使用详解(包括重写SGD,加上L1正则)就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
Python中断言Assertion的一些改进方案
Oct 27 Python
python通过socket实现多个连接并实现ssh功能详解
Nov 08 Python
Python随机生成均匀分布在单位圆内的点代码示例
Nov 13 Python
pytorch 调整某一维度数据顺序的方法
Dec 08 Python
Python使用python-docx读写word文档
Aug 26 Python
Python3实现发送邮件和发送短信验证码功能
Jan 07 Python
python安装dlib库报错问题及解决方法
Mar 16 Python
python实现秒杀商品的微信自动提醒功能(代码详解)
Apr 27 Python
python如何进行矩阵运算
Jun 05 Python
python根据字典的键来删除元素的方法
Aug 16 Python
python的数学算法函数及公式用法
Nov 18 Python
Python超简单容易上手的画图工具库推荐
May 10 Python
Python sys模块常用方法解析
Feb 20 #Python
pytorch 实现在一个优化器中设置多个网络参数的例子
Feb 20 #Python
pytorch ImageFolder的覆写实例
Feb 20 #Python
pytorch torchvision.ImageFolder的用法介绍
Feb 20 #Python
详解python常用命令行选项与环境变量
Feb 20 #Python
用什么库写 Python 命令行程序(示例代码详解)
Feb 20 #Python
在 Linux/Mac 下为Python函数添加超时时间的方法
Feb 20 #Python
You might like
用PHP查询搜索引擎排名位置的代码
2010/01/05 PHP
如何使用PHP实现javascript的escape和unescape函数
2013/06/29 PHP
PHP微信PC二维码登陆的实现思路
2017/07/13 PHP
详解PHP防止直接访问.php 文件的实现方法
2017/07/28 PHP
浅谈Laravel POST,PUT,PATCH 路由的区别
2019/10/15 PHP
又一个小巧的图片预加载类
2007/05/05 Javascript
THREE.JS入门教程(6)创建自己的全景图实现步骤
2013/01/25 Javascript
js实现连续英文字符自动换行兼容ie6 ie7和firefox
2013/09/06 Javascript
原生js仿jq判断当前浏览器是否为ie,精确到ie6~8
2014/08/30 Javascript
让angularjs支持浏览器自动填表
2014/11/10 Javascript
javascript的tab切换原理与效果实现方法
2015/01/10 Javascript
简介JavaScript中valueOf()方法的使用
2015/06/05 Javascript
基于jquery插件实现拖拽删除图片功能
2020/08/27 Javascript
Node.js中常规的文件操作总结
2016/10/13 Javascript
微信小程序 下拉菜单简单实例
2017/04/13 Javascript
JavaScript实现获取用户单击body中所有A标签内容的方法
2017/06/05 Javascript
浅谈函数调用的不同方式,以及this的指向
2017/09/17 Javascript
使用D3.js+Vue实现一个简单的柱形图
2018/08/05 Javascript
vue+echarts+datav大屏数据展示及实现中国地图省市县下钻功能
2020/11/16 Javascript
vue实现表格合并功能
2020/12/01 Vue.js
[03:18]【TI9纪实】社区大触GL与木木
2019/08/25 DOTA
python删除列表内容
2015/08/04 Python
Python错误处理操作示例
2018/07/18 Python
详解python3中的真值测试
2018/08/13 Python
Python实现的ftp服务器功能详解【附源码下载】
2019/06/26 Python
iHerb台湾:维生素、保健品和健康产品
2018/01/31 全球购物
美国保健品专家:Life Extension
2018/05/04 全球购物
JSF的标签库有哪些
2012/04/27 面试题
怎么写好自荐书
2014/03/02 职场文书
交通事故和解协议书
2015/01/27 职场文书
售后服务质量承诺书
2015/04/29 职场文书
2019广播稿怎么写
2019/04/17 职场文书
八年级作文之感悟亲情
2019/11/20 职场文书
MySQL优化之如何写出高质量sql语句
2021/05/17 MySQL
开发微信小程序之WXSS样式教程
2022/04/18 HTML / CSS
java中如何截取字符串最后一位
2022/07/07 Java/Android