pytorch载入预训练模型后,实现训练指定层


Posted in Python onJanuary 06, 2020

1、有了已经训练好的模型参数,对这个模型的某些层做了改变,如何利用这些训练好的模型参数继续训练:

pretrained_params = torch.load('Pretrained_Model')
model = The_New_Model(xxx)
model.load_state_dict(pretrained_params.state_dict(), strict=False)

strict=False 使得预训练模型参数中和新模型对应上的参数会被载入,对应不上或没有的参数被抛弃。

2、如果载入的这些参数中,有些参数不要求被更新,即固定不变,不参与训练,需要手动设置这些参数的梯度属性为Fasle,并且在optimizer传参时筛选掉这些参数:

# 载入预训练模型参数后...
for name, value in model.named_parameters():
  if name 满足某些条件:
    value.requires_grad = False

# setup optimizer
params = filter(lambda p: p.requires_grad, model.parameters())
optimizer = torch.optim.Adam(params, lr=1e-4)

将满足条件的参数的 requires_grad 属性设置为False, 同时 filter 函数将模型中属性 requires_grad = True 的参数帅选出来,传到优化器(以Adam为例)中,只有这些参数会被求导数和更新。

3、如果载入的这些参数中,所有参数都更新,但要求一些参数和另一些参数的更新速度(学习率learning rate)不一样,最好知道这些参数的名称都有什么:

# 载入预训练模型参数后...
for name, value in model.named_parameters():
  print(name)
# 或
print(model.state_dict().keys())

假设该模型中有encoder,viewer和decoder两部分,参数名称分别是:

'encoder.visual_emb.0.weight',
'encoder.visual_emb.0.bias',
'viewer.bd.Wsi',
'viewer.bd.bias',
'decoder.core.layer_0.weight_ih',
'decoder.core.layer_0.weight_hh',

假设要求encode、viewer的学习率为1e-6, decoder的学习率为1e-4,那么在将参数传入优化器时:

ignored_params = list(map(id, model.decoder.parameters()))
base_params = filter(lambda p: id(p) not in ignored_params, model.parameters())
optimizer = torch.optim.Adam([{'params':base_params,'lr':1e-6},
               {'params':model.decoder.parameters()}
               ],
               lr=1e-4, momentum=0.9)

代码的结果是除decoder参数的learning_rate=1e-4 外,其他参数的额learning_rate=1e-6。

在传入optimizer时,和一般的传参方法torch.optim.Adam(model.parameters(), lr=xxx) 不同,参数部分用了一个list, list的每个元素有params和lr两个键值。如果没有 lr则应用Adam的lr属性。Adam的属性除了lr, 其他都是参数所共有的(比如momentum)。

以上这篇pytorch载入预训练模型后,实现训练指定层就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

参考:

Python 相关文章推荐
Python提取Linux内核源代码的目录结构实现方法
Jun 24 Python
python 提取tuple类型值中json格式的key值方法
Dec 31 Python
python 发送和接收ActiveMQ消息的实例
Jan 30 Python
python使用Plotly绘图工具绘制水平条形图
Mar 25 Python
python 实现二维字典的键值合并等函数
Dec 06 Python
解决使用python print打印函数返回值多一个None的问题
Apr 09 Python
IDLE下Python文件编辑和运行操作
Apr 25 Python
如何在mac版pycharm选择python版本
Jul 21 Python
Python 操作 MySQL数据库
Sep 18 Python
python 如何对logging日志封装
Dec 02 Python
基于Python-Pycharm实现的猴子摘桃小游戏(源代码)
Feb 20 Python
Python常遇到的错误和异常
Nov 02 Python
python与mysql数据库交互的实现
Jan 06 #Python
win10系统下python3安装及pip换源和使用教程
Jan 06 #Python
基于python实现文件加密功能
Jan 06 #Python
Pytorch 实现冻结指定卷积层的参数
Jan 06 #Python
如何使用python实现模拟鼠标点击
Jan 06 #Python
pytorch 实现查看网络中的参数
Jan 06 #Python
Python3 虚拟开发环境搭建过程(图文详解)
Jan 06 #Python
You might like
异步加载技术实现当滚动条到最底部的瀑布流效果
2014/09/16 PHP
php微信高级接口群发 多客服
2016/06/23 PHP
浅谈PHP的反射API
2017/02/26 PHP
php5.3/5.4/5.5/5.6/7常见新增特性汇总整理
2020/02/27 PHP
Laravel 框架基于自带的用户系统实现登录注册及错误处理功能分析
2020/04/14 PHP
解决jquery .ajax 在IE下卡死问题的解决方法
2009/10/26 Javascript
Jquery幻灯片特效代码分享--打开页面随机选择切换方式(3)
2015/08/15 Javascript
跟我学习javascript的arguments对象
2015/11/16 Javascript
通过命令行生成vue项目框架的方法
2017/07/12 Javascript
vue实现单选和多选功能
2017/08/11 Javascript
详解HTML5 使用video标签实现选择摄像头功能
2017/10/25 Javascript
基于vue.js实现分页查询功能
2018/12/29 Javascript
vue单页面在微信下只能分享落地页的解决方案
2019/04/15 Javascript
Node.js API详解之 dns模块用法实例分析
2020/05/15 Javascript
在vue中axios设置timeout超时的操作
2020/09/04 Javascript
[52:07]完美世界DOTA2联赛PWL S3 LBZS vs access 第二场 12.10
2020/12/13 DOTA
python实现百度关键词排名查询
2014/03/30 Python
Python使用正则表达式实现文本替换的方法
2017/04/18 Python
Flask框架响应、调度方法和蓝图操作实例分析
2018/07/24 Python
详解python中的hashlib模块的使用
2019/04/22 Python
python opencv 二值化 计算白色像素点的实例
2019/07/03 Python
Python使用正则表达式分割字符串的实现方法
2019/07/16 Python
python绘制随机网络图形示例
2019/11/21 Python
使用Pandas将inf, nan转化成特定的值
2019/12/19 Python
linux环境下安装python虚拟环境及注意事项
2020/01/07 Python
python求一个字符串的所有排列的实现方法
2020/02/04 Python
Python如何生成xml文件
2020/06/04 Python
摩托车和ATV零件、配件和服装的首选在线零售商:MotoSport
2017/12/22 全球购物
彪马香港官方网上商店:PUMA香港
2020/12/06 全球购物
如何获取某个日期是当月的最后一天
2013/12/05 面试题
区域销售经理职责
2013/12/22 职场文书
八一演出活动方案
2014/02/03 职场文书
环境科学专业优秀毕业生自荐书
2014/02/03 职场文书
四年级数学教学反思
2016/02/16 职场文书
Win11 引入 Windows 365 云操作系统,适应疫情期间混合办公模式:启动时直接登录、模
2022/04/06 数码科技
使用pd.merge表连接出现多余行的问题解决
2022/06/16 Python