浅谈Pytorch torch.optim优化器个性化的使用


Posted in Python onFebruary 20, 2020

一、简化前馈网络LeNet

import torch as t
 
 
class LeNet(t.nn.Module):
 def __init__(self):
  super(LeNet, self).__init__()
  self.features = t.nn.Sequential(
   t.nn.Conv2d(3, 6, 5),
   t.nn.ReLU(),
   t.nn.MaxPool2d(2, 2),
   t.nn.Conv2d(6, 16, 5),
   t.nn.ReLU(),
   t.nn.MaxPool2d(2, 2)
  )
  # 由于调整shape并不是一个class层,
  # 所以在涉及这种操作(非nn.Module操作)需要拆分为多个模型
  self.classifiter = t.nn.Sequential(
   t.nn.Linear(16*5*5, 120),
   t.nn.ReLU(),
   t.nn.Linear(120, 84),
   t.nn.ReLU(),
   t.nn.Linear(84, 10)
  )
 
 def forward(self, x):
  x = self.features(x)
  x = x.view(-1, 16*5*5)
  x = self.classifiter(x)
  return x
 
net = LeNet()

二、优化器基本使用方法

建立优化器实例

循环:

清空梯度

向前传播

计算Loss

反向传播

更新参数

from torch import optim
 
# 通常的step优化过程
optimizer = optim.SGD(params=net.parameters(), lr=1)
optimizer.zero_grad() # net.zero_grad()
 
input_ = t.autograd.Variable(t.randn(1, 3, 32, 32))
output = net(input_)
output.backward(output)
 
optimizer.step()

三、网络模块参数定制

为不同的子网络参数不同的学习率,finetune常用,使分类器学习率参数更高,学习速度更快(理论上)。

1.经由构建网络时划分好的模组进行学习率设定,

# # 直接对不同的网络模块制定不同学习率
optimizer = optim.SGD([{'params': net.features.parameters()}, # 默认lr是1e-5
      {'params': net.classifiter.parameters(), 'lr': 1e-2}], lr=1e-5)

2.以网络层对象为单位进行分组,并设定学习率

# # 以层为单位,为不同层指定不同的学习率
# ## 提取指定层对象
special_layers = t.nn.ModuleList([net.classifiter[0], net.classifiter[3]])
# ## 获取指定层参数id
special_layers_params = list(map(id, special_layers.parameters()))
print(special_layers_params)
# ## 获取非指定层的参数id
base_params = filter(lambda p: id(p) not in special_layers_params, net.parameters())
optimizer = t.optim.SGD([{'params': base_params},
       {'params': special_layers.parameters(), 'lr': 0.01}], lr=0.001)

四、在训练中动态的调整学习率

'''调整学习率'''
# 新建optimizer或者修改optimizer.params_groups对应的学习率
# # 新建optimizer更简单也更推荐,optimizer十分轻量级,所以开销很小
# # 但是新的优化器会初始化动量等状态信息,这对于使用动量的优化器(momentum参数的sgd)可能会造成收敛中的震荡
# ## optimizer.param_groups:长度2的list,optimizer.param_groups[0]:长度6的字典
print(optimizer.param_groups[0]['lr'])
old_lr = 0.1
optimizer = optim.SGD([{'params': net.features.parameters()},
      {'params': net.classifiter.parameters(), 'lr': old_lr*0.1}], lr=1e-5)

可以看到optimizer.param_groups结构,[{'params','lr', 'momentum', 'dampening', 'weight_decay', 'nesterov'},{……}],集合了优化器的各项参数。

重写sgd优化器

import torch
from torch.optim.optimizer import Optimizer, required

class SGD(Optimizer):
 def __init__(self, params, lr=required, momentum=0, dampening=0, weight_decay1=0, weight_decay2=0, nesterov=False):
  defaults = dict(lr=lr, momentum=momentum, dampening=dampening,
      weight_decay1=weight_decay1, weight_decay2=weight_decay2, nesterov=nesterov)
  if nesterov and (momentum <= 0 or dampening != 0):
   raise ValueError("Nesterov momentum requires a momentum and zero dampening")
  super(SGD, self).__init__(params, defaults)

 def __setstate__(self, state):
  super(SGD, self).__setstate__(state)
  for group in self.param_groups:
   group.setdefault('nesterov', False)

 def step(self, closure=None):
  """Performs a single optimization step. Arguments: closure (callable, optional): A closure that reevaluates the model and returns the loss. """
  loss = None
  if closure is not None:
   loss = closure()

  for group in self.param_groups:
   weight_decay1 = group['weight_decay1']
   weight_decay2 = group['weight_decay2']
   momentum = group['momentum']
   dampening = group['dampening']
   nesterov = group['nesterov']

   for p in group['params']:
    if p.grad is None:
     continue
    d_p = p.grad.data
    if weight_decay1 != 0:
     d_p.add_(weight_decay1, torch.sign(p.data))
    if weight_decay2 != 0:
     d_p.add_(weight_decay2, p.data)
    if momentum != 0:
     param_state = self.state[p]
     if 'momentum_buffer' not in param_state:
      buf = param_state['momentum_buffer'] = torch.zeros_like(p.data)
      buf.mul_(momentum).add_(d_p)
     else:
      buf = param_state['momentum_buffer']
      buf.mul_(momentum).add_(1 - dampening, d_p)
     if nesterov:
      d_p = d_p.add(momentum, buf)
     else:
      d_p = buf

    p.data.add_(-group['lr'], d_p)

  return loss

