浅谈Pytorch torch.optim优化器个性化的使用


Posted in Python onFebruary 20, 2020

一、简化前馈网络LeNet

import torch as t
 
 
class LeNet(t.nn.Module):
 def __init__(self):
  super(LeNet, self).__init__()
  self.features = t.nn.Sequential(
   t.nn.Conv2d(3, 6, 5),
   t.nn.ReLU(),
   t.nn.MaxPool2d(2, 2),
   t.nn.Conv2d(6, 16, 5),
   t.nn.ReLU(),
   t.nn.MaxPool2d(2, 2)
  )
  # 由于调整shape并不是一个class层,
  # 所以在涉及这种操作(非nn.Module操作)需要拆分为多个模型
  self.classifiter = t.nn.Sequential(
   t.nn.Linear(16*5*5, 120),
   t.nn.ReLU(),
   t.nn.Linear(120, 84),
   t.nn.ReLU(),
   t.nn.Linear(84, 10)
  )
 
 def forward(self, x):
  x = self.features(x)
  x = x.view(-1, 16*5*5)
  x = self.classifiter(x)
  return x
 
net = LeNet()

二、优化器基本使用方法

建立优化器实例

循环:

清空梯度

向前传播

计算Loss

反向传播

更新参数

from torch import optim
 
# 通常的step优化过程
optimizer = optim.SGD(params=net.parameters(), lr=1)
optimizer.zero_grad() # net.zero_grad()
 
input_ = t.autograd.Variable(t.randn(1, 3, 32, 32))
output = net(input_)
output.backward(output)
 
optimizer.step()

三、网络模块参数定制

为不同的子网络参数不同的学习率,finetune常用,使分类器学习率参数更高,学习速度更快(理论上)。

1.经由构建网络时划分好的模组进行学习率设定,

# # 直接对不同的网络模块制定不同学习率
optimizer = optim.SGD([{'params': net.features.parameters()}, # 默认lr是1e-5
      {'params': net.classifiter.parameters(), 'lr': 1e-2}], lr=1e-5)

2.以网络层对象为单位进行分组,并设定学习率

# # 以层为单位,为不同层指定不同的学习率
# ## 提取指定层对象
special_layers = t.nn.ModuleList([net.classifiter[0], net.classifiter[3]])
# ## 获取指定层参数id
special_layers_params = list(map(id, special_layers.parameters()))
print(special_layers_params)
# ## 获取非指定层的参数id
base_params = filter(lambda p: id(p) not in special_layers_params, net.parameters())
optimizer = t.optim.SGD([{'params': base_params},
       {'params': special_layers.parameters(), 'lr': 0.01}], lr=0.001)

四、在训练中动态的调整学习率

'''调整学习率'''
# 新建optimizer或者修改optimizer.params_groups对应的学习率
# # 新建optimizer更简单也更推荐,optimizer十分轻量级,所以开销很小
# # 但是新的优化器会初始化动量等状态信息,这对于使用动量的优化器(momentum参数的sgd)可能会造成收敛中的震荡
# ## optimizer.param_groups:长度2的list,optimizer.param_groups[0]:长度6的字典
print(optimizer.param_groups[0]['lr'])
old_lr = 0.1
optimizer = optim.SGD([{'params': net.features.parameters()},
      {'params': net.classifiter.parameters(), 'lr': old_lr*0.1}], lr=1e-5)

可以看到optimizer.param_groups结构,[{'params','lr', 'momentum', 'dampening', 'weight_decay', 'nesterov'},{……}],集合了优化器的各项参数。

重写sgd优化器

import torch
from torch.optim.optimizer import Optimizer, required

class SGD(Optimizer):
 def __init__(self, params, lr=required, momentum=0, dampening=0, weight_decay1=0, weight_decay2=0, nesterov=False):
  defaults = dict(lr=lr, momentum=momentum, dampening=dampening,
      weight_decay1=weight_decay1, weight_decay2=weight_decay2, nesterov=nesterov)
  if nesterov and (momentum <= 0 or dampening != 0):
   raise ValueError("Nesterov momentum requires a momentum and zero dampening")
  super(SGD, self).__init__(params, defaults)

 def __setstate__(self, state):
  super(SGD, self).__setstate__(state)
  for group in self.param_groups:
   group.setdefault('nesterov', False)

 def step(self, closure=None):
  """Performs a single optimization step. Arguments: closure (callable, optional): A closure that reevaluates the model and returns the loss. """
  loss = None
  if closure is not None:
   loss = closure()

  for group in self.param_groups:
   weight_decay1 = group['weight_decay1']
   weight_decay2 = group['weight_decay2']
   momentum = group['momentum']
   dampening = group['dampening']
   nesterov = group['nesterov']

   for p in group['params']:
    if p.grad is None:
     continue
    d_p = p.grad.data
    if weight_decay1 != 0:
     d_p.add_(weight_decay1, torch.sign(p.data))
    if weight_decay2 != 0:
     d_p.add_(weight_decay2, p.data)
    if momentum != 0:
     param_state = self.state[p]
     if 'momentum_buffer' not in param_state:
      buf = param_state['momentum_buffer'] = torch.zeros_like(p.data)
      buf.mul_(momentum).add_(d_p)
     else:
      buf = param_state['momentum_buffer']
      buf.mul_(momentum).add_(1 - dampening, d_p)
     if nesterov:
      d_p = d_p.add(momentum, buf)
     else:
      d_p = buf

    p.data.add_(-group['lr'], d_p)

  return loss

