pytorch 预训练模型读取修改相关参数的填坑问题


Posted in Python onJune 05, 2021

pytorch 预训练模型读取修改相关参数的填坑

修改部分层,仍然调用之前的模型参数。

resnet = resnet50(pretrained=False)
resnet.load_state_dict(torch.load(args.predir))
 
res_conv31 = Bottleneck_dilated(1024, 256,dilated_rate = 2)
print("---------------------",res_conv31)
print("---------------------",resnet.layer3[1])
 
res_conv31.load_state_dict(resnet.layer3[1].state_dict())

网络预训练模型与之前的模型对应不上,名称差个前缀

model_dict = model.state_dict()
# print(model_dict)
pretrained_dict = torch.load("/yzc/reid_testpcb/se_resnet50-ce0d4300.pth")
keys = []
for k, v in pretrained_dict.items():
       keys.append(k)
i = 0
for k, v in model_dict.items():
    if v.size() == pretrained_dict[keys[i]].size():
         model_dict[k] = pretrained_dict[keys[i]]
         #print(model_dict[k])
         i = i + 1
model.load_state_dict(model_dict)

最后是修改参数名拿来用的,

from collections import OrderedDict
pretrained_dict = torch.load('premodel')
 
new_state_dict = OrderedDict()
 
# for k, v in mgn_state_dict.items():
#     name = k[7:]  # remove `module.`
#     new_state_dict[name] = v
# self.model = self.model.load_state_dict(new_state_dict)
 
for k, v in pretrained_dict.items():
    name = "model.module."+k   # remove `module.`
    # print(name)
    new_state_dict[name] = v
self.model.load_state_dict(new_state_dict)

pytorch:加载预训练模型中的部分参数,并固定该部分参数(真实有效)

大家在学习pytorch时,可能想利用pytorch进行fine-tune,但是又烦恼于参数的加载问题。下面我将讲诉我的使用心得。

Step1: 加载预训练模型,并去除需要再次训练的层

#注意:需要重新训练的层的名字要和之前的不同。
model=resnet()#自己构建的模型,以resnet为例
model_dict = model.state_dict()
pretrained_dict = torch.load('xxx.pkl')
pretrained_dict = {k: v for k, v in pretrained_dict.items() if k in model_dict}
model_dict.update(pretrained_dict)
model.load_state_dict(model_dict)

Step2:固定部分参数

#k是可训练参数的名字,v是包含可训练参数的一个实体
#可以先print(k),找到自己想进行调整的层,并将该层的名字加入到if语句中:
for k,v in model.named_parameters():
    if k!='xxx.weight' and k!='xxx.bias' :
        v.requires_grad=False#固定参数

Step3:训练部分参数

#将要训练的参数放入优化器
optimizer2=torch.optim.Adam(params=[model.xxx.weight,model.xxx.bias],lr=learning_rate,betas=(0.9,0.999),weight_decay=1e-5)

Step4:检查部分参数是否固定

debug之后,程序正常运行,最好检查一下网络的参数是否真的被固定了,如何没固定,网络的状态接近于重新训练,可能会导致网络性能不稳定,也没办法得到想要得到的性能提升。

for k,v in model.named_parameters():
   if k!='xxx.weight' and k!='xxx.bias' :
   print(v.requires_grad)#理想状态下,所有值都是False

需要注意的是,操作失误最大的影响是,loss函数几乎不会发生变化,一直处于最开始的状态,这很可能是因为所有参数都被固定了。

以上为个人经验,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
python使用append合并两个数组的方法
Apr 28 Python
Python实现读取文件最后n行的方法
Feb 23 Python
Python实现SSH远程登陆,并执行命令的方法(分享)
May 08 Python
Python实现读取txt文件并画三维图简单代码示例
Dec 09 Python
对numpy和pandas中数组的合并和拆分详解
Apr 11 Python
Django中更改默认数据库为mysql的方法示例
Dec 05 Python
python 反编译exe文件为py文件的实例代码
Jun 27 Python
Python CSS选择器爬取京东网商品信息过程解析
Jun 01 Python
什么是python类属性
Jun 10 Python
Python如何解除一个装饰器
Aug 07 Python
基于Python实现体育彩票选号器功能代码实例
Sep 16 Python
如何在Python中创建二叉树
Mar 30 Python
解决pytorch 损失函数中输入输出不匹配的问题
Jun 05 #Python
Pytorch distributed 多卡并行载入模型操作
Jun 05 #Python
Pytorch中的学习率衰减及其用法详解
Jun 05 #Python
pytorch finetuning 自己的图片进行训练操作
Jun 05 #Python
Python 如何将integer转化为罗马数(3999以内)
Jun 05 #Python
刚学完怎么用Python实现定时任务,转头就跑去撩妹!
OpenCV中resize函数插值算法的实现过程(五种)
Jun 05 #Python
You might like
PHP导出MySQL数据到Excel文件(fputcsv)
2011/07/03 PHP
PHP error_log()将错误信息写入一个文件(定义和用法)
2013/10/25 PHP
php面向对象重点知识分享
2019/09/27 PHP
PHP7移除的扩展和SAPI
2021/03/09 PHP
javascript 放大镜效果js组件 qsoft.PopBigImage.v0.35 加入了chrome支持
2009/04/07 Javascript
jquery 3D球状导航的文章分类
2010/07/06 Javascript
Javascript拓展String方法小结
2013/07/08 Javascript
JavaScript中的包装对象介绍
2015/01/27 Javascript
jquery中EasyUI实现异步树
2015/03/01 Javascript
JavaScript的Vue.js库入门学习教程
2016/05/23 Javascript
判断数组的最佳方法(推荐)
2016/10/11 Javascript
原生js实现返回顶部缓冲效果
2017/01/18 Javascript
深入解析js轮播插件核心代码的实现过程
2017/04/14 Javascript
JS身份证信息验证正则表达式
2017/06/12 Javascript
在vue中使用v-bind:class的选项卡方法
2018/09/27 Javascript
Vue scrollBehavior 滚动行为实现后退页面显示在上次浏览的位置
2019/05/27 Javascript
js实现3D旋转效果
2020/08/18 Javascript
[03:28]2014DOTA2国际邀请赛 走近EG战队天才中单Arteezy
2014/07/12 DOTA
[01:33]完美世界DOTA2联赛PWL S3 集锦第二期
2020/12/21 DOTA
Python实现屏幕截图的代码及函数详解
2016/10/01 Python
python urllib urlopen()对象方法/代理的补充说明
2017/06/29 Python
python3利用Dlib19.7实现人脸68个特征点标定
2018/02/26 Python
对pandas中时间窗函数rolling的使用详解
2018/11/28 Python
Django实现文件上传下载功能
2019/10/06 Python
python深copy和浅copy区别对比解析
2019/12/26 Python
python numpy实现rolling滚动案例
2020/06/08 Python
python的pip有什么用
2020/06/17 Python
大学同学聚会邀请函
2014/01/19 职场文书
药品促销活动方案
2014/02/14 职场文书
暑期培训随笔感言
2014/03/10 职场文书
经典禁毒标语
2014/06/16 职场文书
党员群众路线对照检查材料
2014/08/31 职场文书
领导干部群众路线剖析材料
2014/10/09 职场文书
2015年消防工作总结
2015/04/24 职场文书
大学生社会服务心得体会
2016/01/22 职场文书
Java 常见的限流算法详细分析并实现
2022/04/07 Java/Android