pytorch 预训练模型读取修改相关参数的填坑问题


Posted in Python onJune 05, 2021

pytorch 预训练模型读取修改相关参数的填坑

修改部分层,仍然调用之前的模型参数。

resnet = resnet50(pretrained=False)
resnet.load_state_dict(torch.load(args.predir))
 
res_conv31 = Bottleneck_dilated(1024, 256,dilated_rate = 2)
print("---------------------",res_conv31)
print("---------------------",resnet.layer3[1])
 
res_conv31.load_state_dict(resnet.layer3[1].state_dict())

网络预训练模型与之前的模型对应不上,名称差个前缀

model_dict = model.state_dict()
# print(model_dict)
pretrained_dict = torch.load("/yzc/reid_testpcb/se_resnet50-ce0d4300.pth")
keys = []
for k, v in pretrained_dict.items():
       keys.append(k)
i = 0
for k, v in model_dict.items():
    if v.size() == pretrained_dict[keys[i]].size():
         model_dict[k] = pretrained_dict[keys[i]]
         #print(model_dict[k])
         i = i + 1
model.load_state_dict(model_dict)

最后是修改参数名拿来用的,

from collections import OrderedDict
pretrained_dict = torch.load('premodel')
 
new_state_dict = OrderedDict()
 
# for k, v in mgn_state_dict.items():
#     name = k[7:]  # remove `module.`
#     new_state_dict[name] = v
# self.model = self.model.load_state_dict(new_state_dict)
 
for k, v in pretrained_dict.items():
    name = "model.module."+k   # remove `module.`
    # print(name)
    new_state_dict[name] = v
self.model.load_state_dict(new_state_dict)

pytorch:加载预训练模型中的部分参数,并固定该部分参数(真实有效)

大家在学习pytorch时,可能想利用pytorch进行fine-tune,但是又烦恼于参数的加载问题。下面我将讲诉我的使用心得。

Step1: 加载预训练模型,并去除需要再次训练的层

#注意:需要重新训练的层的名字要和之前的不同。
model=resnet()#自己构建的模型,以resnet为例
model_dict = model.state_dict()
pretrained_dict = torch.load('xxx.pkl')
pretrained_dict = {k: v for k, v in pretrained_dict.items() if k in model_dict}
model_dict.update(pretrained_dict)
model.load_state_dict(model_dict)

Step2:固定部分参数

#k是可训练参数的名字,v是包含可训练参数的一个实体
#可以先print(k),找到自己想进行调整的层,并将该层的名字加入到if语句中:
for k,v in model.named_parameters():
    if k!='xxx.weight' and k!='xxx.bias' :
        v.requires_grad=False#固定参数

Step3:训练部分参数

#将要训练的参数放入优化器
optimizer2=torch.optim.Adam(params=[model.xxx.weight,model.xxx.bias],lr=learning_rate,betas=(0.9,0.999),weight_decay=1e-5)

Step4:检查部分参数是否固定

debug之后,程序正常运行,最好检查一下网络的参数是否真的被固定了,如何没固定,网络的状态接近于重新训练,可能会导致网络性能不稳定,也没办法得到想要得到的性能提升。

for k,v in model.named_parameters():
   if k!='xxx.weight' and k!='xxx.bias' :
   print(v.requires_grad)#理想状态下,所有值都是False

需要注意的是,操作失误最大的影响是,loss函数几乎不会发生变化,一直处于最开始的状态,这很可能是因为所有参数都被固定了。

以上为个人经验,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
python selenium UI自动化解决验证码的4种方法
Jan 05 Python
python实现机器人行走效果
Jan 29 Python
Python装饰器知识点补充
May 28 Python
Python爬虫实现(伪)球迷速成
Jun 10 Python
python write无法写入文件的解决方法
Jan 23 Python
Python使用get_text()方法从大段html中提取文本的实例
Aug 27 Python
Python计算机视觉里的IOU计算实例
Jan 17 Python
pytorch 修改预训练model实例
Jan 18 Python
python实现简单坦克大战
Mar 27 Python
Python 实现打印单词的菱形字符图案
Apr 12 Python
Python如何使用正则表达式爬取京东商品信息
Jun 01 Python
Python实现聚类K-means算法详解
Jul 15 Python
解决pytorch 损失函数中输入输出不匹配的问题
Jun 05 #Python
Pytorch distributed 多卡并行载入模型操作
Jun 05 #Python
Pytorch中的学习率衰减及其用法详解
Jun 05 #Python
pytorch finetuning 自己的图片进行训练操作
Jun 05 #Python
Python 如何将integer转化为罗马数(3999以内)
Jun 05 #Python
刚学完怎么用Python实现定时任务,转头就跑去撩妹!
OpenCV中resize函数插值算法的实现过程(五种)
Jun 05 #Python
You might like
php获取从百度搜索进入网站的关键词的详细代码
2014/01/08 PHP
php快速排序原理与实现方法分析
2016/05/26 PHP
php字符串操作针对负值的判断分析
2016/07/28 PHP
PHP 读取大文件并显示的简单实例(推荐)
2016/08/12 PHP
自制PHP框架之模型与数据库
2017/05/07 PHP
php微信开发之关注事件
2018/06/14 PHP
laravel + vue实现的数据统计绘图(今天、7天、30天数据)
2018/07/31 PHP
jquery中子元素和后代元素的区别示例介绍
2014/04/02 Javascript
js实现二代身份证号码验证详解
2014/11/20 Javascript
jQuery插件实现控制网页元素动态居中显示
2015/03/24 Javascript
jQuery插件开发精品教程让你的jQuery提升一个台阶
2016/01/27 Javascript
Javascript基础_简单比较undefined和null 值
2016/06/14 Javascript
浅谈js中对象的使用
2016/08/11 Javascript
原生JS实现获取及修改CSS样式的方法
2018/09/04 Javascript
koa-router源码学习小结
2018/09/07 Javascript
JS如何实现网站中PC端和手机端自动识别并跳转对应的代码
2020/01/08 Javascript
[04:27]DOTA2官方论坛水友赛集锦
2013/09/16 DOTA
Python获取脚本所在目录的正确方法
2014/04/15 Python
Python开发WebService系列教程之REST,web.py,eurasia,Django
2014/06/30 Python
Python使用MYSQLDB实现从数据库中导出XML文件的方法
2015/05/11 Python
详解Python各大聊天系统的屏蔽脏话功能原理
2016/12/01 Python
Python实现的简单dns查询功能示例
2017/05/24 Python
Python File readlines() 使用方法
2018/03/19 Python
python实现百度语音识别api
2018/04/10 Python
Python用字典构建多级菜单功能
2019/07/11 Python
解决使用export_graphviz可视化树报错的问题
2019/08/09 Python
wxPython实现列表增删改查功能
2019/11/19 Python
Scrapy模拟登录赶集网的实现代码
2020/07/07 Python
python Selenium 库的使用技巧
2020/10/16 Python
CSS3 特效范例整理
2011/08/22 HTML / CSS
生物技术专业毕业生求职信范文
2013/12/14 职场文书
适用于所有创业者的创业计划书
2014/02/05 职场文书
国庆横幅标语
2014/10/08 职场文书
检讨书大全
2015/01/27 职场文书
2015年感恩节活动总结
2015/03/24 职场文书
意外事故赔偿协议书
2016/03/22 职场文书