关于torch.optim的灵活使用详解(包括重写SGD,加上L1正则)


Posted in Python onFebruary 20, 2020

torch.optim的灵活使用详解

1. 基本用法:

要构建一个优化器Optimizer,必须给它一个包含参数的迭代器来优化,然后,我们可以指定特定的优化选项,

例如学习速率,重量衰减值等。

注:如果要把model放在GPU中,需要在构建一个Optimizer之前就执行model.cuda(),确保优化器里面的参数也是在GPU中。

例子:

optimizer = optim.SGD(model.parameters(), lr = 0.01, momentum=0.9)

2. 灵活的设置各层的学习率

将model中需要进行BP的层的参数送到torch.optim中,这些层不一定是连续的。

这个时候,Optimizer的参数不是一个可迭代的变量,而是一个可迭代的字典

(字典的key必须包含'params'(查看源码可以得知optimizer通过'params'访问parameters),

其他的key就是optimizer可以接受的,比如说'lr','weight_decay'),可以将这些字典构成一个list,

这样就是一个可迭代的字典了。

注:这个时候,可以在optimizer设置选项作为关键字参数传递,这时它们将被认为是默认值(当字典里面没有这个关键字参数key-value对时,就使用这个默认的参数)

This is useful when you only want to vary a single option, while keeping all others consistent between parameter groups.

例子:

optimizer = SGD([
        {'params': model.features12.parameters(), 'lr': 1e-2},
        {'params': model.features22.parameters()},
        {'params': model.features32.parameters()},
        {'params': model.features42.parameters()},
        {'params': model.features52.parameters()},
      ], weight_decay1=5e-4, lr=1e-1, momentum=0.9)

上面创建的optim.SGD类型的Optimizer,lr默认值为1e-1,momentum默认值为0.9。features12的参数学习率为1e-2。

灵活更改各层的学习率

torch.optim.optimizer.Optimizer的初始化函数如下:

__init__(self, params, lr=<object object>, momentum=0, dampening=0, weight_decay=0, nesterov=False)

params (iterable): iterable of parameters to optimize or dicts defining parameter groups (params可以是可迭代的参数,或者一个定义参数组的字典,如上所示,字典的键值包括:params,lr,momentum,dampening,weight_decay,nesterov)

想要改变各层的学习率,可以访问optimizer的param_groups属性。type(optimizer.param_groups) -> list

optimizer.param_groups[0].keys()
Out[21]: ['dampening', 'nesterov', 'params', 'lr', 'weight_decay', 'momentum']

因此,想要更改某层参数的学习率,可以访问optimizer.param_groups,指定某个索引更改'lr'参数就可以。

def adjust_learning_rate(optimizer, decay_rate=0.9):
  for para in optimizer.param_groups:
    para['lr'] = para['lr']*decay_rate

重写torch.optim,加上L1正则

查看torch.optim.SGD等Optimizer的源码,发现没有L1正则的选项,而L1正则更容易得到稀疏解。

这个时候,可以更改/home/smiles/anaconda2/lib/python2.7/site-packages/torch/optim/sgd.py文件,模拟L2正则化的操作。

L1正则化求导如下:

dw = 1 * sign(w)

更改后的sgd.py如下:

import torch
from torch.optim.optimizer import Optimizer, required

class SGD(Optimizer):
  def __init__(self, params, lr=required, momentum=0, dampening=0,
         weight_decay1=0, weight_decay2=0, nesterov=False):
    defaults = dict(lr=lr, momentum=momentum, dampening=dampening,
            weight_decay1=weight_decay1, weight_decay2=weight_decay2, nesterov=nesterov)
    if nesterov and (momentum <= 0 or dampening != 0):
      raise ValueError("Nesterov momentum requires a momentum and zero dampening")
    super(SGD, self).__init__(params, defaults)

  def __setstate__(self, state):
    super(SGD, self).__setstate__(state)
    for group in self.param_groups:
      group.setdefault('nesterov', False)

  def step(self, closure=None):
    """Performs a single optimization step.

    Arguments:
      closure (callable, optional): A closure that reevaluates the model
        and returns the loss.
    """
    loss = None
    if closure is not None:
      loss = closure()

    for group in self.param_groups:
      weight_decay1 = group['weight_decay1']
      weight_decay2 = group['weight_decay2']
      momentum = group['momentum']
      dampening = group['dampening']
      nesterov = group['nesterov']

      for p in group['params']:
        if p.grad is None:
          continue
        d_p = p.grad.data
        if weight_decay1 != 0:
          d_p.add_(weight_decay1, torch.sign(p.data))
        if weight_decay2 != 0:
          d_p.add_(weight_decay2, p.data)
        if momentum != 0:
          param_state = self.state[p]
          if 'momentum_buffer' not in param_state:
            buf = param_state['momentum_buffer'] = torch.zeros_like(p.data)
            buf.mul_(momentum).add_(d_p)
          else:
            buf = param_state['momentum_buffer']
            buf.mul_(momentum).add_(1 - dampening, d_p)
          if nesterov:
            d_p = d_p.add(momentum, buf)
          else:
            d_p = buf

        p.data.add_(-group['lr'], d_p)

    return loss

