pytorch 实现模型不同层设置不同的学习率方式


Posted in Python onJanuary 06, 2020

在目标检测的模型训练中, 我们通常都会有一个特征提取网络backbone, 例如YOLO使用的darknet SSD使用的VGG-16。

为了达到比较好的训练效果, 往往会加载预训练的backbone模型参数, 然后在此基础上训练检测网络, 并对backbone进行微调, 这时候就需要为backbone设置一个较小的lr。

class net(torch.nn.Module):
  def __init__(self):
    super(net, self).__init__()
    # backbone
    self.backbone = ...
    # detect
    self....

在设置optimizer时, 只需要参数分为两个部分, 并分别给定不同的学习率lr。

base_params = list(map(id, net.backbone.parameters()))
logits_params = filter(lambda p: id(p) not in base_params, net.parameters())
params = [
  {"params": logits_params, "lr": config.lr},
  {"params": net.backbone.parameters(), "lr": config.backbone_lr},
]
optimizer = torch.optim.SGD(params, momentum=config.momentum, weight_decay=config.weight_decay)

以上这篇pytorch 实现模型不同层设置不同的学习率方式就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
Python写的Socks5协议代理服务器
Aug 06 Python
python制作爬虫爬取京东商品评论教程
Dec 16 Python
Python使用PyCrypto实现AES加密功能示例
May 22 Python
Python2.7编程中SQLite3基本操作方法示例
Aug 09 Python
Python中Scrapy爬虫图片处理详解
Nov 29 Python
python 删除指定时间间隔之前的文件实例
Apr 24 Python
python利用微信公众号实现报警功能
Jun 10 Python
python实现排序算法解析
Sep 08 Python
Python3实现计算两个数组的交集算法示例
Apr 03 Python
python2.7使用scapy发送syn实例
May 05 Python
Python hashlib模块的使用示例
Oct 09 Python
python 用Matplotlib作图中有多个Y轴
Nov 28 Python
浅析Python3 pip换源问题
Jan 06 #Python
通过实例学习Python Excel操作
Jan 06 #Python
pytorch载入预训练模型后,实现训练指定层
Jan 06 #Python
python与mysql数据库交互的实现
Jan 06 #Python
win10系统下python3安装及pip换源和使用教程
Jan 06 #Python
基于python实现文件加密功能
Jan 06 #Python
Pytorch 实现冻结指定卷积层的参数
Jan 06 #Python
You might like
javascript,php获取函数参数对象的代码
2011/02/03 PHP
php ckeditor上传图片文件名乱码解决方法
2013/11/15 PHP
PHP5.2下preg_replace函数的问题
2015/05/08 PHP
Zend Framework教程之Zend_Config_Ini用法分析
2016/03/23 PHP
PHP的Yii框架中View视图的使用进阶
2016/03/29 PHP
PHP对象相关知识总结
2017/04/09 PHP
extjs fckeditor集成代码
2009/05/10 Javascript
起点页面传值js,有空研究学习下
2010/01/25 Javascript
CSS(js)限制页面显示的文本字符长度
2012/12/27 Javascript
jQuery的Ajax的自动完成功能控件简要说明
2013/02/22 Javascript
枚举的实现求得1-1000所有出现1的数字并计算出现1的个数
2013/09/10 Javascript
JS根据key值获取URL中的参数值及把URL的参数转换成json对象
2015/08/26 Javascript
深入剖析JavaScript编程中的对象概念
2015/10/21 Javascript
详解JavaScript时间处理之几个月前或几个月后的指定日期
2016/12/21 Javascript
js实现hashtable的赋值、取值、遍历操作实例详解
2016/12/25 Javascript
BootStrap学习系列之布局组件(下拉,按钮组[toolbar],上拉)
2017/01/03 Javascript
获取IE浏览器Cookie信息的方法
2017/01/23 Javascript
ES6新增数据结构WeakSet的用法详解
2017/08/07 Javascript
Vue组件模板及组件互相引用代码实例
2020/03/11 Javascript
基于vue 动态菜单 刷新空白问题的解决
2020/08/06 Javascript
python spyder中读取txt为图片的方法
2018/04/27 Python
深入理解python中sort()与sorted()的区别
2018/08/29 Python
python实现年会抽奖程序
2019/01/22 Python
使用tensorboard可视化loss和acc的实例
2020/01/21 Python
Django实现文章详情页面跳转代码实例
2020/09/16 Python
绝对令人的惊叹的CSS3折叠效果(3D效果)整理
2012/12/30 HTML / CSS
css3使网页、图片变成灰色兼容大多数浏览器
2014/07/02 HTML / CSS
英国标志性奢侈品牌:Burberry
2016/07/28 全球购物
旅游网创业计划书
2014/01/31 职场文书
2014年小班保育员工作总结
2014/12/23 职场文书
员工辞职信范文
2015/03/02 职场文书
污水处理保证书
2015/05/09 职场文书
超级礼物观后感
2015/06/15 职场文书
Nginx服务器如何设置url链接
2021/03/31 Servers
HTML+CSS制作心跳特效的实现
2021/05/26 HTML / CSS
Spring Data JPA框架持久化存储数据到数据库
2022/04/28 Java/Android