浅谈Pytorch torch.optim优化器个性化的使用


Posted in Python onFebruary 20, 2020

一、简化前馈网络LeNet

import torch as t
 
 
class LeNet(t.nn.Module):
 def __init__(self):
  super(LeNet, self).__init__()
  self.features = t.nn.Sequential(
   t.nn.Conv2d(3, 6, 5),
   t.nn.ReLU(),
   t.nn.MaxPool2d(2, 2),
   t.nn.Conv2d(6, 16, 5),
   t.nn.ReLU(),
   t.nn.MaxPool2d(2, 2)
  )
  # 由于调整shape并不是一个class层,
  # 所以在涉及这种操作(非nn.Module操作)需要拆分为多个模型
  self.classifiter = t.nn.Sequential(
   t.nn.Linear(16*5*5, 120),
   t.nn.ReLU(),
   t.nn.Linear(120, 84),
   t.nn.ReLU(),
   t.nn.Linear(84, 10)
  )
 
 def forward(self, x):
  x = self.features(x)
  x = x.view(-1, 16*5*5)
  x = self.classifiter(x)
  return x
 
net = LeNet()

二、优化器基本使用方法

建立优化器实例

循环:

清空梯度

向前传播

计算Loss

反向传播

更新参数

from torch import optim
 
# 通常的step优化过程
optimizer = optim.SGD(params=net.parameters(), lr=1)
optimizer.zero_grad() # net.zero_grad()
 
input_ = t.autograd.Variable(t.randn(1, 3, 32, 32))
output = net(input_)
output.backward(output)
 
optimizer.step()

三、网络模块参数定制

为不同的子网络参数不同的学习率,finetune常用,使分类器学习率参数更高,学习速度更快(理论上)。

1.经由构建网络时划分好的模组进行学习率设定,

# # 直接对不同的网络模块制定不同学习率
optimizer = optim.SGD([{'params': net.features.parameters()}, # 默认lr是1e-5
      {'params': net.classifiter.parameters(), 'lr': 1e-2}], lr=1e-5)

2.以网络层对象为单位进行分组,并设定学习率

# # 以层为单位,为不同层指定不同的学习率
# ## 提取指定层对象
special_layers = t.nn.ModuleList([net.classifiter[0], net.classifiter[3]])
# ## 获取指定层参数id
special_layers_params = list(map(id, special_layers.parameters()))
print(special_layers_params)
# ## 获取非指定层的参数id
base_params = filter(lambda p: id(p) not in special_layers_params, net.parameters())
optimizer = t.optim.SGD([{'params': base_params},
       {'params': special_layers.parameters(), 'lr': 0.01}], lr=0.001)

四、在训练中动态的调整学习率

'''调整学习率'''
# 新建optimizer或者修改optimizer.params_groups对应的学习率
# # 新建optimizer更简单也更推荐,optimizer十分轻量级,所以开销很小
# # 但是新的优化器会初始化动量等状态信息,这对于使用动量的优化器(momentum参数的sgd)可能会造成收敛中的震荡
# ## optimizer.param_groups:长度2的list,optimizer.param_groups[0]:长度6的字典
print(optimizer.param_groups[0]['lr'])
old_lr = 0.1
optimizer = optim.SGD([{'params': net.features.parameters()},
      {'params': net.classifiter.parameters(), 'lr': old_lr*0.1}], lr=1e-5)

可以看到optimizer.param_groups结构,[{'params','lr', 'momentum', 'dampening', 'weight_decay', 'nesterov'},{……}],集合了优化器的各项参数。

重写sgd优化器

import torch
from torch.optim.optimizer import Optimizer, required

class SGD(Optimizer):
 def __init__(self, params, lr=required, momentum=0, dampening=0, weight_decay1=0, weight_decay2=0, nesterov=False):
  defaults = dict(lr=lr, momentum=momentum, dampening=dampening,
      weight_decay1=weight_decay1, weight_decay2=weight_decay2, nesterov=nesterov)
  if nesterov and (momentum <= 0 or dampening != 0):
   raise ValueError("Nesterov momentum requires a momentum and zero dampening")
  super(SGD, self).__init__(params, defaults)

 def __setstate__(self, state):
  super(SGD, self).__setstate__(state)
  for group in self.param_groups:
   group.setdefault('nesterov', False)

 def step(self, closure=None):
  """Performs a single optimization step. Arguments: closure (callable, optional): A closure that reevaluates the model and returns the loss. """
  loss = None
  if closure is not None:
   loss = closure()

  for group in self.param_groups:
   weight_decay1 = group['weight_decay1']
   weight_decay2 = group['weight_decay2']
   momentum = group['momentum']
   dampening = group['dampening']
   nesterov = group['nesterov']

   for p in group['params']:
    if p.grad is None:
     continue
    d_p = p.grad.data
    if weight_decay1 != 0:
     d_p.add_(weight_decay1, torch.sign(p.data))
    if weight_decay2 != 0:
     d_p.add_(weight_decay2, p.data)
    if momentum != 0:
     param_state = self.state[p]
     if 'momentum_buffer' not in param_state:
      buf = param_state['momentum_buffer'] = torch.zeros_like(p.data)
      buf.mul_(momentum).add_(d_p)
     else:
      buf = param_state['momentum_buffer']
      buf.mul_(momentum).add_(1 - dampening, d_p)
     if nesterov:
      d_p = d_p.add(momentum, buf)
     else:
      d_p = buf

    p.data.add_(-group['lr'], d_p)

  return loss