一个使用的例子:

optimizer = SGD([
        {'params': model.features12.parameters()},
        {'params': model.features22.parameters()},
        {'params': model.features32.parameters()},
        {'params': model.features42.parameters()},
        {'params': model.features52.parameters()},
      ], weight_decay1=5e-4, lr=1e-1, momentum=0.9)

以上这篇关于torch.optim的灵活使用详解(包括重写SGD,加上L1正则)就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
Python中使用PIPE操作Linux管道
Feb 04 Python
Python实现简单截取中文字符串的方法
Jun 15 Python
Python中matplotlib中文乱码解决办法
May 12 Python
Python 实现数据库(SQL)更新脚本的生成方法
Jul 09 Python
Python实现将文本生成二维码的方法示例
Jul 18 Python
python查找指定文件夹下所有文件并按修改时间倒序排列的方法
Oct 21 Python
python实现一个简单的udp通信的示例代码
Feb 01 Python
python读取并写入mat文件的方法
Jul 12 Python
matplotlib绘制多个子图(subplot)的方法
Dec 03 Python
python tkinter之 复选、文本、下拉的实现
Mar 04 Python
python爬虫请求头设置代码
Jul 28 Python
PyTorch 中的傅里叶卷积实现示例
Dec 11 Python
Python sys模块常用方法解析
Feb 20 #Python
pytorch 实现在一个优化器中设置多个网络参数的例子
Feb 20 #Python
pytorch ImageFolder的覆写实例
Feb 20 #Python
pytorch torchvision.ImageFolder的用法介绍
Feb 20 #Python
详解python常用命令行选项与环境变量
Feb 20 #Python
用什么库写 Python 命令行程序(示例代码详解)
Feb 20 #Python
在 Linux/Mac 下为Python函数添加超时时间的方法
Feb 20 #Python
You might like
Ajax+PHP 边学边练 之二 实例
2009/11/24 PHP
zf框架的数据库追踪器使用示例
2014/03/13 PHP
重新认识php array_merge函数
2014/08/31 PHP
关于laravel 数据库迁移中integer类型是无法指定长度的问题
2019/10/09 PHP
用jscript启动sqlserver
2007/06/21 Javascript
JavaScript DOM 学习第七章 表单的扩展
2010/02/19 Javascript
基于jquery实现一张图片点击鼠标放大再点缩小
2013/09/29 Javascript
javascript实现继承的简单实例
2015/07/26 Javascript
Node.js 实现简单的接口服务器的实例代码
2017/05/23 Javascript
Vue实现virtual-dom的原理简析
2017/07/10 Javascript
一个简易的js图片轮播效果
2017/07/22 Javascript
JS计算两个数组的交集、差集、并集、补集(多种实现方式)
2019/05/21 Javascript
JS/CSS实现字符串单词首字母大写功能
2019/09/03 Javascript
nodejs实现聊天机器人功能
2019/09/19 NodeJs
在SSM框架下用laypage和ajax实现分页和数据交互的方法
2019/09/27 Javascript
浅析vue-cli3配置webpack-bundle-analyzer插件【推荐】
2019/10/23 Javascript
jquery实现聊天机器人
2020/02/08 jQuery
js+canvas实现纸牌游戏
2020/03/16 Javascript
Element el-button 按钮组件的使用详解
2021/02/01 Javascript
Python实现的二维码生成小软件
2014/07/11 Python
python requests 使用快速入门
2017/08/31 Python
Python实现读取字符串按列分配后按行输出示例
2018/04/17 Python
Python数据可视化之画图
2019/01/15 Python
python定时复制远程文件夹中所有文件
2019/04/30 Python
Python 实现将大图切片成小图,将小图组合成大图的例子
2020/03/14 Python
python3+openCV 获取图片中文本区域的最小外接矩形实例
2020/06/02 Python
pandas创建DataFrame的7种方法小结
2020/06/14 Python
CSS3 渐变(Gradients)之CSS3 径向渐变
2016/07/08 HTML / CSS
编写html5时调试发现脚本php等网页js、css等失效
2013/12/31 HTML / CSS
酒店服务实习自我鉴定
2013/09/22 职场文书
电钳专业个人求职信
2014/01/04 职场文书
志愿者活动总结
2014/04/28 职场文书
个人三严三实对照检查材料
2014/09/25 职场文书
法学专业毕业实习自我鉴定2014
2014/09/27 职场文书
违反工作规定检讨书范文
2014/12/14 职场文书
2015年社区计生工作总结
2015/04/21 职场文书