pytorch载入预训练模型后,实现训练指定层


Posted in Python onJanuary 06, 2020

1、有了已经训练好的模型参数,对这个模型的某些层做了改变,如何利用这些训练好的模型参数继续训练:

pretrained_params = torch.load('Pretrained_Model')
model = The_New_Model(xxx)
model.load_state_dict(pretrained_params.state_dict(), strict=False)

strict=False 使得预训练模型参数中和新模型对应上的参数会被载入,对应不上或没有的参数被抛弃。

2、如果载入的这些参数中,有些参数不要求被更新,即固定不变,不参与训练,需要手动设置这些参数的梯度属性为Fasle,并且在optimizer传参时筛选掉这些参数:

# 载入预训练模型参数后...
for name, value in model.named_parameters():
  if name 满足某些条件:
    value.requires_grad = False

# setup optimizer
params = filter(lambda p: p.requires_grad, model.parameters())
optimizer = torch.optim.Adam(params, lr=1e-4)

将满足条件的参数的 requires_grad 属性设置为False, 同时 filter 函数将模型中属性 requires_grad = True 的参数帅选出来,传到优化器(以Adam为例)中,只有这些参数会被求导数和更新。

3、如果载入的这些参数中,所有参数都更新,但要求一些参数和另一些参数的更新速度(学习率learning rate)不一样,最好知道这些参数的名称都有什么:

# 载入预训练模型参数后...
for name, value in model.named_parameters():
  print(name)
# 或
print(model.state_dict().keys())

假设该模型中有encoder,viewer和decoder两部分,参数名称分别是:

'encoder.visual_emb.0.weight',
'encoder.visual_emb.0.bias',
'viewer.bd.Wsi',
'viewer.bd.bias',
'decoder.core.layer_0.weight_ih',
'decoder.core.layer_0.weight_hh',

假设要求encode、viewer的学习率为1e-6, decoder的学习率为1e-4,那么在将参数传入优化器时:

ignored_params = list(map(id, model.decoder.parameters()))
base_params = filter(lambda p: id(p) not in ignored_params, model.parameters())
optimizer = torch.optim.Adam([{'params':base_params,'lr':1e-6},
               {'params':model.decoder.parameters()}
               ],
               lr=1e-4, momentum=0.9)

代码的结果是除decoder参数的learning_rate=1e-4 外,其他参数的额learning_rate=1e-6。

在传入optimizer时,和一般的传参方法torch.optim.Adam(model.parameters(), lr=xxx) 不同,参数部分用了一个list, list的每个元素有params和lr两个键值。如果没有 lr则应用Adam的lr属性。Adam的属性除了lr, 其他都是参数所共有的(比如momentum)。

以上这篇pytorch载入预训练模型后,实现训练指定层就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

参考:

Python 相关文章推荐
使用Node.js和Socket.IO扩展Django的实时处理功能
Apr 20 Python
Python合并两个字典的常用方法与效率比较
Jun 17 Python
浅谈python为什么不需要三目运算符和switch
Jun 17 Python
python接口自动化(十七)--Json 数据处理---一次爬坑记(详解)
Apr 18 Python
python中break、continue 、exit() 、pass终止循环的区别详解
Jul 08 Python
Python循环实现n的全排列功能
Sep 16 Python
Tensorflow进行多维矩阵的拆分与拼接实例
Feb 07 Python
使用Python+selenium实现第一个自动化测试脚本
Mar 17 Python
Python要求O(n)复杂度求无序列表中第K的大元素实例
Apr 02 Python
Windows下Anaconda和PyCharm的安装与使用详解
Apr 23 Python
python装饰器代码深入讲解
Mar 01 Python
Python学习之异常中的finally使用详解
Mar 16 Python
python与mysql数据库交互的实现
Jan 06 #Python
win10系统下python3安装及pip换源和使用教程
Jan 06 #Python
基于python实现文件加密功能
Jan 06 #Python
Pytorch 实现冻结指定卷积层的参数
Jan 06 #Python
如何使用python实现模拟鼠标点击
Jan 06 #Python
pytorch 实现查看网络中的参数
Jan 06 #Python
Python3 虚拟开发环境搭建过程(图文详解)
Jan 06 #Python
You might like
PHP 开源AJAX框架14种
2009/08/24 PHP
第七章 php自定义函数实现代码
2011/12/30 PHP
获取php页面执行时间,数据库读写次数,函数调用次数等(THINKphp)
2013/06/03 PHP
PHP的一个完美GIF等比缩放类,附带去除缩放黑背景
2014/04/01 PHP
php将数组转换成csv格式文件输出的方法
2015/03/14 PHP
解决PHP里大量数据循环时内存耗尽的方法
2015/10/10 PHP
golang与PHP输出excel示例
2016/07/22 PHP
浅谈php中urlencode与rawurlencode的区别
2016/09/05 PHP
javascript 文档的编码问题解决
2009/03/01 Javascript
Dom与浏览器兼容性说明
2010/10/25 Javascript
JS,Jquery获取select,dropdownlist,checkbox下拉列表框的值(示例代码)
2014/01/11 Javascript
微信小程序 swiper制作tab切换实现附源码
2017/01/21 Javascript
vue绑定设置属性的多种方式(5)
2017/08/16 Javascript
使用element-ui +Vue 解决 table 里包含表单验证的问题
2020/07/17 Javascript
js实现批量删除功能
2020/08/27 Javascript
python实现得到一个给定类的虚函数
2014/09/28 Python
浅谈python可视化包Bokeh
2018/02/07 Python
django解决跨域请求的问题详解
2019/01/20 Python
python实现的爬取电影下载链接功能示例
2019/08/26 Python
pytorch中的卷积和池化计算方式详解
2020/01/03 Python
CSS3 box-sizing属性
2009/04/17 HTML / CSS
一款利用纯css3实现的超炫3D表单的实例教程
2014/12/01 HTML / CSS
阿根廷票务网站:StubHub阿根廷
2018/04/13 全球购物
水上运动奥特莱斯:Wasterports Outlet
2018/08/08 全球购物
世嘉游戏英国官方商店:SEGA Shop UK
2019/09/20 全球购物
父母寄语大全
2014/04/12 职场文书
《白鹅》教学反思
2014/04/13 职场文书
终止劳动合同协议书
2014/04/14 职场文书
《海伦?凯勒》教学反思
2014/04/17 职场文书
八一建军节演讲稿
2014/09/10 职场文书
走群众路线学习心得体会
2014/10/31 职场文书
个人优缺点总结
2015/02/28 职场文书
困难补助申请报告
2015/05/19 职场文书
2016年度创先争优活动总结
2016/04/05 职场文书
用 Python 元类的特性实现 ORM 框架
2021/05/19 Python
python_tkinter事件类型详情
2022/03/20 Python