以上这篇浅谈Pytorch torch.optim优化器个性化的使用就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
python list语法学习(带例子)
Nov 01 Python
Python二维码生成库qrcode安装和使用示例
Dec 16 Python
python实现批量下载新浪博客的方法
Jun 15 Python
Python中的二维数组实例(list与numpy.array)
Apr 13 Python
解决Python2.7中IDLE启动没有反应的问题
Nov 30 Python
Python爬虫实现验证码登录代码实例
May 10 Python
django解决订单并发问题【推荐】
Jul 31 Python
python自动生成model文件过程详解
Nov 02 Python
win10系统Anaconda和Pycharm的Tensorflow2.0之CPU和GPU版本安装教程
Dec 03 Python
Python 解析pymysql模块操作数据库的方法
Feb 18 Python
Python request使用方法及问题总结
Apr 26 Python
降低python版本的操作方法
Sep 11 Python
关于torch.optim的灵活使用详解(包括重写SGD,加上L1正则)
Feb 20 #Python
Python sys模块常用方法解析
Feb 20 #Python
pytorch 实现在一个优化器中设置多个网络参数的例子
Feb 20 #Python
pytorch ImageFolder的覆写实例
Feb 20 #Python
pytorch torchvision.ImageFolder的用法介绍
Feb 20 #Python
详解python常用命令行选项与环境变量
Feb 20 #Python
用什么库写 Python 命令行程序(示例代码详解)
Feb 20 #Python
You might like
CPU步进是什么意思?i3-9100F B0步进和U0步进区别知识科普
2020/03/17 数码科技
Memcached常用命令以及使用说明详解
2013/06/27 PHP
php过滤敏感词的示例
2014/03/31 PHP
php实现约瑟夫问题的方法小结
2015/03/23 PHP
jquery延迟加载外部js实现代码
2013/01/11 Javascript
在JavaScript中实现类的方式探讨
2013/08/28 Javascript
JavaScript实现Java中Map容器的方法
2016/10/09 Javascript
基于AngularJS前端云组件最佳实践
2016/10/20 Javascript
javascript基本数据类型及类型检测常用方法小结
2016/12/14 Javascript
Vue组件tree实现树形菜单
2017/04/13 Javascript
Angular路由ui-router配置详解
2018/08/01 Javascript
vue-router命名路由和编程式路由传参讲解
2019/01/19 Javascript
react的滑动图片验证码组件的示例代码
2019/02/27 Javascript
python使用mysqldb连接数据库操作方法示例详解
2013/12/03 Python
python循环监控远程端口的方法
2015/03/14 Python
python中尾递归用法实例详解
2015/04/28 Python
浅谈Python中copy()方法的使用
2015/05/21 Python
Python实现二叉树结构与进行二叉树遍历的方法详解
2016/05/24 Python
解决TensorFlow训练内存不断增长,进程被杀死问题
2020/02/05 Python
Keras 数据增强ImageDataGenerator多输入多输出实例
2020/07/03 Python
CSS3教程(7):CSS3嵌入字体
2009/04/02 HTML / CSS
关于box-sizing的全面理解
2016/07/28 HTML / CSS
HTML5头部标签的一些常用信息小结
2016/10/23 HTML / CSS
调解协议书
2014/04/16 职场文书
食品销售计划书
2014/04/26 职场文书
2014年计生协会工作总结
2014/11/21 职场文书
2014年化工厂工作总结
2014/11/25 职场文书
中学团支部工作总结
2015/08/13 职场文书
2016年教师政治思想表现评语
2015/12/02 职场文书
应用最多的公文《通知》如何写?
2019/04/02 职场文书
如何让2019年上半年的工作总结更出色!
2019/07/01 职场文书
导游词之南迦巴瓦峰
2019/11/19 职场文书
SONY AN-LP1 短波有源天线放大器
2021/04/22 无线电
如何使用JavaScript策略模式校验表单
2021/04/29 Javascript
Mysql Online DDL的使用详解
2021/05/20 MySQL
浏览器常用基本操作之python3+selenium4自动化测试(基础篇3)
2021/05/21 Python