浅谈Pytorch torch.optim优化器个性化的使用


Posted in Python onFebruary 20, 2020

一、简化前馈网络LeNet

import torch as t
 
 
class LeNet(t.nn.Module):
 def __init__(self):
  super(LeNet, self).__init__()
  self.features = t.nn.Sequential(
   t.nn.Conv2d(3, 6, 5),
   t.nn.ReLU(),
   t.nn.MaxPool2d(2, 2),
   t.nn.Conv2d(6, 16, 5),
   t.nn.ReLU(),
   t.nn.MaxPool2d(2, 2)
  )
  # 由于调整shape并不是一个class层,
  # 所以在涉及这种操作(非nn.Module操作)需要拆分为多个模型
  self.classifiter = t.nn.Sequential(
   t.nn.Linear(16*5*5, 120),
   t.nn.ReLU(),
   t.nn.Linear(120, 84),
   t.nn.ReLU(),
   t.nn.Linear(84, 10)
  )
 
 def forward(self, x):
  x = self.features(x)
  x = x.view(-1, 16*5*5)
  x = self.classifiter(x)
  return x
 
net = LeNet()

二、优化器基本使用方法

建立优化器实例

循环:

清空梯度

向前传播

计算Loss

反向传播

更新参数

from torch import optim
 
# 通常的step优化过程
optimizer = optim.SGD(params=net.parameters(), lr=1)
optimizer.zero_grad() # net.zero_grad()
 
input_ = t.autograd.Variable(t.randn(1, 3, 32, 32))
output = net(input_)
output.backward(output)
 
optimizer.step()

三、网络模块参数定制

为不同的子网络参数不同的学习率,finetune常用,使分类器学习率参数更高,学习速度更快(理论上)。

1.经由构建网络时划分好的模组进行学习率设定,

# # 直接对不同的网络模块制定不同学习率
optimizer = optim.SGD([{'params': net.features.parameters()}, # 默认lr是1e-5
      {'params': net.classifiter.parameters(), 'lr': 1e-2}], lr=1e-5)

2.以网络层对象为单位进行分组,并设定学习率

# # 以层为单位,为不同层指定不同的学习率
# ## 提取指定层对象
special_layers = t.nn.ModuleList([net.classifiter[0], net.classifiter[3]])
# ## 获取指定层参数id
special_layers_params = list(map(id, special_layers.parameters()))
print(special_layers_params)
# ## 获取非指定层的参数id
base_params = filter(lambda p: id(p) not in special_layers_params, net.parameters())
optimizer = t.optim.SGD([{'params': base_params},
       {'params': special_layers.parameters(), 'lr': 0.01}], lr=0.001)

四、在训练中动态的调整学习率

'''调整学习率'''
# 新建optimizer或者修改optimizer.params_groups对应的学习率
# # 新建optimizer更简单也更推荐,optimizer十分轻量级,所以开销很小
# # 但是新的优化器会初始化动量等状态信息,这对于使用动量的优化器(momentum参数的sgd)可能会造成收敛中的震荡
# ## optimizer.param_groups:长度2的list,optimizer.param_groups[0]:长度6的字典
print(optimizer.param_groups[0]['lr'])
old_lr = 0.1
optimizer = optim.SGD([{'params': net.features.parameters()},
      {'params': net.classifiter.parameters(), 'lr': old_lr*0.1}], lr=1e-5)

可以看到optimizer.param_groups结构,[{'params','lr', 'momentum', 'dampening', 'weight_decay', 'nesterov'},{……}],集合了优化器的各项参数。

重写sgd优化器

import torch
from torch.optim.optimizer import Optimizer, required

class SGD(Optimizer):
 def __init__(self, params, lr=required, momentum=0, dampening=0, weight_decay1=0, weight_decay2=0, nesterov=False):
  defaults = dict(lr=lr, momentum=momentum, dampening=dampening,
      weight_decay1=weight_decay1, weight_decay2=weight_decay2, nesterov=nesterov)
  if nesterov and (momentum <= 0 or dampening != 0):
   raise ValueError("Nesterov momentum requires a momentum and zero dampening")
  super(SGD, self).__init__(params, defaults)

 def __setstate__(self, state):
  super(SGD, self).__setstate__(state)
  for group in self.param_groups:
   group.setdefault('nesterov', False)

 def step(self, closure=None):
  """Performs a single optimization step. Arguments: closure (callable, optional): A closure that reevaluates the model and returns the loss. """
  loss = None
  if closure is not None:
   loss = closure()

  for group in self.param_groups:
   weight_decay1 = group['weight_decay1']
   weight_decay2 = group['weight_decay2']
   momentum = group['momentum']
   dampening = group['dampening']
   nesterov = group['nesterov']

   for p in group['params']:
    if p.grad is None:
     continue
    d_p = p.grad.data
    if weight_decay1 != 0:
     d_p.add_(weight_decay1, torch.sign(p.data))
    if weight_decay2 != 0:
     d_p.add_(weight_decay2, p.data)
    if momentum != 0:
     param_state = self.state[p]
     if 'momentum_buffer' not in param_state:
      buf = param_state['momentum_buffer'] = torch.zeros_like(p.data)
      buf.mul_(momentum).add_(d_p)
     else:
      buf = param_state['momentum_buffer']
      buf.mul_(momentum).add_(1 - dampening, d_p)
     if nesterov:
      d_p = d_p.add(momentum, buf)
     else:
      d_p = buf

    p.data.add_(-group['lr'], d_p)

  return loss

