pytorch载入预训练模型后,实现训练指定层


Posted in Python onJanuary 06, 2020

1、有了已经训练好的模型参数,对这个模型的某些层做了改变,如何利用这些训练好的模型参数继续训练:

pretrained_params = torch.load('Pretrained_Model')
model = The_New_Model(xxx)
model.load_state_dict(pretrained_params.state_dict(), strict=False)

strict=False 使得预训练模型参数中和新模型对应上的参数会被载入,对应不上或没有的参数被抛弃。

2、如果载入的这些参数中,有些参数不要求被更新,即固定不变,不参与训练,需要手动设置这些参数的梯度属性为Fasle,并且在optimizer传参时筛选掉这些参数:

# 载入预训练模型参数后...
for name, value in model.named_parameters():
  if name 满足某些条件:
    value.requires_grad = False

# setup optimizer
params = filter(lambda p: p.requires_grad, model.parameters())
optimizer = torch.optim.Adam(params, lr=1e-4)

将满足条件的参数的 requires_grad 属性设置为False, 同时 filter 函数将模型中属性 requires_grad = True 的参数帅选出来,传到优化器(以Adam为例)中,只有这些参数会被求导数和更新。

3、如果载入的这些参数中,所有参数都更新,但要求一些参数和另一些参数的更新速度(学习率learning rate)不一样,最好知道这些参数的名称都有什么:

# 载入预训练模型参数后...
for name, value in model.named_parameters():
  print(name)
# 或
print(model.state_dict().keys())

假设该模型中有encoder,viewer和decoder两部分,参数名称分别是:

'encoder.visual_emb.0.weight',
'encoder.visual_emb.0.bias',
'viewer.bd.Wsi',
'viewer.bd.bias',
'decoder.core.layer_0.weight_ih',
'decoder.core.layer_0.weight_hh',

假设要求encode、viewer的学习率为1e-6, decoder的学习率为1e-4,那么在将参数传入优化器时:

ignored_params = list(map(id, model.decoder.parameters()))
base_params = filter(lambda p: id(p) not in ignored_params, model.parameters())
optimizer = torch.optim.Adam([{'params':base_params,'lr':1e-6},
               {'params':model.decoder.parameters()}
               ],
               lr=1e-4, momentum=0.9)

代码的结果是除decoder参数的learning_rate=1e-4 外,其他参数的额learning_rate=1e-6。

在传入optimizer时,和一般的传参方法torch.optim.Adam(model.parameters(), lr=xxx) 不同,参数部分用了一个list, list的每个元素有params和lr两个键值。如果没有 lr则应用Adam的lr属性。Adam的属性除了lr, 其他都是参数所共有的(比如momentum)。

以上这篇pytorch载入预训练模型后,实现训练指定层就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

参考:

Python 相关文章推荐
使用优化器来提升Python程序的执行效率的教程
Apr 02 Python
python中字典dict常用操作方法实例总结
Apr 04 Python
使用Django的模版来配合字符串翻译工作
Jul 27 Python
读取json格式为DataFrame(可转为.csv)的实例讲解
Jun 05 Python
Python实现的网页截图功能【PyQt4与selenium组件】
Jul 12 Python
对python多线程SSH登录并发脚本详解
Feb 14 Python
python 处理telnet返回的More,以及get想要的那个参数方法
Feb 14 Python
Python decimal模块使用方法详解
Jun 08 Python
python切片作为占位符使用实例讲解
Feb 17 Python
python执行js代码的方法
May 13 Python
jupyter notebook保存文件默认路径更改方法汇总(亲测可以)
Jun 09 Python
Python自动化工具之实现Excel转Markdown表格
Apr 08 Python
python与mysql数据库交互的实现
Jan 06 #Python
win10系统下python3安装及pip换源和使用教程
Jan 06 #Python
基于python实现文件加密功能
Jan 06 #Python
Pytorch 实现冻结指定卷积层的参数
Jan 06 #Python
如何使用python实现模拟鼠标点击
Jan 06 #Python
pytorch 实现查看网络中的参数
Jan 06 #Python
Python3 虚拟开发环境搭建过程(图文详解)
Jan 06 #Python
You might like
PHP新手上路(九)
2006/10/09 PHP
常用的PHP数据库操作方法(MYSQL版)
2011/06/08 PHP
一个php短网址的生成代码(仿微博短网址)
2014/05/07 PHP
PHP获取二维数组中某一列的值集合
2015/12/25 PHP
zend framework中使用memcache的方法
2016/03/04 PHP
可自定义速度的js图片无缝滚动示例分享
2014/01/20 Javascript
JS实现简单的顶部定时关闭层效果
2014/06/15 Javascript
《JavaScript DOM 编程艺术》读书笔记之JavaScript 简史
2015/01/09 Javascript
基于Jquery代码实现支持PC端手机端幻灯片代码
2015/11/17 Javascript
基于Jquery和html5实现炫酷的3D焦点图动画
2016/03/02 Javascript
如何用js判断dom是否有存在某class的值
2017/02/13 Javascript
jQuery中Chosen三级联动功能实例代码
2017/03/07 Javascript
javascript 动态生成css代码的两种方法
2017/03/17 Javascript
基于JS实现仿京东搜索栏随滑动透明度渐变效果
2017/07/10 Javascript
Node.js中环境变量process.env的一些事详解
2017/10/26 Javascript
vue-scroller记录滚动位置的示例代码
2018/01/17 Javascript
JavaScript中Object基础内部方法图
2018/02/05 Javascript
js作用域和作用域链及预解析
2019/04/11 Javascript
微信小程序加载机制及运行机制图解
2019/11/27 Javascript
微信小程序学习总结(一)项目创建与目录结构分析
2020/06/04 Javascript
VUE中V-IF条件判断改变元素的样式操作
2020/08/09 Javascript
[37:45]完美世界DOTA2联赛PWL S3 LBZS vs Phoenix 第二场 12.09
2020/12/11 DOTA
怎样使用Python脚本日志功能
2016/08/14 Python
详解重置Django migration的常见方式
2019/02/15 Python
Python3.6.x中内置函数总结及讲解
2019/02/22 Python
python-opencv获取二值图像轮廓及中心点坐标的代码
2019/08/27 Python
python中自带的三个装饰器的实现
2019/11/08 Python
python等待10秒执行下一命令的方法
2020/07/19 Python
Erwin Müller穆勒家居瑞士官网:您整个家庭的邮购公司
2019/12/28 全球购物
英文简历中的自我评价
2013/10/06 职场文书
行政人员工作职责
2013/12/05 职场文书
上班早退检讨书
2014/01/09 职场文书
我的求职计划书
2014/01/10 职场文书
电子专业毕业生自荐信
2014/05/25 职场文书
Go语言基础函数基本用法及示例详解
2021/11/17 Golang
Go gorilla/sessions库安装使用
2022/08/14 Golang