pytorch载入预训练模型后,实现训练指定层


Posted in Python onJanuary 06, 2020

1、有了已经训练好的模型参数,对这个模型的某些层做了改变,如何利用这些训练好的模型参数继续训练:

pretrained_params = torch.load('Pretrained_Model')
model = The_New_Model(xxx)
model.load_state_dict(pretrained_params.state_dict(), strict=False)

strict=False 使得预训练模型参数中和新模型对应上的参数会被载入,对应不上或没有的参数被抛弃。

2、如果载入的这些参数中,有些参数不要求被更新,即固定不变,不参与训练,需要手动设置这些参数的梯度属性为Fasle,并且在optimizer传参时筛选掉这些参数:

# 载入预训练模型参数后...
for name, value in model.named_parameters():
  if name 满足某些条件:
    value.requires_grad = False

# setup optimizer
params = filter(lambda p: p.requires_grad, model.parameters())
optimizer = torch.optim.Adam(params, lr=1e-4)

将满足条件的参数的 requires_grad 属性设置为False, 同时 filter 函数将模型中属性 requires_grad = True 的参数帅选出来,传到优化器(以Adam为例)中,只有这些参数会被求导数和更新。

3、如果载入的这些参数中,所有参数都更新,但要求一些参数和另一些参数的更新速度(学习率learning rate)不一样,最好知道这些参数的名称都有什么:

# 载入预训练模型参数后...
for name, value in model.named_parameters():
  print(name)
# 或
print(model.state_dict().keys())

假设该模型中有encoder,viewer和decoder两部分,参数名称分别是:

'encoder.visual_emb.0.weight',
'encoder.visual_emb.0.bias',
'viewer.bd.Wsi',
'viewer.bd.bias',
'decoder.core.layer_0.weight_ih',
'decoder.core.layer_0.weight_hh',

假设要求encode、viewer的学习率为1e-6, decoder的学习率为1e-4,那么在将参数传入优化器时:

ignored_params = list(map(id, model.decoder.parameters()))
base_params = filter(lambda p: id(p) not in ignored_params, model.parameters())
optimizer = torch.optim.Adam([{'params':base_params,'lr':1e-6},
               {'params':model.decoder.parameters()}
               ],
               lr=1e-4, momentum=0.9)

代码的结果是除decoder参数的learning_rate=1e-4 外,其他参数的额learning_rate=1e-6。

在传入optimizer时,和一般的传参方法torch.optim.Adam(model.parameters(), lr=xxx) 不同,参数部分用了一个list, list的每个元素有params和lr两个键值。如果没有 lr则应用Adam的lr属性。Adam的属性除了lr, 其他都是参数所共有的(比如momentum)。

以上这篇pytorch载入预训练模型后,实现训练指定层就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

参考:

Python 相关文章推荐
Python中字符编码简介、方法及使用建议
Jan 08 Python
在Python中操作字典之fromkeys()方法的使用
May 21 Python
Python环境搭建之OpenCV的步骤方法
Oct 20 Python
Python编程给numpy矩阵添加一列方法示例
Dec 04 Python
解决PyCharm不运行脚本,而是运行单元测试的问题
Jan 17 Python
Python设计模式之代理模式实例详解
Jan 19 Python
python远程邮件控制电脑升级版
May 23 Python
如何在Django项目中引入静态文件
Jul 26 Python
使用django和vue进行数据交互的方法步骤
Nov 11 Python
Tensorflow实现将标签变为one-hot形式
May 22 Python
keras自动编码器实现系列之卷积自动编码器操作
Jul 03 Python
据Python爬虫不靠谱预测可知今年双十一销售额将超过6000亿元
Nov 11 Python
python与mysql数据库交互的实现
Jan 06 #Python
win10系统下python3安装及pip换源和使用教程
Jan 06 #Python
基于python实现文件加密功能
Jan 06 #Python
Pytorch 实现冻结指定卷积层的参数
Jan 06 #Python
如何使用python实现模拟鼠标点击
Jan 06 #Python
pytorch 实现查看网络中的参数
Jan 06 #Python
Python3 虚拟开发环境搭建过程(图文详解)
Jan 06 #Python
You might like
我的论坛源代码(九)
2006/10/09 PHP
亲密接触PHP之PHP语法学习笔记1
2006/12/17 PHP
php模拟asp中的XmlHttpRequest实现http请求的代码
2011/03/24 PHP
php 下载保存文件保存到本地的两种实现方法
2013/08/12 PHP
PHP使用逆波兰式计算工资的方法
2015/07/29 PHP
浅谈使用 Yii2 AssetBundle 中 $publishOptions 的正确姿势
2017/11/08 PHP
Ecshop 后台添加新功能栏目及管理权限设置教程
2017/11/21 PHP
一个简单的js鼠标划过切换效果
2010/06/30 Javascript
Javascript中的window.event.keyCode使用介绍
2011/04/26 Javascript
js获取本机的外网/广域网ip地址完整源码
2013/08/12 Javascript
js生成随机数之random函数随机示例
2013/12/20 Javascript
无刷新预览所选择的图片示例代码
2014/04/02 Javascript
简介JavaScript中setUTCSeconds()方法的使用
2015/06/12 Javascript
JavaScript实现点击按钮切换网页背景色的方法
2015/10/17 Javascript
基于BootStrap Metronic开发框架经验小结【五】Bootstrap File Input文件上传插件的用法详解
2016/05/12 Javascript
jQuery插件FusionCharts绘制的2D双面积图效果示例【附demo源码】
2017/04/11 jQuery
node.js中EJS 模板快速入门教程
2017/05/08 Javascript
layui弹出层效果实现代码
2017/05/19 Javascript
VUE使用vuex解决模块间传值问题的方法
2017/06/01 Javascript
Layui点击图片弹框预览的实现方法
2019/09/16 Javascript
Vue+scss白天和夜间模式切换功能的实现方法
2021/01/05 Vue.js
vue3.0中使用element的完整步骤
2021/03/04 Vue.js
在Python中利用Into包整洁地进行数据迁移的教程
2015/03/30 Python
Python中内置的日志模块logging用法详解
2016/07/12 Python
使用Python进行目录的对比方法
2018/11/01 Python
django 邮件发送模块smtp使用详解
2019/07/22 Python
初学者学习Python好还是Java好
2020/05/26 Python
Python如何把字典写入到CSV文件的方法示例
2020/08/23 Python
2014年四风问题个人对照自查剖析材料
2014/09/15 职场文书
大学生考试作弊检讨书
2014/09/21 职场文书
安全生产工作汇报
2014/10/28 职场文书
入团申请书格式
2019/06/20 职场文书
创业计划书之o2o水果店
2019/08/30 职场文书
Jupyter notebook 不自动弹出网页的解决方案
2021/05/21 Python
vue项目支付功能代码详解
2022/02/18 Vue.js
OpenStack虚拟机快照和增量备份实现方法
2022/04/04 Servers