pytorch 实现模型不同层设置不同的学习率方式


Posted in Python onJanuary 06, 2020

在目标检测的模型训练中, 我们通常都会有一个特征提取网络backbone, 例如YOLO使用的darknet SSD使用的VGG-16。

为了达到比较好的训练效果, 往往会加载预训练的backbone模型参数, 然后在此基础上训练检测网络, 并对backbone进行微调, 这时候就需要为backbone设置一个较小的lr。

class net(torch.nn.Module):
  def __init__(self):
    super(net, self).__init__()
    # backbone
    self.backbone = ...
    # detect
    self....

在设置optimizer时, 只需要参数分为两个部分, 并分别给定不同的学习率lr。

base_params = list(map(id, net.backbone.parameters()))
logits_params = filter(lambda p: id(p) not in base_params, net.parameters())
params = [
  {"params": logits_params, "lr": config.lr},
  {"params": net.backbone.parameters(), "lr": config.backbone_lr},
]
optimizer = torch.optim.SGD(params, momentum=config.momentum, weight_decay=config.weight_decay)

以上这篇pytorch 实现模型不同层设置不同的学习率方式就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
在Python中使用pngquant压缩png图片的教程
Apr 09 Python
redis之django-redis的简单缓存使用
Jun 07 Python
python matplotlib画图库学习绘制常用的图
Mar 19 Python
PyQt5 实现字体大小自适应分辨率的方法
Jun 18 Python
python中时间、日期、时间戳的转换的实现方法
Jul 06 Python
python如何删除文件中重复的字段
Jul 16 Python
在PyTorch中使用标签平滑正则化的问题
Apr 03 Python
查看jupyter notebook每个单元格运行时间实例
Apr 22 Python
python利用Excel读取和存储测试数据完成接口自动化教程
Apr 30 Python
matplotlib subplot绘制多个子图的方法示例
Jul 28 Python
python使用bs4爬取boss直聘静态页面
Oct 10 Python
详解pandas中利用DataFrame对象的.loc[]、.iloc[]方法抽取数据
Dec 13 Python
浅析Python3 pip换源问题
Jan 06 #Python
通过实例学习Python Excel操作
Jan 06 #Python
pytorch载入预训练模型后,实现训练指定层
Jan 06 #Python
python与mysql数据库交互的实现
Jan 06 #Python
win10系统下python3安装及pip换源和使用教程
Jan 06 #Python
基于python实现文件加密功能
Jan 06 #Python
Pytorch 实现冻结指定卷积层的参数
Jan 06 #Python
You might like
php文本转图片自动换行的方法
2013/03/13 PHP
php获取文件内容最后一行示例
2014/01/09 PHP
php5.4传引用时报错问题分析
2016/01/22 PHP
php实现批量修改文件名称的方法
2016/07/23 PHP
PHP的CURL方法curl_setopt()函数案例介绍(抓取网页,POST数据)
2016/12/14 PHP
PHP实现提高SESSION响应速度的几种方法详解
2019/08/09 PHP
php设计模式之观察者模式定义与用法经典示例
2019/09/19 PHP
PHP实现文件上传后台处理脚本
2020/03/04 PHP
一个封装js代码-----展开收起效果示例
2013/07/03 Javascript
jquery事件重复绑定的快速解决方法
2014/01/03 Javascript
个人总结的一些JavaScript技巧、实用函数、简洁方法、编程细节
2015/06/10 Javascript
基于javascript html5实现多文件上传
2016/03/03 Javascript
Bootstrap CSS布局之按钮
2016/12/17 Javascript
JS基于面向对象实现的选项卡效果示例
2016/12/20 Javascript
微信小程序scroll-view实现横向滚动和上拉加载示例
2017/03/06 Javascript
js 获取图像缩放后的实际宽高,位置等信息
2017/03/07 Javascript
jQuery简易时光轴实现方法示例
2017/03/13 Javascript
vue实例中data使用return包裹的方法
2018/08/27 Javascript
在Vue组件中获取全局的点击事件方法
2018/09/06 Javascript
vue使用map代替Aarry数组循环遍历的方法
2020/04/30 Javascript
python实现聚类算法原理
2018/02/12 Python
Django 实现购物车功能的示例代码
2018/10/08 Python
解决项目pycharm能运行,在终端却无法运行的问题
2019/01/19 Python
Python简易版停车管理系统
2019/08/12 Python
python的time模块和datetime模块实例解析
2019/11/29 Python
Django 批量插入数据的实现方法
2020/01/12 Python
Python利用逻辑回归分类实现模板
2020/02/15 Python
使用python-Jenkins批量创建及修改jobs操作
2020/05/12 Python
Python如何给你的程序做性能测试
2020/07/29 Python
详解CSS中iconfont的使用
2015/08/04 HTML / CSS
意大利在线购买隐形眼镜网站:VisionDirect.it
2019/03/18 全球购物
年度安全生产目标责任书
2014/07/23 职场文书
北京离婚协议书范文2014
2014/09/29 职场文书
淘宝客服专员岗位职责
2015/04/07 职场文书
pytest实现多进程与多线程运行超好用的插件
2022/07/15 Python
解决ubuntu安装软件时,status-code=409报错的问题
2022/12/24 Servers