pytorch载入预训练模型后,实现训练指定层


Posted in Python onJanuary 06, 2020

1、有了已经训练好的模型参数,对这个模型的某些层做了改变,如何利用这些训练好的模型参数继续训练:

pretrained_params = torch.load('Pretrained_Model')
model = The_New_Model(xxx)
model.load_state_dict(pretrained_params.state_dict(), strict=False)

strict=False 使得预训练模型参数中和新模型对应上的参数会被载入,对应不上或没有的参数被抛弃。

2、如果载入的这些参数中,有些参数不要求被更新,即固定不变,不参与训练,需要手动设置这些参数的梯度属性为Fasle,并且在optimizer传参时筛选掉这些参数:

# 载入预训练模型参数后...
for name, value in model.named_parameters():
  if name 满足某些条件:
    value.requires_grad = False

# setup optimizer
params = filter(lambda p: p.requires_grad, model.parameters())
optimizer = torch.optim.Adam(params, lr=1e-4)

将满足条件的参数的 requires_grad 属性设置为False, 同时 filter 函数将模型中属性 requires_grad = True 的参数帅选出来,传到优化器(以Adam为例)中,只有这些参数会被求导数和更新。

3、如果载入的这些参数中,所有参数都更新,但要求一些参数和另一些参数的更新速度(学习率learning rate)不一样,最好知道这些参数的名称都有什么:

# 载入预训练模型参数后...
for name, value in model.named_parameters():
  print(name)
# 或
print(model.state_dict().keys())

假设该模型中有encoder,viewer和decoder两部分,参数名称分别是:

'encoder.visual_emb.0.weight',
'encoder.visual_emb.0.bias',
'viewer.bd.Wsi',
'viewer.bd.bias',
'decoder.core.layer_0.weight_ih',
'decoder.core.layer_0.weight_hh',

假设要求encode、viewer的学习率为1e-6, decoder的学习率为1e-4,那么在将参数传入优化器时:

ignored_params = list(map(id, model.decoder.parameters()))
base_params = filter(lambda p: id(p) not in ignored_params, model.parameters())
optimizer = torch.optim.Adam([{'params':base_params,'lr':1e-6},
               {'params':model.decoder.parameters()}
               ],
               lr=1e-4, momentum=0.9)

代码的结果是除decoder参数的learning_rate=1e-4 外,其他参数的额learning_rate=1e-6。

在传入optimizer时,和一般的传参方法torch.optim.Adam(model.parameters(), lr=xxx) 不同,参数部分用了一个list, list的每个元素有params和lr两个键值。如果没有 lr则应用Adam的lr属性。Adam的属性除了lr, 其他都是参数所共有的(比如momentum)。

以上这篇pytorch载入预训练模型后,实现训练指定层就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

参考:

Python 相关文章推荐
python错误处理详解
Sep 28 Python
简单说明Python中的装饰器的用法
Apr 24 Python
python在windows下创建隐藏窗口子进程的方法
Jun 04 Python
Win7下搭建python开发环境图文教程(安装Python、pip、解释器)
May 17 Python
Python 中的with关键字使用详解
Sep 11 Python
Python2和Python3中print的用法示例总结
Oct 25 Python
PyQt5每天必学之事件与信号
Apr 20 Python
Python编写通讯录通过数据库存储实现模糊查询功能
Jul 18 Python
Python将string转换到float的实例方法
Jul 29 Python
Python爬取破解无线网络wifi密码过程解析
Sep 17 Python
python 用opencv实现霍夫线变换
Nov 27 Python
Python打包为exe详细教程
May 18 Python
python与mysql数据库交互的实现
Jan 06 #Python
win10系统下python3安装及pip换源和使用教程
Jan 06 #Python
基于python实现文件加密功能
Jan 06 #Python
Pytorch 实现冻结指定卷积层的参数
Jan 06 #Python
如何使用python实现模拟鼠标点击
Jan 06 #Python
pytorch 实现查看网络中的参数
Jan 06 #Python
Python3 虚拟开发环境搭建过程(图文详解)
Jan 06 #Python
You might like
关于二级目录拖拽排序的实现(源码示例下载)
2013/04/26 PHP
PHP函数引用返回的实例详解
2016/09/11 PHP
PHP实现链表的定义与反转功能示例
2018/06/09 PHP
brook javascript框架介绍
2011/10/10 Javascript
jQuery学习之prop和attr的区别示例介绍
2013/11/15 Javascript
JavaScript编程的10个实用小技巧
2014/04/18 Javascript
javascript事件模型实例分析
2015/01/30 Javascript
Node.js中的缓冲与流模块详细介绍
2015/02/11 Javascript
jquery模拟实现鼠标指针停止运动事件
2016/01/12 Javascript
jQuery删除当前节点元素
2016/12/07 Javascript
COM组件中调用JavaScript函数详解及实例
2017/02/23 Javascript
Angularjs实现下拉框联动的示例代码
2017/08/22 Javascript
vue-cli项目中怎么使用mock数据
2017/09/27 Javascript
详解js访问对象的属性和方法
2018/10/25 Javascript
js逆向解密之网络爬虫
2019/05/30 Javascript
electron实现静默打印的示例代码
2019/08/12 Javascript
layer.js open 隐藏滚动条的例子
2019/09/05 Javascript
原生js实现放大镜组件
2021/01/22 Javascript
[50:02]完美世界DOTA2联赛循环赛 Magma vs IO BO2第一场 11.01
2020/11/02 DOTA
Python生成随机密码
2015/03/10 Python
python简单贪吃蛇开发
2019/01/28 Python
pyqt5 使用cv2 显示图片,摄像头的实例
2019/06/27 Python
简单了解Python读取大文件代码实例
2019/12/18 Python
如何在Windows中安装多个python解释器
2020/06/16 Python
python3的pip路径在哪
2020/06/23 Python
python 密码学示例——理解哈希(Hash)算法
2020/09/21 Python
远程Wi-Fi宠物监控相机:Petcube
2017/04/26 全球购物
香港最大的洋酒零售连锁店:屈臣氏酒窖(Watson’s Wine)
2018/12/10 全球购物
英国和爱尔兰最大的地毯零售商:Kukoon
2018/12/17 全球购物
Pamela Love官网:纽约设计师Pamela Love的精美、时尚和穿孔珠宝
2020/10/19 全球购物
自荐信包含哪些内容
2013/10/30 职场文书
煤矿班组长的职责
2013/12/25 职场文书
师德学习感言
2014/01/31 职场文书
学习党的群众路线教育实践活动剖析材料
2014/10/13 职场文书
2015年银行大堂经理工作总结
2015/04/24 职场文书
中学感恩教育活动总结
2015/05/05 职场文书