以上这篇浅谈Pytorch torch.optim优化器个性化的使用就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
Python IDE PyCharm的基本快捷键和配置简介
Nov 04 Python
Python基于select实现的socket服务器
Apr 13 Python
Python使用functools模块中的partial函数生成偏函数
Jul 02 Python
使用PIL(Python-Imaging)反转图像的颜色方法
Jan 24 Python
Python hexstring-list-str之间的转换方法
Jun 12 Python
Python实现微信小程序支付功能
Jul 25 Python
python实现树的深度优先遍历与广度优先遍历详解
Oct 26 Python
PyCharm下载和安装详细步骤
Dec 17 Python
keras获得某一层或者某层权重的输出实例
Jan 24 Python
Python常用库Numpy进行矩阵运算详解
Jul 21 Python
Python使用Pygame绘制时钟
Nov 29 Python
python中mongodb包操作数据库
Apr 19 Python
关于torch.optim的灵活使用详解(包括重写SGD,加上L1正则)
Feb 20 #Python
Python sys模块常用方法解析
Feb 20 #Python
pytorch 实现在一个优化器中设置多个网络参数的例子
Feb 20 #Python
pytorch ImageFolder的覆写实例
Feb 20 #Python
pytorch torchvision.ImageFolder的用法介绍
Feb 20 #Python
详解python常用命令行选项与环境变量
Feb 20 #Python
用什么库写 Python 命令行程序(示例代码详解)
Feb 20 #Python
You might like
PHP字符串 ==比较运算符的副作用
2009/10/21 PHP
超级好用的一个php上传图片类(随机名,缩略图,加水印)
2010/06/30 PHP
解析php函数method_exists()与is_callable()的区别
2013/06/21 PHP
简单实用的.net DataTable导出Execl
2013/10/28 PHP
PHP使用CURL_MULTI实现多线程采集的例子
2014/07/29 PHP
PHP生成短网址方法汇总
2016/07/12 PHP
js中的push和join方法使用介绍
2013/10/08 Javascript
JS小功能(onmouseover实现选择月份)实例代码
2013/11/28 Javascript
JS逆序遍历实现代码
2014/12/02 Javascript
JavaScript的代码编写格式规范指南
2015/12/07 Javascript
JS中的forEach、$.each、map方法推荐
2016/04/05 Javascript
bootstrap模态框垂直居中效果
2016/12/03 Javascript
angularjs中ng-bind-html的用法总结
2017/05/23 Javascript
详解javascript设计模式三:代理模式
2019/03/25 Javascript
VUE 解决mode为history页面为空白的问题
2019/11/01 Javascript
浅谈vue项目利用Hbuilder打包成APP流程,以及遇到的坑
2020/09/12 Javascript
VUE实现吸底按钮
2021/03/04 Vue.js
[54:43]DOTA2-DPC中国联赛 正赛 CDEC vs Dynasty BO3 第一场 2月22日
2021/03/11 DOTA
python获取当前运行函数名称的方法实例代码
2017/04/06 Python
python3 破解 geetest(极验)的滑块验证码功能
2018/02/24 Python
python生成器与迭代器详解
2019/01/01 Python
对Python模块导入时全局变量__all__的作用详解
2019/01/11 Python
pip install python 快速安装模块的教程图解
2019/10/08 Python
Python vtk读取并显示dicom文件示例
2020/01/13 Python
解决jupyter notebook 出现In[*]的问题
2020/04/13 Python
django ORM之values和annotate使用详解
2020/05/19 Python
Python Pandas数据分析工具用法实例
2020/11/05 Python
英国优质家居用品网上品牌:URBANARA
2018/06/01 全球购物
Auguste The Label官网:澳大利亚一家精品女装时尚品牌
2020/06/14 全球购物
JavaScript实现前端网页版倒计时
2021/03/24 Javascript
大学生涯自我鉴定
2014/01/16 职场文书
诚信考试承诺书
2014/03/27 职场文书
安全生产月宣传标语
2014/10/06 职场文书
数学教师求职信范文
2015/03/20 职场文书
青春雷锋观后感
2015/06/10 职场文书
浅谈 JavaScript 沙箱Sandbox
2021/11/02 Javascript