关于torch.optim的灵活使用详解(包括重写SGD,加上L1正则)


Posted in Python onFebruary 20, 2020

torch.optim的灵活使用详解

1. 基本用法:

要构建一个优化器Optimizer,必须给它一个包含参数的迭代器来优化,然后,我们可以指定特定的优化选项,

例如学习速率,重量衰减值等。

注:如果要把model放在GPU中,需要在构建一个Optimizer之前就执行model.cuda(),确保优化器里面的参数也是在GPU中。

例子:

optimizer = optim.SGD(model.parameters(), lr = 0.01, momentum=0.9)

2. 灵活的设置各层的学习率

将model中需要进行BP的层的参数送到torch.optim中,这些层不一定是连续的。

这个时候,Optimizer的参数不是一个可迭代的变量,而是一个可迭代的字典

(字典的key必须包含'params'(查看源码可以得知optimizer通过'params'访问parameters),

其他的key就是optimizer可以接受的,比如说'lr','weight_decay'),可以将这些字典构成一个list,

这样就是一个可迭代的字典了。

注:这个时候,可以在optimizer设置选项作为关键字参数传递,这时它们将被认为是默认值(当字典里面没有这个关键字参数key-value对时,就使用这个默认的参数)

This is useful when you only want to vary a single option, while keeping all others consistent between parameter groups.

例子:

optimizer = SGD([
        {'params': model.features12.parameters(), 'lr': 1e-2},
        {'params': model.features22.parameters()},
        {'params': model.features32.parameters()},
        {'params': model.features42.parameters()},
        {'params': model.features52.parameters()},
      ], weight_decay1=5e-4, lr=1e-1, momentum=0.9)

上面创建的optim.SGD类型的Optimizer,lr默认值为1e-1,momentum默认值为0.9。features12的参数学习率为1e-2。

灵活更改各层的学习率

torch.optim.optimizer.Optimizer的初始化函数如下:

__init__(self, params, lr=<object object>, momentum=0, dampening=0, weight_decay=0, nesterov=False)

params (iterable): iterable of parameters to optimize or dicts defining parameter groups (params可以是可迭代的参数,或者一个定义参数组的字典,如上所示,字典的键值包括:params,lr,momentum,dampening,weight_decay,nesterov)

想要改变各层的学习率,可以访问optimizer的param_groups属性。type(optimizer.param_groups) -> list

optimizer.param_groups[0].keys()
Out[21]: ['dampening', 'nesterov', 'params', 'lr', 'weight_decay', 'momentum']

因此,想要更改某层参数的学习率,可以访问optimizer.param_groups,指定某个索引更改'lr'参数就可以。

def adjust_learning_rate(optimizer, decay_rate=0.9):
  for para in optimizer.param_groups:
    para['lr'] = para['lr']*decay_rate

重写torch.optim,加上L1正则

查看torch.optim.SGD等Optimizer的源码,发现没有L1正则的选项,而L1正则更容易得到稀疏解。

这个时候,可以更改/home/smiles/anaconda2/lib/python2.7/site-packages/torch/optim/sgd.py文件,模拟L2正则化的操作。

L1正则化求导如下:

dw = 1 * sign(w)

更改后的sgd.py如下:

import torch
from torch.optim.optimizer import Optimizer, required

class SGD(Optimizer):
  def __init__(self, params, lr=required, momentum=0, dampening=0,
         weight_decay1=0, weight_decay2=0, nesterov=False):
    defaults = dict(lr=lr, momentum=momentum, dampening=dampening,
            weight_decay1=weight_decay1, weight_decay2=weight_decay2, nesterov=nesterov)
    if nesterov and (momentum <= 0 or dampening != 0):
      raise ValueError("Nesterov momentum requires a momentum and zero dampening")
    super(SGD, self).__init__(params, defaults)

  def __setstate__(self, state):
    super(SGD, self).__setstate__(state)
    for group in self.param_groups:
      group.setdefault('nesterov', False)

  def step(self, closure=None):
    """Performs a single optimization step.

    Arguments:
      closure (callable, optional): A closure that reevaluates the model
        and returns the loss.
    """
    loss = None
    if closure is not None:
      loss = closure()

    for group in self.param_groups:
      weight_decay1 = group['weight_decay1']
      weight_decay2 = group['weight_decay2']
      momentum = group['momentum']
      dampening = group['dampening']
      nesterov = group['nesterov']

      for p in group['params']:
        if p.grad is None:
          continue
        d_p = p.grad.data
        if weight_decay1 != 0:
          d_p.add_(weight_decay1, torch.sign(p.data))
        if weight_decay2 != 0:
          d_p.add_(weight_decay2, p.data)
        if momentum != 0:
          param_state = self.state[p]
          if 'momentum_buffer' not in param_state:
            buf = param_state['momentum_buffer'] = torch.zeros_like(p.data)
            buf.mul_(momentum).add_(d_p)
          else:
            buf = param_state['momentum_buffer']
            buf.mul_(momentum).add_(1 - dampening, d_p)
          if nesterov:
            d_p = d_p.add(momentum, buf)
          else:
            d_p = buf

        p.data.add_(-group['lr'], d_p)

    return loss