以上这篇浅谈Pytorch torch.optim优化器个性化的使用就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
基于wxpython实现的windows GUI程序实例
May 30 Python
一行代码让 Python 的运行速度提高100倍
Oct 08 Python
python 检查文件mime类型的方法
Dec 08 Python
python中时间模块的基本使用教程
May 14 Python
python tkinter实现彩球碰撞屏保
Jul 30 Python
python实现小世界网络生成
Nov 21 Python
在flask中使用python-dotenv+flask-cli自定义命令(推荐)
Jan 05 Python
PyPDF2读取PDF文件内容保存到本地TXT实例
May 12 Python
python mongo 向数据中的数组类型新增数据操作
Dec 05 Python
通过python-pptx模块操作ppt文件的方法
Dec 26 Python
python使用openpyxl库读写Excel表格的方法(增删改查操作)
May 02 Python
关于torch.optim的灵活使用详解(包括重写SGD,加上L1正则)
Feb 20 #Python
Python sys模块常用方法解析
Feb 20 #Python
pytorch 实现在一个优化器中设置多个网络参数的例子
Feb 20 #Python
pytorch ImageFolder的覆写实例
Feb 20 #Python
pytorch torchvision.ImageFolder的用法介绍
Feb 20 #Python
详解python常用命令行选项与环境变量
Feb 20 #Python
用什么库写 Python 命令行程序(示例代码详解)
Feb 20 #Python
You might like
安健A254立体声随身听的分析与打磨
2021/03/02 无线电
php 正则表达式小结
2009/08/31 PHP
PHP 读取文件内容代码(txt,js等)
2009/12/06 PHP
php 不使用js实现页面跳转
2014/02/11 PHP
php去掉URL网址中带有PHPSESSID的配置方法
2014/07/08 PHP
jQuery+CSS 实现的超Sexy下拉菜单
2010/01/17 Javascript
javascript 节点排序 2
2011/01/31 Javascript
jQuery + Flex 通过拖拽方式动态改变图片的代码
2011/08/03 Javascript
table对象中的insertRow与deleteRow使用示例
2014/01/26 Javascript
js中的hasOwnProperty和isPrototypeOf方法使用实例
2014/06/06 Javascript
jQuery源码分析之jQuery中的循环技巧详解
2014/09/06 Javascript
JavaScript实现标题栏文字轮播效果代码
2015/10/24 Javascript
完美解决IE9浏览器出现的对象未定义问题
2016/09/29 Javascript
vue获取元素宽、高、距离左边距离,右,上距离等还有XY坐标轴的方法
2018/09/05 Javascript
vue 表单之通过v-model绑定单选按钮radio
2019/05/13 Javascript
vue2配置scss的方法步骤
2019/06/06 Javascript
从零撸一个pc端vue的ui组件库( 计数器组件 )
2019/08/08 Javascript
Webpack中loader打包各种文件的方法实例
2019/09/03 Javascript
create-react-app中添加less支持的实现
2019/11/15 Javascript
springboot+vue+对接支付宝接口+二维码扫描支付功能(沙箱环境)
2020/10/15 Javascript
vite2.0+vue3移动端项目实战详解
2021/03/03 Vue.js
python格式化字符串实例总结
2014/09/28 Python
python通过自定义isnumber函数判断字符串是否为数字的方法
2015/04/23 Python
一篇文章了解Python中常见的序列化操作
2019/06/20 Python
python logging模块的使用总结
2019/07/09 Python
python实现切割url得到域名、协议、主机名等各个字段的例子
2019/07/25 Python
关于PyTorch 自动求导机制详解
2019/08/18 Python
python实现图片素描效果
2020/09/26 Python
巴西电子产品购物网站:Saldão da Informática
2018/01/09 全球购物
店长职务说明书
2014/02/04 职场文书
小学语文业务学习材料
2014/06/02 职场文书
教师工作失职检讨书
2014/09/18 职场文书
学校食品安全责任书
2015/01/29 职场文书
微信小程序实现聊天室功能
2021/06/14 Javascript
Python Matplotlib绘制等高线图与渐变色扇形图
2022/04/14 Python
一文了解Java动态代理的原理及实现
2022/07/07 Java/Android