pytorch 预训练模型读取修改相关参数的填坑问题


Posted in Python onJune 05, 2021

pytorch 预训练模型读取修改相关参数的填坑

修改部分层,仍然调用之前的模型参数。

resnet = resnet50(pretrained=False)
resnet.load_state_dict(torch.load(args.predir))
 
res_conv31 = Bottleneck_dilated(1024, 256,dilated_rate = 2)
print("---------------------",res_conv31)
print("---------------------",resnet.layer3[1])
 
res_conv31.load_state_dict(resnet.layer3[1].state_dict())

网络预训练模型与之前的模型对应不上,名称差个前缀

model_dict = model.state_dict()
# print(model_dict)
pretrained_dict = torch.load("/yzc/reid_testpcb/se_resnet50-ce0d4300.pth")
keys = []
for k, v in pretrained_dict.items():
       keys.append(k)
i = 0
for k, v in model_dict.items():
    if v.size() == pretrained_dict[keys[i]].size():
         model_dict[k] = pretrained_dict[keys[i]]
         #print(model_dict[k])
         i = i + 1
model.load_state_dict(model_dict)

最后是修改参数名拿来用的,

from collections import OrderedDict
pretrained_dict = torch.load('premodel')
 
new_state_dict = OrderedDict()
 
# for k, v in mgn_state_dict.items():
#     name = k[7:]  # remove `module.`
#     new_state_dict[name] = v
# self.model = self.model.load_state_dict(new_state_dict)
 
for k, v in pretrained_dict.items():
    name = "model.module."+k   # remove `module.`
    # print(name)
    new_state_dict[name] = v
self.model.load_state_dict(new_state_dict)

pytorch:加载预训练模型中的部分参数,并固定该部分参数(真实有效)

大家在学习pytorch时,可能想利用pytorch进行fine-tune,但是又烦恼于参数的加载问题。下面我将讲诉我的使用心得。

Step1: 加载预训练模型,并去除需要再次训练的层

#注意:需要重新训练的层的名字要和之前的不同。
model=resnet()#自己构建的模型,以resnet为例
model_dict = model.state_dict()
pretrained_dict = torch.load('xxx.pkl')
pretrained_dict = {k: v for k, v in pretrained_dict.items() if k in model_dict}
model_dict.update(pretrained_dict)
model.load_state_dict(model_dict)

Step2:固定部分参数

#k是可训练参数的名字,v是包含可训练参数的一个实体
#可以先print(k),找到自己想进行调整的层,并将该层的名字加入到if语句中:
for k,v in model.named_parameters():
    if k!='xxx.weight' and k!='xxx.bias' :
        v.requires_grad=False#固定参数

Step3:训练部分参数

#将要训练的参数放入优化器
optimizer2=torch.optim.Adam(params=[model.xxx.weight,model.xxx.bias],lr=learning_rate,betas=(0.9,0.999),weight_decay=1e-5)

Step4:检查部分参数是否固定

debug之后,程序正常运行,最好检查一下网络的参数是否真的被固定了,如何没固定,网络的状态接近于重新训练,可能会导致网络性能不稳定,也没办法得到想要得到的性能提升。

for k,v in model.named_parameters():
   if k!='xxx.weight' and k!='xxx.bias' :
   print(v.requires_grad)#理想状态下,所有值都是False

需要注意的是,操作失误最大的影响是,loss函数几乎不会发生变化,一直处于最开始的状态,这很可能是因为所有参数都被固定了。

以上为个人经验,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
Python Sql数据库增删改查操作简单封装
Apr 18 Python
Python实现树莓派WiFi断线自动重连的实例代码
Mar 16 Python
Python多层装饰器用法实例分析
Feb 09 Python
Python3实现的简单验证码识别功能示例
May 02 Python
Python常用模块之requests模块用法分析
May 15 Python
用Python实现BP神经网络(附代码)
Jul 10 Python
Python CVXOPT模块安装及使用解析
Aug 01 Python
对Python中一维向量和一维向量转置相乘的方法详解
Aug 26 Python
python3 mmh3安装及使用方法
Oct 09 Python
用Python去除图像的黑色或白色背景实例
Dec 12 Python
Python模拟登入的N种方式(建议收藏)
May 31 Python
python实现PolynomialFeatures多项式的方法
Jan 06 Python
解决pytorch 损失函数中输入输出不匹配的问题
Jun 05 #Python
Pytorch distributed 多卡并行载入模型操作
Jun 05 #Python
Pytorch中的学习率衰减及其用法详解
Jun 05 #Python
pytorch finetuning 自己的图片进行训练操作
Jun 05 #Python
Python 如何将integer转化为罗马数(3999以内)
Jun 05 #Python
刚学完怎么用Python实现定时任务,转头就跑去撩妹!
OpenCV中resize函数插值算法的实现过程(五种)
Jun 05 #Python
You might like
将word转化为swf 如同百度文库般阅读实现思路及代码
2013/08/09 PHP
深入理解Yii2.0乐观锁与悲观锁的原理与使用
2017/07/26 PHP
php操作redis命令及代码实例大全
2020/11/19 PHP
javascript一点特殊用法
2008/05/28 Javascript
jQuery 使用手册(一)
2009/09/23 Javascript
兼容IE和Firefox的javascript获取iframe文档内容的函数
2011/08/15 Javascript
javascript右下角弹层及自动隐藏(自己编写)
2013/11/20 Javascript
jQuery遍历对象、数组、集合实例
2014/11/08 Javascript
浅谈js的setInterval事件
2014/12/05 Javascript
javascript实现详细时间提醒信息效果的方法
2015/03/11 Javascript
检测一个函数是否是JavaScript原生函数的小技巧
2015/03/13 Javascript
浅谈Jquery核心函数
2015/06/18 Javascript
JavaScript简单修改窗口大小的方法
2015/08/03 Javascript
AngularJS中的Directive实现延迟加载
2016/01/25 Javascript
jQuery实现模糊查询的方法分析
2018/05/10 jQuery
vue.js图片转Base64上传图片并预览的实现方法
2018/08/02 Javascript
CKEditor 4.4.1 添加代码高亮显示插件功能教程【使用官方推荐Code Snippet插件】
2019/06/14 Javascript
跟老齐学Python之dict()的操作方法
2014/09/24 Python
python快速查找算法应用实例
2014/09/26 Python
pandas 根据列的值选取所有行的示例
2018/11/07 Python
python自定义函数实现一个数的三次方计算方法
2019/01/20 Python
打包python 加icon 去掉cmd黑窗口方法
2019/06/24 Python
浅谈python已知元素,获取元素索引(numpy,pandas)
2019/11/26 Python
python如何使用socketserver模块实现并发聊天
2019/12/14 Python
Django配置Bootstrap, js实现过程详解
2020/10/13 Python
在阿尔卑斯山或希腊度过快乐假期:Alpine Elements
2019/12/28 全球购物
新西兰最大的天然保健及护肤品网站:HealthPost(直邮中国)
2021/02/13 全球购物
物业管理公司实习生自我鉴定
2013/09/19 职场文书
爱国卫生月活动总结范文
2014/04/25 职场文书
事业单位考核材料
2014/05/21 职场文书
民族团结先进集体事迹材料
2014/05/22 职场文书
真诚的求职信
2014/07/04 职场文书
道路交通事故赔偿协议书
2014/10/24 职场文书
导游词之上海东方明珠塔
2019/09/25 职场文书
详解CSS3.0(Cascading Style Sheet) 层叠级联样式表
2021/07/16 HTML / CSS
Pygame Time时间控制的具体使用详解
2021/11/17 Python