一个使用的例子:

optimizer = SGD([
        {'params': model.features12.parameters()},
        {'params': model.features22.parameters()},
        {'params': model.features32.parameters()},
        {'params': model.features42.parameters()},
        {'params': model.features52.parameters()},
      ], weight_decay1=5e-4, lr=1e-1, momentum=0.9)

以上这篇关于torch.optim的灵活使用详解(包括重写SGD,加上L1正则)就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
python访问mysql数据库的实现方法(2则示例)
Jan 06 Python
总结python爬虫抓站的实用技巧
Aug 09 Python
Python 装饰器深入理解
Mar 16 Python
Python的爬虫框架scrapy用21行代码写一个爬虫
Apr 24 Python
matplotlib在python上绘制3D散点图实例详解
Dec 09 Python
Python整数对象实现原理详解
Jul 01 Python
Python 使用folium绘制leaflet地图的实现方法
Jul 05 Python
Python实现二叉搜索树BST的方法示例
Jul 30 Python
python实现输入任意一个大写字母生成金字塔的示例
Oct 27 Python
Python子进程subpocess原理及用法解析
Jul 16 Python
python使用glob检索文件的操作
May 20 Python
Python手拉手教你爬取贝壳房源数据的实战教程
May 21 Python
Python sys模块常用方法解析
Feb 20 #Python
pytorch 实现在一个优化器中设置多个网络参数的例子
Feb 20 #Python
pytorch ImageFolder的覆写实例
Feb 20 #Python
pytorch torchvision.ImageFolder的用法介绍
Feb 20 #Python
详解python常用命令行选项与环境变量
Feb 20 #Python
用什么库写 Python 命令行程序(示例代码详解)
Feb 20 #Python
在 Linux/Mac 下为Python函数添加超时时间的方法
Feb 20 #Python
You might like
BBS(php &amp; mysql)完整版(二)
2006/10/09 PHP
php 在文件指定行插入数据的代码
2010/05/08 PHP
二招解决php乱码问题
2012/03/25 PHP
php计算当前程序执行时间示例
2014/04/24 PHP
PHP strtotime函数用法、实现原理和源码分析
2015/02/04 PHP
使用prototype.js进行异步操作
2007/02/07 Javascript
下拉菜单点击实现连接跳转功能的js代码
2013/05/19 Javascript
jQuery实现的多级下拉菜单效果代码
2015/08/24 Javascript
原生JS实现旋转木马式图片轮播插件
2016/04/25 Javascript
Bootstrap模态对话框的简单使用
2016/04/29 Javascript
js从数组中删除指定值(不是指定位置)的元素实现代码
2016/09/13 Javascript
JS中检测数据类型的几种方式及优缺点小结
2016/12/12 Javascript
通过命令行生成vue项目框架的方法
2017/07/12 Javascript
VUE 实现滚动监听 导航栏置顶的方法
2018/09/11 Javascript
angularJs中json数据转换与本地存储的实例
2018/10/08 Javascript
Seajs源码详解分析
2019/04/02 Javascript
VUE实现移动端列表筛选功能
2019/08/23 Javascript
使用vue cli4.x搭建vue项目的过程详解
2020/05/08 Javascript
Vue+Element UI 树形控件整合下拉功能菜单(tree + dropdown +input)
2020/08/28 Javascript
python 列表删除所有指定元素的方法
2018/04/19 Python
python实现微信小程序自动回复
2018/09/10 Python
Python利用命名空间解析XML文档
2020/08/10 Python
CSS3教程(9):设置RGB颜色
2009/04/02 HTML / CSS
eDreams德国:南欧领先的在线旅游公司
2020/12/07 全球购物
WINDOWS域的具体实现方式是什么
2014/02/20 面试题
部队学习十八大感言
2014/01/11 职场文书
春节晚会主持词
2014/03/24 职场文书
施工单位安全责任书
2014/07/24 职场文书
个人总结与自我评价
2014/09/18 职场文书
群众路线教育实践活动心得体会(四风)
2014/11/03 职场文书
慈善募捐倡议书
2015/04/27 职场文书
观看《信仰》心得体会
2016/01/15 职场文书
2016领导干部廉洁从政心得体会
2016/01/19 职场文书
穷人该怎么创业?谨记以下几点
2019/07/11 职场文书
Redis特殊数据类型Geospatial地理空间
2022/06/01 Redis
win10此电脑打不开怎么办 win10双击此电脑无响应的解决办法
2022/07/23 数码科技