浅谈Pytorch torch.optim优化器个性化的使用


Posted in Python onFebruary 20, 2020

一、简化前馈网络LeNet

import torch as t
 
 
class LeNet(t.nn.Module):
 def __init__(self):
  super(LeNet, self).__init__()
  self.features = t.nn.Sequential(
   t.nn.Conv2d(3, 6, 5),
   t.nn.ReLU(),
   t.nn.MaxPool2d(2, 2),
   t.nn.Conv2d(6, 16, 5),
   t.nn.ReLU(),
   t.nn.MaxPool2d(2, 2)
  )
  # 由于调整shape并不是一个class层,
  # 所以在涉及这种操作(非nn.Module操作)需要拆分为多个模型
  self.classifiter = t.nn.Sequential(
   t.nn.Linear(16*5*5, 120),
   t.nn.ReLU(),
   t.nn.Linear(120, 84),
   t.nn.ReLU(),
   t.nn.Linear(84, 10)
  )
 
 def forward(self, x):
  x = self.features(x)
  x = x.view(-1, 16*5*5)
  x = self.classifiter(x)
  return x
 
net = LeNet()

二、优化器基本使用方法

建立优化器实例

循环:

清空梯度

向前传播

计算Loss

反向传播

更新参数

from torch import optim
 
# 通常的step优化过程
optimizer = optim.SGD(params=net.parameters(), lr=1)
optimizer.zero_grad() # net.zero_grad()
 
input_ = t.autograd.Variable(t.randn(1, 3, 32, 32))
output = net(input_)
output.backward(output)
 
optimizer.step()

三、网络模块参数定制

为不同的子网络参数不同的学习率,finetune常用,使分类器学习率参数更高,学习速度更快(理论上)。

1.经由构建网络时划分好的模组进行学习率设定,

# # 直接对不同的网络模块制定不同学习率
optimizer = optim.SGD([{'params': net.features.parameters()}, # 默认lr是1e-5
      {'params': net.classifiter.parameters(), 'lr': 1e-2}], lr=1e-5)

2.以网络层对象为单位进行分组,并设定学习率

# # 以层为单位,为不同层指定不同的学习率
# ## 提取指定层对象
special_layers = t.nn.ModuleList([net.classifiter[0], net.classifiter[3]])
# ## 获取指定层参数id
special_layers_params = list(map(id, special_layers.parameters()))
print(special_layers_params)
# ## 获取非指定层的参数id
base_params = filter(lambda p: id(p) not in special_layers_params, net.parameters())
optimizer = t.optim.SGD([{'params': base_params},
       {'params': special_layers.parameters(), 'lr': 0.01}], lr=0.001)

四、在训练中动态的调整学习率

'''调整学习率'''
# 新建optimizer或者修改optimizer.params_groups对应的学习率
# # 新建optimizer更简单也更推荐,optimizer十分轻量级,所以开销很小
# # 但是新的优化器会初始化动量等状态信息,这对于使用动量的优化器(momentum参数的sgd)可能会造成收敛中的震荡
# ## optimizer.param_groups:长度2的list,optimizer.param_groups[0]:长度6的字典
print(optimizer.param_groups[0]['lr'])
old_lr = 0.1
optimizer = optim.SGD([{'params': net.features.parameters()},
      {'params': net.classifiter.parameters(), 'lr': old_lr*0.1}], lr=1e-5)

可以看到optimizer.param_groups结构,[{'params','lr', 'momentum', 'dampening', 'weight_decay', 'nesterov'},{……}],集合了优化器的各项参数。

重写sgd优化器

import torch
from torch.optim.optimizer import Optimizer, required

class SGD(Optimizer):
 def __init__(self, params, lr=required, momentum=0, dampening=0, weight_decay1=0, weight_decay2=0, nesterov=False):
  defaults = dict(lr=lr, momentum=momentum, dampening=dampening,
      weight_decay1=weight_decay1, weight_decay2=weight_decay2, nesterov=nesterov)
  if nesterov and (momentum <= 0 or dampening != 0):
   raise ValueError("Nesterov momentum requires a momentum and zero dampening")
  super(SGD, self).__init__(params, defaults)

 def __setstate__(self, state):
  super(SGD, self).__setstate__(state)
  for group in self.param_groups:
   group.setdefault('nesterov', False)

 def step(self, closure=None):
  """Performs a single optimization step. Arguments: closure (callable, optional): A closure that reevaluates the model and returns the loss. """
  loss = None
  if closure is not None:
   loss = closure()

  for group in self.param_groups:
   weight_decay1 = group['weight_decay1']
   weight_decay2 = group['weight_decay2']
   momentum = group['momentum']
   dampening = group['dampening']
   nesterov = group['nesterov']

   for p in group['params']:
    if p.grad is None:
     continue
    d_p = p.grad.data
    if weight_decay1 != 0:
     d_p.add_(weight_decay1, torch.sign(p.data))
    if weight_decay2 != 0:
     d_p.add_(weight_decay2, p.data)
    if momentum != 0:
     param_state = self.state[p]
     if 'momentum_buffer' not in param_state:
      buf = param_state['momentum_buffer'] = torch.zeros_like(p.data)
      buf.mul_(momentum).add_(d_p)
     else:
      buf = param_state['momentum_buffer']
      buf.mul_(momentum).add_(1 - dampening, d_p)
     if nesterov:
      d_p = d_p.add(momentum, buf)
     else:
      d_p = buf

    p.data.add_(-group['lr'], d_p)

  return loss

