pytorch 预训练模型读取修改相关参数的填坑问题


Posted in Python onJune 05, 2021

pytorch 预训练模型读取修改相关参数的填坑

修改部分层,仍然调用之前的模型参数。

resnet = resnet50(pretrained=False)
resnet.load_state_dict(torch.load(args.predir))
 
res_conv31 = Bottleneck_dilated(1024, 256,dilated_rate = 2)
print("---------------------",res_conv31)
print("---------------------",resnet.layer3[1])
 
res_conv31.load_state_dict(resnet.layer3[1].state_dict())

网络预训练模型与之前的模型对应不上,名称差个前缀

model_dict = model.state_dict()
# print(model_dict)
pretrained_dict = torch.load("/yzc/reid_testpcb/se_resnet50-ce0d4300.pth")
keys = []
for k, v in pretrained_dict.items():
       keys.append(k)
i = 0
for k, v in model_dict.items():
    if v.size() == pretrained_dict[keys[i]].size():
         model_dict[k] = pretrained_dict[keys[i]]
         #print(model_dict[k])
         i = i + 1
model.load_state_dict(model_dict)

最后是修改参数名拿来用的,

from collections import OrderedDict
pretrained_dict = torch.load('premodel')
 
new_state_dict = OrderedDict()
 
# for k, v in mgn_state_dict.items():
#     name = k[7:]  # remove `module.`
#     new_state_dict[name] = v
# self.model = self.model.load_state_dict(new_state_dict)
 
for k, v in pretrained_dict.items():
    name = "model.module."+k   # remove `module.`
    # print(name)
    new_state_dict[name] = v
self.model.load_state_dict(new_state_dict)

pytorch:加载预训练模型中的部分参数,并固定该部分参数(真实有效)

大家在学习pytorch时,可能想利用pytorch进行fine-tune,但是又烦恼于参数的加载问题。下面我将讲诉我的使用心得。

Step1: 加载预训练模型,并去除需要再次训练的层

#注意:需要重新训练的层的名字要和之前的不同。
model=resnet()#自己构建的模型,以resnet为例
model_dict = model.state_dict()
pretrained_dict = torch.load('xxx.pkl')
pretrained_dict = {k: v for k, v in pretrained_dict.items() if k in model_dict}
model_dict.update(pretrained_dict)
model.load_state_dict(model_dict)

Step2:固定部分参数

#k是可训练参数的名字,v是包含可训练参数的一个实体
#可以先print(k),找到自己想进行调整的层,并将该层的名字加入到if语句中:
for k,v in model.named_parameters():
    if k!='xxx.weight' and k!='xxx.bias' :
        v.requires_grad=False#固定参数

Step3:训练部分参数

#将要训练的参数放入优化器
optimizer2=torch.optim.Adam(params=[model.xxx.weight,model.xxx.bias],lr=learning_rate,betas=(0.9,0.999),weight_decay=1e-5)

Step4:检查部分参数是否固定

debug之后,程序正常运行,最好检查一下网络的参数是否真的被固定了,如何没固定,网络的状态接近于重新训练,可能会导致网络性能不稳定,也没办法得到想要得到的性能提升。

for k,v in model.named_parameters():
   if k!='xxx.weight' and k!='xxx.bias' :
   print(v.requires_grad)#理想状态下,所有值都是False

需要注意的是,操作失误最大的影响是,loss函数几乎不会发生变化,一直处于最开始的状态,这很可能是因为所有参数都被固定了。

以上为个人经验,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
使用Python脚本对Linux服务器进行监控的教程
Apr 02 Python
python中assert用法实例分析
Apr 30 Python
详解Python中for循环的使用方法
May 14 Python
用生成器来改写直接返回列表的函数方法
May 25 Python
python3.6连接MySQL和表的创建与删除实例代码
Dec 28 Python
Python实现的rsa加密算法详解
Jan 24 Python
python微信跳一跳游戏辅助代码解析
Jan 29 Python
详解python OpenCV学习笔记之直方图均衡化
Feb 08 Python
深入理解Django的中间件middleware
Mar 14 Python
在Qt5和PyQt5中设置支持高分辨率屏幕自适应的方法
Jun 18 Python
如何利用Pyecharts可视化微信好友
Jul 04 Python
解决Opencv+Python cv2.imshow闪退问题
Apr 24 Python
解决pytorch 损失函数中输入输出不匹配的问题
Jun 05 #Python
Pytorch distributed 多卡并行载入模型操作
Jun 05 #Python
Pytorch中的学习率衰减及其用法详解
Jun 05 #Python
pytorch finetuning 自己的图片进行训练操作
Jun 05 #Python
Python 如何将integer转化为罗马数(3999以内)
Jun 05 #Python
刚学完怎么用Python实现定时任务,转头就跑去撩妹!
OpenCV中resize函数插值算法的实现过程(五种)
Jun 05 #Python
You might like
php UTF-8、Unicode和BOM问题
2010/05/18 PHP
php使用Jpgraph绘制复杂X-Y坐标图的方法
2015/06/10 PHP
PHP微信开发之微信消息自动回复下所遇到的坑
2016/05/09 PHP
PHP中strpos、strstr和stripos、stristr函数分析
2016/06/11 PHP
php类的自动加载操作实例详解
2016/09/28 PHP
一个非常实用的php文件上传类
2017/07/04 PHP
解析arp病毒背后利用的Javascript技术附解密方法
2007/08/06 Javascript
在js中单选框和复选框获取值的方式
2009/11/06 Javascript
滚动图片效果 jquery实现回旋滚动效果
2013/01/08 Javascript
js实现简单选项卡与自动切换效果的方法
2015/04/10 Javascript
高性能JavaScript 重排与重绘(2)
2015/08/11 Javascript
JavaScript简单下拉菜单实例代码
2015/09/07 Javascript
微信支付 JS API支付接口详解
2016/07/11 Javascript
JavaScript实现水平进度条拖拽效果
2017/01/18 Javascript
bootstrap table动态加载数据示例代码
2017/03/25 Javascript
微信小程序实战之自定义toast(6)
2017/04/18 Javascript
node.js中fs文件系统目录操作与文件信息操作
2018/02/24 Javascript
详解vue指令与$nextTick 操作DOM的不同之处
2018/08/02 Javascript
Node.js Buffer模块功能及常用方法实例分析
2019/01/05 Javascript
利用Electron简单撸一个Markdown编辑器的方法
2019/06/10 Javascript
Vue基于iview实现登录密码的显示与隐藏功能
2020/03/06 Javascript
[45:59]完美世界DOTA2联赛PWL S2 FTD vs GXR 第二场 11.22
2020/11/24 DOTA
python 生成器生成杨辉三角的方法(必看)
2017/04/10 Python
CentOS 6.5下安装Python 3.5.2(与Python2并存)
2017/06/05 Python
Python实现修改文件内容的方法分析
2018/03/25 Python
python实现简易动态时钟
2018/11/19 Python
python网络应用开发知识点浅析
2019/05/28 Python
python如何提取英语pdf内容并翻译
2020/03/03 Python
详解Python GUI编程之PyQt5入门到实战
2020/12/10 Python
Shopee印度尼西亚:东南亚与台湾市场最大电商平台
2018/06/17 全球购物
十八届三中全会报告学习材料
2014/02/17 职场文书
法学专业毕业生自荐信
2014/06/11 职场文书
九寨沟导游词
2015/02/02 职场文书
《全神贯注》教学反思
2016/02/22 职场文书
员工安全责任协议书
2016/03/22 职场文书
2019让人心动的商业计划书
2019/06/27 职场文书