以上这篇浅谈Pytorch torch.optim优化器个性化的使用就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
python自动翻译实现方法
May 28 Python
Python设计足球联赛赛程表程序的思路与简单实现示例
Jun 28 Python
Python Json模块中dumps、loads、dump、load函数介绍
May 15 Python
numpy中以文本的方式存储以及读取数据方法
Jun 04 Python
浅谈Python接口对json串的处理方法
Dec 19 Python
python 判断矩阵中每行非零个数的方法
Jan 26 Python
pytorch 在sequential中使用view来reshape的例子
Aug 20 Python
Django 创建后台,配置sqlite3教程
Nov 18 Python
Python3实现监控新型冠状病毒肺炎疫情的示例代码
Feb 13 Python
Python gevent协程切换实现详解
Sep 14 Python
PyTorch预训练Bert模型的示例
Nov 17 Python
Python matplotlib 利用随机函数生成变化图形
Apr 26 Python
关于torch.optim的灵活使用详解(包括重写SGD,加上L1正则)
Feb 20 #Python
Python sys模块常用方法解析
Feb 20 #Python
pytorch 实现在一个优化器中设置多个网络参数的例子
Feb 20 #Python
pytorch ImageFolder的覆写实例
Feb 20 #Python
pytorch torchvision.ImageFolder的用法介绍
Feb 20 #Python
详解python常用命令行选项与环境变量
Feb 20 #Python
用什么库写 Python 命令行程序(示例代码详解)
Feb 20 #Python
You might like
php中将一个对象保存到Session中的方法
2015/03/13 PHP
php实现TCP端口检测的方法
2015/04/01 PHP
php生成与读取excel文件
2016/10/14 PHP
PHP+iframe图片上传实现即时刷新效果
2016/11/18 PHP
轻松实现php文件上传功能
2017/02/17 PHP
PHP实现的多维数组排序算法分析
2018/02/10 PHP
laravel 关联关系遍历数组的例子
2019/10/10 PHP
PHP $O00OO0=urldecode &amp; eval 解密,记一次商业源码的去后门
2020/09/13 PHP
js读写cookie实现一个底部广告浮层效果的两种方法
2013/12/29 Javascript
javascript日期对象格式化为字符串的实现方法
2014/01/14 Javascript
jquery动态调整div大小使其宽度始终为浏览器宽度
2014/06/06 Javascript
jquery引用方法时传递参数原理分析
2014/10/13 Javascript
分享几种比较简单实用的JavaScript tabel切换
2015/12/31 Javascript
jQuery的Cookie封装,与PHP交互的简单实现
2016/10/05 Javascript
ui-router中使用ocLazyLoad和resolve的具体方法
2017/10/18 Javascript
angularjs中$http异步上传Excel文件方法
2018/02/23 Javascript
基于vue实现可搜索下拉框定制组件
2020/03/26 Javascript
vue多个元素的样式选择器问题
2019/11/29 Javascript
NUXT SSR初级入门笔记(小结)
2019/12/16 Javascript
jQuery实现推拉门效果
2020/10/19 jQuery
[43:58]DOTA2上海特级锦标赛C组败者赛 Newbee VS Archon第二局
2016/02/27 DOTA
[01:10]DOTA2英雄背景故事第四期之混沌法则混沌骑士
2020/07/16 DOTA
Python随机生成数模块random使用实例
2015/04/13 Python
tensorflow: variable的值与variable.read_value()的值区别详解
2018/07/30 Python
在Python中增加和插入元素的示例
2018/11/01 Python
python实现在函数图像上添加文字和标注的方法
2019/07/08 Python
基于python实现文件加密功能
2020/01/06 Python
python+gdal+遥感图像拼接(mosaic)的实例
2020/03/10 Python
如何开发一款堪比APP的微信小程序(腾讯内部团队分享)
2016/12/22 HTML / CSS
英国领先的在线药房:Pharmacy First
2017/09/10 全球购物
数控技术专科生自我评价
2014/01/08 职场文书
维修工先进事迹
2014/05/29 职场文书
校园会短篇的广播稿
2014/10/21 职场文书
运动会班级前导词
2015/07/20 职场文书
Python字典和列表性能之间的比较
2021/06/07 Python
Python可视化学习之seaborn调色盘
2022/02/24 Python