pytorch载入预训练模型后,实现训练指定层


Posted in Python onJanuary 06, 2020

1、有了已经训练好的模型参数,对这个模型的某些层做了改变,如何利用这些训练好的模型参数继续训练:

pretrained_params = torch.load('Pretrained_Model')
model = The_New_Model(xxx)
model.load_state_dict(pretrained_params.state_dict(), strict=False)

strict=False 使得预训练模型参数中和新模型对应上的参数会被载入,对应不上或没有的参数被抛弃。

2、如果载入的这些参数中,有些参数不要求被更新,即固定不变,不参与训练,需要手动设置这些参数的梯度属性为Fasle,并且在optimizer传参时筛选掉这些参数:

# 载入预训练模型参数后...
for name, value in model.named_parameters():
  if name 满足某些条件:
    value.requires_grad = False

# setup optimizer
params = filter(lambda p: p.requires_grad, model.parameters())
optimizer = torch.optim.Adam(params, lr=1e-4)

将满足条件的参数的 requires_grad 属性设置为False, 同时 filter 函数将模型中属性 requires_grad = True 的参数帅选出来,传到优化器(以Adam为例)中,只有这些参数会被求导数和更新。

3、如果载入的这些参数中,所有参数都更新,但要求一些参数和另一些参数的更新速度(学习率learning rate)不一样,最好知道这些参数的名称都有什么:

# 载入预训练模型参数后...
for name, value in model.named_parameters():
  print(name)
# 或
print(model.state_dict().keys())

假设该模型中有encoder,viewer和decoder两部分,参数名称分别是:

'encoder.visual_emb.0.weight',
'encoder.visual_emb.0.bias',
'viewer.bd.Wsi',
'viewer.bd.bias',
'decoder.core.layer_0.weight_ih',
'decoder.core.layer_0.weight_hh',

假设要求encode、viewer的学习率为1e-6, decoder的学习率为1e-4,那么在将参数传入优化器时:

ignored_params = list(map(id, model.decoder.parameters()))
base_params = filter(lambda p: id(p) not in ignored_params, model.parameters())
optimizer = torch.optim.Adam([{'params':base_params,'lr':1e-6},
               {'params':model.decoder.parameters()}
               ],
               lr=1e-4, momentum=0.9)

代码的结果是除decoder参数的learning_rate=1e-4 外,其他参数的额learning_rate=1e-6。

在传入optimizer时,和一般的传参方法torch.optim.Adam(model.parameters(), lr=xxx) 不同,参数部分用了一个list, list的每个元素有params和lr两个键值。如果没有 lr则应用Adam的lr属性。Adam的属性除了lr, 其他都是参数所共有的(比如momentum)。

以上这篇pytorch载入预训练模型后,实现训练指定层就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

参考:

Python 相关文章推荐
python设置windows桌面壁纸的实现代码
Jan 28 Python
压缩包密码破解示例分享(类似典破解)
Jan 17 Python
Python的string模块中的Template类字符串模板用法
Jun 27 Python
apache部署python程序出现503错误的解决方法
Jul 24 Python
Python面向对象编程基础解析(一)
Oct 26 Python
Python中将变量按行写入txt文本中的方法
Apr 03 Python
Django Form 实时从数据库中获取数据的操作方法
Jul 25 Python
python3 enum模块的应用实例详解
Aug 12 Python
python 解决tqdm模块不能单行显示的问题
Feb 19 Python
在Keras中实现保存和加载权重及模型结构
Jun 15 Python
python中count函数知识点浅析
Dec 17 Python
用python获取txt文件中关键字的数量
Dec 24 Python
python与mysql数据库交互的实现
Jan 06 #Python
win10系统下python3安装及pip换源和使用教程
Jan 06 #Python
基于python实现文件加密功能
Jan 06 #Python
Pytorch 实现冻结指定卷积层的参数
Jan 06 #Python
如何使用python实现模拟鼠标点击
Jan 06 #Python
pytorch 实现查看网络中的参数
Jan 06 #Python
Python3 虚拟开发环境搭建过程(图文详解)
Jan 06 #Python
You might like
PHP 读取文本文件内容并分页显示
2016/01/02 PHP
DEDE实现转跳属性文档在模板上调用出转跳地址
2016/11/04 PHP
PHP加MySQL消息队列深入理解
2021/02/27 PHP
JQuery 自定义CircleAnimation,Animate方法学习笔记
2011/07/10 Javascript
中国地区三级联动下拉菜单效果分析
2012/11/15 Javascript
JS中Iframe之间传值的方法
2013/03/11 Javascript
jQuery插件multiScroll实现全屏鼠标滚动切换页面特效
2015/04/12 Javascript
JavaScript数组对象赋值用法实例
2015/08/04 Javascript
javascript插件开发的一些感想和心得
2016/02/28 Javascript
jQuery向父辈遍历的简单方法
2016/09/18 Javascript
jquery无法为动态生成的元素添加点击事件的解决方法(推荐)
2016/12/26 Javascript
JS编写兼容IE6,7,8浏览器无缝自动轮播
2018/10/12 Javascript
Vee-validate 父组件获取子组件表单校验结果的实例代码
2019/05/20 Javascript
利用js-cookie实现前端设置缓存数据定时失效
2019/06/18 Javascript
node-red File读取好保存实例讲解
2019/09/11 Javascript
js实现右键弹出自定义菜单
2020/09/08 Javascript
Python使用cx_Oracle模块将oracle中数据导出到csv文件的方法
2015/05/16 Python
Python实现求最大公约数及判断素数的方法
2015/05/26 Python
Python函数式编程指南(三):迭代器详解
2015/06/24 Python
Python的IDEL增加清屏功能实例
2017/06/19 Python
python2 与python3的print区别小结
2018/01/16 Python
Python中生成器和迭代器的区别详解
2018/02/10 Python
Linux CentOS Python开发环境搭建教程
2018/11/28 Python
Tensorflow 模型转换 .pb convert to .lite实例
2020/02/12 Python
Doyoueven官网:澳大利亚健身服饰和配饰品牌
2019/03/24 全球购物
欧缇丽加拿大官方网站:Caudalie加拿大
2019/07/18 全球购物
施华洛世奇巴西官网:SWAROVSKI巴西
2019/12/03 全球购物
工程监理应届生求职信
2013/11/09 职场文书
大学生就业推荐信范文
2013/11/29 职场文书
医疗纠纷协议书
2014/04/16 职场文书
市场营销工作计划书
2014/05/06 职场文书
2014年社区矫正工作总结
2014/11/18 职场文书
2016医师资格考试考生诚信考试承诺书
2016/03/25 职场文书
关于党风廉政建设宣传教育月的活动总结!
2019/08/08 职场文书
电脑只能进入安全模式无法正常启动的解决办法
2022/04/08 数码科技
Li list-style-image 图片垂直居中实现方法
2023/05/21 HTML / CSS