pytorch载入预训练模型后,实现训练指定层


Posted in Python onJanuary 06, 2020

1、有了已经训练好的模型参数,对这个模型的某些层做了改变,如何利用这些训练好的模型参数继续训练:

pretrained_params = torch.load('Pretrained_Model')
model = The_New_Model(xxx)
model.load_state_dict(pretrained_params.state_dict(), strict=False)

strict=False 使得预训练模型参数中和新模型对应上的参数会被载入,对应不上或没有的参数被抛弃。

2、如果载入的这些参数中,有些参数不要求被更新,即固定不变,不参与训练,需要手动设置这些参数的梯度属性为Fasle,并且在optimizer传参时筛选掉这些参数:

# 载入预训练模型参数后...
for name, value in model.named_parameters():
  if name 满足某些条件:
    value.requires_grad = False

# setup optimizer
params = filter(lambda p: p.requires_grad, model.parameters())
optimizer = torch.optim.Adam(params, lr=1e-4)

将满足条件的参数的 requires_grad 属性设置为False, 同时 filter 函数将模型中属性 requires_grad = True 的参数帅选出来,传到优化器(以Adam为例)中,只有这些参数会被求导数和更新。

3、如果载入的这些参数中,所有参数都更新,但要求一些参数和另一些参数的更新速度(学习率learning rate)不一样,最好知道这些参数的名称都有什么:

# 载入预训练模型参数后...
for name, value in model.named_parameters():
  print(name)
# 或
print(model.state_dict().keys())

假设该模型中有encoder,viewer和decoder两部分,参数名称分别是:

'encoder.visual_emb.0.weight',
'encoder.visual_emb.0.bias',
'viewer.bd.Wsi',
'viewer.bd.bias',
'decoder.core.layer_0.weight_ih',
'decoder.core.layer_0.weight_hh',

假设要求encode、viewer的学习率为1e-6, decoder的学习率为1e-4,那么在将参数传入优化器时:

ignored_params = list(map(id, model.decoder.parameters()))
base_params = filter(lambda p: id(p) not in ignored_params, model.parameters())
optimizer = torch.optim.Adam([{'params':base_params,'lr':1e-6},
               {'params':model.decoder.parameters()}
               ],
               lr=1e-4, momentum=0.9)

代码的结果是除decoder参数的learning_rate=1e-4 外,其他参数的额learning_rate=1e-6。

在传入optimizer时,和一般的传参方法torch.optim.Adam(model.parameters(), lr=xxx) 不同,参数部分用了一个list, list的每个元素有params和lr两个键值。如果没有 lr则应用Adam的lr属性。Adam的属性除了lr, 其他都是参数所共有的(比如momentum)。

以上这篇pytorch载入预训练模型后,实现训练指定层就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

参考:

Python 相关文章推荐
python中常用的各种数据库操作模块和连接实例
May 29 Python
django模型中的字段和model名显示为中文小技巧分享
Nov 18 Python
详解Python中的strftime()方法的使用
May 22 Python
浅谈python中的面向对象和类的基本语法
Jun 13 Python
python生成二维码的实例详解
Oct 29 Python
Python图形绘制操作之正弦曲线实现方法分析
Dec 25 Python
Python使用Windows API创建窗口示例【基于win32gui模块】
May 09 Python
Matplotlib绘制雷达图和三维图的示例代码
Jan 07 Python
tensorflow 利用expand_dims和squeeze扩展和压缩tensor维度方式
Feb 07 Python
Keras - GPU ID 和显存占用设定步骤
Jun 22 Python
只用50行Python代码爬取网络美女高清图片
Jun 02 Python
PYTHON基于Pyecharts绘制常见的直角坐标系图表
Apr 28 Python
python与mysql数据库交互的实现
Jan 06 #Python
win10系统下python3安装及pip换源和使用教程
Jan 06 #Python
基于python实现文件加密功能
Jan 06 #Python
Pytorch 实现冻结指定卷积层的参数
Jan 06 #Python
如何使用python实现模拟鼠标点击
Jan 06 #Python
pytorch 实现查看网络中的参数
Jan 06 #Python
Python3 虚拟开发环境搭建过程(图文详解)
Jan 06 #Python
You might like
php selectradio和checkbox默认选择的实现方法详解
2013/06/29 PHP
php 判断是否是中文/英文/数字示例代码
2013/09/30 PHP
php解决抢购秒杀抽奖等大流量并发入库导致的库存负数的问题
2014/06/19 PHP
详解Yii2 定制表单输入字段的标签和样式
2017/01/04 PHP
PHP基于mcript扩展实现对称加密功能示例
2019/02/21 PHP
php经典趣味算法实例代码
2020/01/21 PHP
PHP实现Snowflake生成分布式唯一ID的方法示例
2020/08/30 PHP
js函数使用技巧之 setTimeout(function(){},0)
2009/02/09 Javascript
jQuery的deferred对象使用详解
2011/08/20 Javascript
JS保存、读取、换行、转Json报错处理方法
2013/06/14 Javascript
js将当前时间格式转换成时间搓(自写)
2013/09/26 Javascript
javascript得到当前页的来路即前一页地址的方法
2014/02/18 Javascript
用JavaScript实现用一个DIV来包装文本元素节点
2014/09/09 Javascript
jQuery Timelinr实现垂直水平时间轴插件(附源码下载)
2016/02/16 Javascript
js实现的简单图片浮动效果完整实例
2016/05/10 Javascript
JS实现含有中文字符串的友好截取功能分析
2017/03/13 Javascript
JavaScript实现树的遍历算法示例【广度优先与深度优先】
2017/10/26 Javascript
详解js的作用域、预解析机制
2018/02/05 Javascript
详解vue中this.$emit()的返回值是什么
2019/04/07 Javascript
详解python里使用正则表达式的全匹配功能
2017/10/19 Python
基于Python 装饰器装饰类中的方法实例
2018/04/21 Python
python实现五子棋人机对战游戏
2020/03/25 Python
python 抓包保存为pcap文件并解析的实例
2019/07/23 Python
Django框架反向解析操作详解
2019/11/28 Python
How TDD works
2012/09/30 面试题
专科文秘应届生求职信
2013/11/18 职场文书
读书活动总结
2014/04/28 职场文书
幼儿教师演讲稿
2014/05/06 职场文书
奉献家乡演讲稿
2014/09/13 职场文书
三八妇女节标语
2014/10/09 职场文书
2014年安全管理工作总结
2014/12/01 职场文书
预备党员表决心的话
2015/09/22 职场文书
大学生心理健康教育心得体会
2016/01/12 职场文书
MySQL快速插入一亿测试数据
2021/06/23 MySQL
Python初识逻辑与if语句及用法大全
2021/08/07 Python
JavaScript小技巧带你提升你的代码技能
2021/09/15 Javascript