pytorch载入预训练模型后,实现训练指定层


Posted in Python onJanuary 06, 2020

1、有了已经训练好的模型参数,对这个模型的某些层做了改变,如何利用这些训练好的模型参数继续训练:

pretrained_params = torch.load('Pretrained_Model')
model = The_New_Model(xxx)
model.load_state_dict(pretrained_params.state_dict(), strict=False)

strict=False 使得预训练模型参数中和新模型对应上的参数会被载入,对应不上或没有的参数被抛弃。

2、如果载入的这些参数中,有些参数不要求被更新,即固定不变,不参与训练,需要手动设置这些参数的梯度属性为Fasle,并且在optimizer传参时筛选掉这些参数:

# 载入预训练模型参数后...
for name, value in model.named_parameters():
  if name 满足某些条件:
    value.requires_grad = False

# setup optimizer
params = filter(lambda p: p.requires_grad, model.parameters())
optimizer = torch.optim.Adam(params, lr=1e-4)

将满足条件的参数的 requires_grad 属性设置为False, 同时 filter 函数将模型中属性 requires_grad = True 的参数帅选出来,传到优化器(以Adam为例)中,只有这些参数会被求导数和更新。

3、如果载入的这些参数中,所有参数都更新,但要求一些参数和另一些参数的更新速度(学习率learning rate)不一样,最好知道这些参数的名称都有什么:

# 载入预训练模型参数后...
for name, value in model.named_parameters():
  print(name)
# 或
print(model.state_dict().keys())

假设该模型中有encoder,viewer和decoder两部分,参数名称分别是:

'encoder.visual_emb.0.weight',
'encoder.visual_emb.0.bias',
'viewer.bd.Wsi',
'viewer.bd.bias',
'decoder.core.layer_0.weight_ih',
'decoder.core.layer_0.weight_hh',

假设要求encode、viewer的学习率为1e-6, decoder的学习率为1e-4,那么在将参数传入优化器时:

ignored_params = list(map(id, model.decoder.parameters()))
base_params = filter(lambda p: id(p) not in ignored_params, model.parameters())
optimizer = torch.optim.Adam([{'params':base_params,'lr':1e-6},
               {'params':model.decoder.parameters()}
               ],
               lr=1e-4, momentum=0.9)

代码的结果是除decoder参数的learning_rate=1e-4 外,其他参数的额learning_rate=1e-6。

在传入optimizer时,和一般的传参方法torch.optim.Adam(model.parameters(), lr=xxx) 不同,参数部分用了一个list, list的每个元素有params和lr两个键值。如果没有 lr则应用Adam的lr属性。Adam的属性除了lr, 其他都是参数所共有的(比如momentum)。

以上这篇pytorch载入预训练模型后,实现训练指定层就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

参考:

Python 相关文章推荐
Python的高级Git库 Gittle
Sep 22 Python
python实现杨辉三角思路
Jul 14 Python
tensorflow saver 保存和恢复指定 tensor的实例讲解
Jul 26 Python
python基础梳理(一)(推荐)
Apr 06 Python
Python两台电脑实现TCP通信的方法示例
May 06 Python
python3.6使用tkinter实现弹跳小球游戏
May 09 Python
浅析Python3中的对象垃圾收集机制
Jun 06 Python
PyCharm搭建Spark开发环境实现第一个pyspark程序
Jun 13 Python
Django 过滤器汇总及自定义过滤器使用详解
Jul 19 Python
将Pytorch模型从CPU转换成GPU的实现方法
Aug 19 Python
Python钉钉报警及Zabbix集成钉钉报警的示例代码
Aug 17 Python
Python OpenCV之常用滤波器使用详解
Apr 07 Python
python与mysql数据库交互的实现
Jan 06 #Python
win10系统下python3安装及pip换源和使用教程
Jan 06 #Python
基于python实现文件加密功能
Jan 06 #Python
Pytorch 实现冻结指定卷积层的参数
Jan 06 #Python
如何使用python实现模拟鼠标点击
Jan 06 #Python
pytorch 实现查看网络中的参数
Jan 06 #Python
Python3 虚拟开发环境搭建过程(图文详解)
Jan 06 #Python
You might like
php的header和asp中的redirect比较
2006/10/09 PHP
dedecms模板标签代码官方参考
2007/03/17 PHP
php使用glob函数遍历文件和目录详解
2016/09/23 PHP
CI框架封装的常用图像处理方法(缩略图,水印,旋转,上传等)
2016/11/22 PHP
基于PHP实现栈数据结构和括号匹配算法示例
2017/08/10 PHP
PHP实现的服务器一致性hash分布算法示例
2018/08/09 PHP
PHP进阶学习之依赖注入与Ioc容器详解
2019/06/19 PHP
js构造函数、索引数组和属性的实现方式和使用
2014/11/16 Javascript
node.js下when.js 的异步编程实践
2014/12/03 Javascript
JS中产生标识符方式的演变
2015/06/12 Javascript
分享两段简单的JS代码防止SQL注入
2016/04/12 Javascript
Javascript闭包与函数柯里化浅析
2016/06/22 Javascript
解决拦截器对ajax请求的拦截实例详解
2016/12/21 Javascript
bootstrap日期控件问题(双日期、清空等问题解决)
2017/04/19 Javascript
vue+swiper实现侧滑菜单效果
2017/12/28 Javascript
JS实现将链接生成二维码并转为图片的方法
2018/03/17 Javascript
使用Vue.js和Flask来构建一个单页的App的示例
2018/03/21 Javascript
微信小程序之onLaunch与onload异步问题详解
2019/03/28 Javascript
vue使用keep-alive保持滚动条位置的实现方法
2019/04/09 Javascript
使用jQuery mobile NuGet让你的网站在移动设备上同样精彩
2019/06/18 jQuery
JS实现移动端在线签协议功能
2019/08/22 Javascript
手动实现vue2.0的双向数据绑定原理详解
2021/02/06 Vue.js
python ip正则式
2009/05/07 Python
pyqt5实现俄罗斯方块游戏
2019/01/11 Python
css3实现多个元素依次显示效果
2017/12/12 HTML / CSS
Bibloo荷兰:女士、男士和儿童的服装、鞋子和配饰
2019/02/25 全球购物
C#公司笔试题
2014/03/28 面试题
vue项目实现分页效果
2021/03/24 Vue.js
大学生优秀自荐信范文
2014/02/25 职场文书
医师定期考核实施方案
2014/05/07 职场文书
上课迟到检讨书范文
2015/05/06 职场文书
关于法制教育的宣传语
2015/07/13 职场文书
运动会加油稿30字
2015/07/21 职场文书
JS ES6异步解决方案
2021/04/29 Javascript
JavaCV实现照片马赛克效果
2022/01/22 Java/Android
Python sklearn分类决策树方法详解
2022/09/23 Python