关于torch.optim的灵活使用详解(包括重写SGD,加上L1正则)


Posted in Python onFebruary 20, 2020

torch.optim的灵活使用详解

1. 基本用法:

要构建一个优化器Optimizer,必须给它一个包含参数的迭代器来优化,然后,我们可以指定特定的优化选项,

例如学习速率,重量衰减值等。

注:如果要把model放在GPU中,需要在构建一个Optimizer之前就执行model.cuda(),确保优化器里面的参数也是在GPU中。

例子:

optimizer = optim.SGD(model.parameters(), lr = 0.01, momentum=0.9)

2. 灵活的设置各层的学习率

将model中需要进行BP的层的参数送到torch.optim中,这些层不一定是连续的。

这个时候,Optimizer的参数不是一个可迭代的变量,而是一个可迭代的字典

(字典的key必须包含'params'(查看源码可以得知optimizer通过'params'访问parameters),

其他的key就是optimizer可以接受的,比如说'lr','weight_decay'),可以将这些字典构成一个list,

这样就是一个可迭代的字典了。

注:这个时候,可以在optimizer设置选项作为关键字参数传递,这时它们将被认为是默认值(当字典里面没有这个关键字参数key-value对时,就使用这个默认的参数)

This is useful when you only want to vary a single option, while keeping all others consistent between parameter groups.

例子:

optimizer = SGD([
        {'params': model.features12.parameters(), 'lr': 1e-2},
        {'params': model.features22.parameters()},
        {'params': model.features32.parameters()},
        {'params': model.features42.parameters()},
        {'params': model.features52.parameters()},
      ], weight_decay1=5e-4, lr=1e-1, momentum=0.9)

上面创建的optim.SGD类型的Optimizer,lr默认值为1e-1,momentum默认值为0.9。features12的参数学习率为1e-2。

灵活更改各层的学习率

torch.optim.optimizer.Optimizer的初始化函数如下:

__init__(self, params, lr=<object object>, momentum=0, dampening=0, weight_decay=0, nesterov=False)

params (iterable): iterable of parameters to optimize or dicts defining parameter groups (params可以是可迭代的参数,或者一个定义参数组的字典,如上所示,字典的键值包括:params,lr,momentum,dampening,weight_decay,nesterov)

想要改变各层的学习率,可以访问optimizer的param_groups属性。type(optimizer.param_groups) -> list

optimizer.param_groups[0].keys()
Out[21]: ['dampening', 'nesterov', 'params', 'lr', 'weight_decay', 'momentum']

因此,想要更改某层参数的学习率,可以访问optimizer.param_groups,指定某个索引更改'lr'参数就可以。

def adjust_learning_rate(optimizer, decay_rate=0.9):
  for para in optimizer.param_groups:
    para['lr'] = para['lr']*decay_rate

重写torch.optim,加上L1正则

查看torch.optim.SGD等Optimizer的源码,发现没有L1正则的选项,而L1正则更容易得到稀疏解。

这个时候,可以更改/home/smiles/anaconda2/lib/python2.7/site-packages/torch/optim/sgd.py文件,模拟L2正则化的操作。

L1正则化求导如下:

dw = 1 * sign(w)

更改后的sgd.py如下:

import torch
from torch.optim.optimizer import Optimizer, required

class SGD(Optimizer):
  def __init__(self, params, lr=required, momentum=0, dampening=0,
         weight_decay1=0, weight_decay2=0, nesterov=False):
    defaults = dict(lr=lr, momentum=momentum, dampening=dampening,
            weight_decay1=weight_decay1, weight_decay2=weight_decay2, nesterov=nesterov)
    if nesterov and (momentum <= 0 or dampening != 0):
      raise ValueError("Nesterov momentum requires a momentum and zero dampening")
    super(SGD, self).__init__(params, defaults)

  def __setstate__(self, state):
    super(SGD, self).__setstate__(state)
    for group in self.param_groups:
      group.setdefault('nesterov', False)

  def step(self, closure=None):
    """Performs a single optimization step.

    Arguments:
      closure (callable, optional): A closure that reevaluates the model
        and returns the loss.
    """
    loss = None
    if closure is not None:
      loss = closure()

    for group in self.param_groups:
      weight_decay1 = group['weight_decay1']
      weight_decay2 = group['weight_decay2']
      momentum = group['momentum']
      dampening = group['dampening']
      nesterov = group['nesterov']

      for p in group['params']:
        if p.grad is None:
          continue
        d_p = p.grad.data
        if weight_decay1 != 0:
          d_p.add_(weight_decay1, torch.sign(p.data))
        if weight_decay2 != 0:
          d_p.add_(weight_decay2, p.data)
        if momentum != 0:
          param_state = self.state[p]
          if 'momentum_buffer' not in param_state:
            buf = param_state['momentum_buffer'] = torch.zeros_like(p.data)
            buf.mul_(momentum).add_(d_p)
          else:
            buf = param_state['momentum_buffer']
            buf.mul_(momentum).add_(1 - dampening, d_p)
          if nesterov:
            d_p = d_p.add(momentum, buf)
          else:
            d_p = buf

        p.data.add_(-group['lr'], d_p)

    return loss