以上这篇浅谈Pytorch torch.optim优化器个性化的使用就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
Python常见文件操作的函数示例代码
Nov 15 Python
python实现随机密码字典生成器示例
Apr 09 Python
Python greenlet实现原理和使用示例
Sep 24 Python
实践Python的爬虫框架Scrapy来抓取豆瓣电影TOP250
Jan 20 Python
python 数据清洗之数据合并、转换、过滤、排序
Feb 12 Python
Python md5与sha1加密算法用法分析
Jul 14 Python
python使用itchat库实现微信机器人(好友聊天、群聊天)
Jan 04 Python
python获取代理IP的实例分享
May 07 Python
完美解决在oj中Python的循环输入问题
Jun 25 Python
Python3 assert断言实现原理解析
Mar 02 Python
pycharm如何使用anaconda中的各种包(操作步骤)
Jul 31 Python
python matlab库简单用法讲解
Dec 31 Python
关于torch.optim的灵活使用详解(包括重写SGD,加上L1正则)
Feb 20 #Python
Python sys模块常用方法解析
Feb 20 #Python
pytorch 实现在一个优化器中设置多个网络参数的例子
Feb 20 #Python
pytorch ImageFolder的覆写实例
Feb 20 #Python
pytorch torchvision.ImageFolder的用法介绍
Feb 20 #Python
详解python常用命令行选项与环境变量
Feb 20 #Python
用什么库写 Python 命令行程序(示例代码详解)
Feb 20 #Python
You might like
php计算两个坐标(经度,纬度)之间距离的方法
2015/04/17 PHP
Json_decode 解析json字符串为NULL的解决方法(必看)
2017/02/17 PHP
PHP实现的简单组词算法示例
2018/04/10 PHP
javascript 函数调用的对象和方法
2010/07/01 Javascript
js获取字符串最后一位方法汇总
2014/11/13 Javascript
vue.js开发环境搭建教程
2017/05/04 Javascript
JS交互点击WKWebView中的图片实现预览效果
2018/01/05 Javascript
vue小白入门教程
2018/04/02 Javascript
基于Vue+element-ui 的Table二次封装的实现
2018/07/20 Javascript
JS html事件冒泡和事件捕获操作示例
2019/05/01 Javascript
微信小程序 轮播图实现原理及优化详解
2019/09/29 Javascript
vue cli3 配置proxy代理无效的解决
2019/10/30 Javascript
Python中使用socket发送HTTP请求数据接收不完整问题解决方法
2015/02/04 Python
python使用multiprocessing模块实现带回调函数的异步调用方法
2015/04/18 Python
python实现的简单RPG游戏流程实例
2015/06/28 Python
Python实现对百度云的文件上传(实例讲解)
2017/10/21 Python
python+opencv实现动态物体识别
2018/01/09 Python
Python设计模式之桥接模式原理与用法实例分析
2019/01/10 Python
python pyinstaller 加载ui路径方法
2019/06/10 Python
详解Python3之数据指纹MD5校验与对比
2019/06/11 Python
Python 用matplotlib画以时间日期为x轴的图像
2019/08/06 Python
python多线程与多进程及其区别详解
2019/08/08 Python
python3 deque 双向队列创建与使用方法分析
2020/03/24 Python
python 录制系统声音的示例
2020/12/21 Python
英国最大的割草机购买网站:Just Lawnmowers
2019/11/02 全球购物
彪马香港官方网上商店:PUMA香港
2020/12/06 全球购物
node中使用shell脚本的方法步骤
2021/03/23 Javascript
珍珠奶茶店创业计划书
2014/01/11 职场文书
工商管理专业大学生职业生涯规划范文
2014/03/09 职场文书
小学国旗下的演讲稿
2014/08/28 职场文书
2014幼儿园大班工作总结
2014/11/10 职场文书
平安家庭事迹材料
2014/12/20 职场文书
2019年工作总结范文
2019/05/21 职场文书
MySQL去除重叠时间求时间差和的实现
2021/08/23 MySQL
Win11怎么把合并的任务栏分开 Win11任务栏合并分开教程
2022/04/06 数码科技
SqlServer常用函数及时间处理小结
2023/05/08 SQL Server