pytorch载入预训练模型后,实现训练指定层


Posted in Python onJanuary 06, 2020

1、有了已经训练好的模型参数,对这个模型的某些层做了改变,如何利用这些训练好的模型参数继续训练:

pretrained_params = torch.load('Pretrained_Model')
model = The_New_Model(xxx)
model.load_state_dict(pretrained_params.state_dict(), strict=False)

strict=False 使得预训练模型参数中和新模型对应上的参数会被载入,对应不上或没有的参数被抛弃。

2、如果载入的这些参数中,有些参数不要求被更新,即固定不变,不参与训练,需要手动设置这些参数的梯度属性为Fasle,并且在optimizer传参时筛选掉这些参数:

# 载入预训练模型参数后...
for name, value in model.named_parameters():
  if name 满足某些条件:
    value.requires_grad = False

# setup optimizer
params = filter(lambda p: p.requires_grad, model.parameters())
optimizer = torch.optim.Adam(params, lr=1e-4)

将满足条件的参数的 requires_grad 属性设置为False, 同时 filter 函数将模型中属性 requires_grad = True 的参数帅选出来,传到优化器(以Adam为例)中,只有这些参数会被求导数和更新。

3、如果载入的这些参数中,所有参数都更新,但要求一些参数和另一些参数的更新速度(学习率learning rate)不一样,最好知道这些参数的名称都有什么:

# 载入预训练模型参数后...
for name, value in model.named_parameters():
  print(name)
# 或
print(model.state_dict().keys())

假设该模型中有encoder,viewer和decoder两部分,参数名称分别是:

'encoder.visual_emb.0.weight',
'encoder.visual_emb.0.bias',
'viewer.bd.Wsi',
'viewer.bd.bias',
'decoder.core.layer_0.weight_ih',
'decoder.core.layer_0.weight_hh',

假设要求encode、viewer的学习率为1e-6, decoder的学习率为1e-4,那么在将参数传入优化器时:

ignored_params = list(map(id, model.decoder.parameters()))
base_params = filter(lambda p: id(p) not in ignored_params, model.parameters())
optimizer = torch.optim.Adam([{'params':base_params,'lr':1e-6},
               {'params':model.decoder.parameters()}
               ],
               lr=1e-4, momentum=0.9)

代码的结果是除decoder参数的learning_rate=1e-4 外,其他参数的额learning_rate=1e-6。

在传入optimizer时,和一般的传参方法torch.optim.Adam(model.parameters(), lr=xxx) 不同,参数部分用了一个list, list的每个元素有params和lr两个键值。如果没有 lr则应用Adam的lr属性。Adam的属性除了lr, 其他都是参数所共有的(比如momentum)。

以上这篇pytorch载入预训练模型后,实现训练指定层就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

参考:

Python 相关文章推荐
对Python进行数据分析_关于Package的安装问题
May 22 Python
浅谈python中的__init__、__new__和__call__方法
Jul 18 Python
python中urlparse模块介绍与使用示例
Nov 19 Python
python实现人民币大写转换
Jun 20 Python
使用python中的in ,not in来检查元素是不是在列表中的方法
Jul 06 Python
详解python校验SQL脚本命名规则
Mar 22 Python
python里运用私有属性和方法总结
Jul 08 Python
python实现拼图小游戏
Feb 22 Python
详解Pycharm出现out of memory的终极解决方法
Mar 03 Python
Pytorch环境搭建与基本语法
Jun 03 Python
Python类的继承super相关原理解析
Oct 22 Python
Pytest中skip和skipif的具体使用方法
Jun 30 Python
python与mysql数据库交互的实现
Jan 06 #Python
win10系统下python3安装及pip换源和使用教程
Jan 06 #Python
基于python实现文件加密功能
Jan 06 #Python
Pytorch 实现冻结指定卷积层的参数
Jan 06 #Python
如何使用python实现模拟鼠标点击
Jan 06 #Python
pytorch 实现查看网络中的参数
Jan 06 #Python
Python3 虚拟开发环境搭建过程(图文详解)
Jan 06 #Python
You might like
php使用curl发送json格式数据实例
2013/12/17 PHP
PHP实现多关键字加亮功能
2016/10/21 PHP
PHP实现递归目录的5种方法
2016/10/27 PHP
jQuery 1.5.1 发布,全面支持IE9 修复大量bug
2011/02/26 Javascript
javascript中注册和移除事件的4种方式
2013/03/20 Javascript
jquery插件jTimer(jquery定时器)使用方法
2013/12/23 Javascript
javascript event在FF和IE的兼容传参心得(绝对好用)
2014/07/10 Javascript
JS按回车键实现登录的方法
2014/08/25 Javascript
10个JavaScript中易犯小错误
2016/02/14 Javascript
BootStrap入门教程(二)之固定的内置样式
2016/09/19 Javascript
Bootstrap CSS组件之按钮组(btn-group)
2016/12/17 Javascript
Node.js中如何合并两个复杂对象详解
2016/12/31 Javascript
为Jquery EasyUI 组件加上清除功能的方法(详解)
2017/04/13 jQuery
Bootstrap实现的标签页内容切换显示效果示例
2017/05/25 Javascript
详解Node.js access_token的获取、存储及更新
2017/06/20 Javascript
微信小程序页面滑动屏幕加载数据效果
2020/11/16 Javascript
Express结合Webpack的全栈自动刷新
2019/05/23 Javascript
微信小程序日历插件代码实例
2019/12/04 Javascript
JavaScript回调函数callback用法解析
2020/01/14 Javascript
使用优化器来提升Python程序的执行效率的教程
2015/04/02 Python
python更新列表的方法
2015/07/28 Python
Python抽象和自定义类定义与用法示例
2018/08/23 Python
Python3实现对列表按元组指定列进行排序的方法分析
2018/12/22 Python
Django在admin后台集成TinyMCE富文本编辑器的例子
2019/08/09 Python
tensorflow的ckpt及pb模型持久化方式及转化详解
2020/02/12 Python
Python3 Click模块的使用方法详解
2020/02/12 Python
Python作用域与名字空间原理详解
2020/03/21 Python
python pymysql链接数据库查询结果转为Dataframe实例
2020/06/05 Python
python线性插值解析
2020/07/05 Python
python 监控logcat关键字功能
2020/09/04 Python
合伙经营协议书范本
2014/09/13 职场文书
教师四风问题整改措施
2014/09/25 职场文书
安全保证书格式
2015/02/28 职场文书
实习推荐信格式模板
2015/03/27 职场文书
2015年音乐教研组工作总结
2015/07/22 职场文书
Docker安装MySql8并远程访问的实现
2022/07/07 Servers