一个使用的例子:

optimizer = SGD([
        {'params': model.features12.parameters()},
        {'params': model.features22.parameters()},
        {'params': model.features32.parameters()},
        {'params': model.features42.parameters()},
        {'params': model.features52.parameters()},
      ], weight_decay1=5e-4, lr=1e-1, momentum=0.9)

以上这篇关于torch.optim的灵活使用详解(包括重写SGD,加上L1正则)就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
Python的subprocess模块总结
Nov 07 Python
python操作gmail实例
Jan 14 Python
使用PDB模式调试Python程序介绍
Apr 05 Python
python enumerate函数的使用方法总结
Nov 15 Python
python爬取拉勾网职位数据的方法
Jan 24 Python
Python编程在flask中模拟进行Restful的CRUD操作
Dec 28 Python
Python OpenCV对本地视频文件进行分帧保存的实例
Jan 08 Python
Python 读取用户指令和格式化打印实现解析
Sep 02 Python
Python从文件中读取指定的行以及在文件指定位置写入
Sep 06 Python
python logging添加filter教程
Dec 24 Python
借助Paramiko通过Python实现linux远程登陆及sftp的操作
Mar 16 Python
python函数指定默认值的实例讲解
Mar 29 Python
Python sys模块常用方法解析
Feb 20 #Python
pytorch 实现在一个优化器中设置多个网络参数的例子
Feb 20 #Python
pytorch ImageFolder的覆写实例
Feb 20 #Python
pytorch torchvision.ImageFolder的用法介绍
Feb 20 #Python
详解python常用命令行选项与环境变量
Feb 20 #Python
用什么库写 Python 命令行程序(示例代码详解)
Feb 20 #Python
在 Linux/Mac 下为Python函数添加超时时间的方法
Feb 20 #Python
You might like
用PHP来写记数器(详细介绍)
2006/10/09 PHP
使用php实现快钱支付功能(涉及到接口)
2013/07/01 PHP
javascript算法学习(直接插入排序)
2011/04/12 Javascript
js history对象简单实现返回和前进
2013/10/30 Javascript
jquery获得keycode的示例代码
2013/12/30 Javascript
jQuery中closest和parents的区别分析
2015/05/07 Javascript
js实现div拖动动画运行轨迹效果代码分享
2015/08/27 Javascript
实例代码详解jquery.slides.js
2015/11/16 Javascript
基于AngularJS+HTML+Groovy实现登录功能
2016/02/17 Javascript
全面理解JavaScript中的闭包
2016/05/12 Javascript
vue.js国际化 vue-i18n插件的使用详解
2017/07/07 Javascript
JavaScript实现音乐自动切换和轮播
2017/11/05 Javascript
jQuery层叠选择器用法实例分析
2019/06/28 jQuery
Vue路由权限控制解析
2020/11/09 Javascript
跟老齐学Python之玩转字符串(3)
2014/09/14 Python
Python3读取zip文件信息的方法
2015/05/22 Python
CentOS安装pillow报错的解决方法
2016/01/27 Python
Python自动化测试ConfigParser模块读写配置文件
2016/08/15 Python
K-近邻算法的python实现代码分享
2017/12/09 Python
Python基于pycrypto实现的AES加密和解密算法示例
2018/04/10 Python
详解如何用TensorFlow训练和识别/分类自定义图片
2019/08/05 Python
python生成随机红包的实例写法
2019/09/02 Python
python开根号实例讲解
2020/08/30 Python
Django限制API访问频率常用方法解析
2020/10/12 Python
教你如何一步一步用Canvas写一个贪吃蛇
2018/10/22 HTML / CSS
初一英语教学反思
2014/01/11 职场文书
感恩节活动方案
2014/01/27 职场文书
中学生个人自我评价
2014/02/06 职场文书
大学理论知识学习自我鉴定
2014/04/28 职场文书
校园绿化美化方案
2014/06/08 职场文书
建筑施工安全责任书
2014/07/24 职场文书
地质工程专业毕业生求职信
2014/08/08 职场文书
群众路线教育实践活动民主生活会个人检查对照思想汇报
2014/10/04 职场文书
售后服务承诺函格式
2015/01/21 职场文书
2016企业先进集体事迹材料
2016/02/25 职场文书
python爬虫selenium模块详解
2021/03/30 Python