pytorch 预训练模型读取修改相关参数的填坑问题


Posted in Python onJune 05, 2021

pytorch 预训练模型读取修改相关参数的填坑

修改部分层,仍然调用之前的模型参数。

resnet = resnet50(pretrained=False)
resnet.load_state_dict(torch.load(args.predir))
 
res_conv31 = Bottleneck_dilated(1024, 256,dilated_rate = 2)
print("---------------------",res_conv31)
print("---------------------",resnet.layer3[1])
 
res_conv31.load_state_dict(resnet.layer3[1].state_dict())

网络预训练模型与之前的模型对应不上,名称差个前缀

model_dict = model.state_dict()
# print(model_dict)
pretrained_dict = torch.load("/yzc/reid_testpcb/se_resnet50-ce0d4300.pth")
keys = []
for k, v in pretrained_dict.items():
       keys.append(k)
i = 0
for k, v in model_dict.items():
    if v.size() == pretrained_dict[keys[i]].size():
         model_dict[k] = pretrained_dict[keys[i]]
         #print(model_dict[k])
         i = i + 1
model.load_state_dict(model_dict)

最后是修改参数名拿来用的,

from collections import OrderedDict
pretrained_dict = torch.load('premodel')
 
new_state_dict = OrderedDict()
 
# for k, v in mgn_state_dict.items():
#     name = k[7:]  # remove `module.`
#     new_state_dict[name] = v
# self.model = self.model.load_state_dict(new_state_dict)
 
for k, v in pretrained_dict.items():
    name = "model.module."+k   # remove `module.`
    # print(name)
    new_state_dict[name] = v
self.model.load_state_dict(new_state_dict)

pytorch:加载预训练模型中的部分参数,并固定该部分参数(真实有效)

大家在学习pytorch时,可能想利用pytorch进行fine-tune,但是又烦恼于参数的加载问题。下面我将讲诉我的使用心得。

Step1: 加载预训练模型,并去除需要再次训练的层

#注意:需要重新训练的层的名字要和之前的不同。
model=resnet()#自己构建的模型,以resnet为例
model_dict = model.state_dict()
pretrained_dict = torch.load('xxx.pkl')
pretrained_dict = {k: v for k, v in pretrained_dict.items() if k in model_dict}
model_dict.update(pretrained_dict)
model.load_state_dict(model_dict)

Step2:固定部分参数

#k是可训练参数的名字,v是包含可训练参数的一个实体
#可以先print(k),找到自己想进行调整的层,并将该层的名字加入到if语句中:
for k,v in model.named_parameters():
    if k!='xxx.weight' and k!='xxx.bias' :
        v.requires_grad=False#固定参数

Step3:训练部分参数

#将要训练的参数放入优化器
optimizer2=torch.optim.Adam(params=[model.xxx.weight,model.xxx.bias],lr=learning_rate,betas=(0.9,0.999),weight_decay=1e-5)

Step4:检查部分参数是否固定

debug之后,程序正常运行,最好检查一下网络的参数是否真的被固定了,如何没固定,网络的状态接近于重新训练,可能会导致网络性能不稳定,也没办法得到想要得到的性能提升。

for k,v in model.named_parameters():
   if k!='xxx.weight' and k!='xxx.bias' :
   print(v.requires_grad)#理想状态下,所有值都是False

需要注意的是,操作失误最大的影响是,loss函数几乎不会发生变化,一直处于最开始的状态,这很可能是因为所有参数都被固定了。

以上为个人经验,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
python爬虫教程之爬取百度贴吧并下载的示例
Mar 07 Python
使用 Python 获取 Linux 系统信息的代码
Jul 13 Python
hmac模块生成加入了密钥的消息摘要详解
Jan 11 Python
pygame游戏之旅 计算游戏中躲过的障碍数量
Nov 20 Python
自定义django admin model表单提交的例子
Aug 23 Python
Python破解BiliBili滑块验证码的思路详解(完美避开人机识别)
Feb 17 Python
手把手教你安装Windows版本的Tensorflow
Mar 26 Python
pyinstaller打包成无控制台程序时运行出错(与popen冲突的解决方法)
Apr 15 Python
哈工大自然语言处理工具箱之ltp在windows10下的安装使用教程
May 07 Python
Python分析最近大火的网剧《隐秘的角落》
Jul 02 Python
图解Python中深浅copy(通俗易懂)
Sep 03 Python
Python 虚拟环境工作原理解析
Dec 24 Python
解决pytorch 损失函数中输入输出不匹配的问题
Jun 05 #Python
Pytorch distributed 多卡并行载入模型操作
Jun 05 #Python
Pytorch中的学习率衰减及其用法详解
Jun 05 #Python
pytorch finetuning 自己的图片进行训练操作
Jun 05 #Python
Python 如何将integer转化为罗马数(3999以内)
Jun 05 #Python
刚学完怎么用Python实现定时任务,转头就跑去撩妹!
OpenCV中resize函数插值算法的实现过程(五种)
Jun 05 #Python
You might like
《DOTA3》开发工作已经开始 《DOTA3》将代替《DOTA2》
2021/03/06 DOTA
一个简易需要注册的留言版程序
2006/10/09 PHP
PHP Session_Regenerate_ID函数双释放内存破坏漏洞
2011/01/27 PHP
PHP递归返回值时出现的问题解决办法
2013/02/19 PHP
常见php数据文件缓存类汇总
2014/12/05 PHP
PHP中ajax无刷新上传图片与图片下载功能
2017/02/21 PHP
PHP使用PDO实现mysql防注入功能详解
2019/12/20 PHP
gearman中worker常驻后台,导致MySQL server has gone away的解决方法
2020/02/27 PHP
Nigma vs AM BO3 第二场2.13
2021/03/10 DOTA
基于jsTree的无限级树JSON数据的转换代码
2010/07/27 Javascript
Dom 是什么的详细说明
2010/10/25 Javascript
jquery 实现checkbox全选,反选,全不选等功能代码(奇数)
2012/10/24 Javascript
Javascript中的for in循环和hasOwnProperty结合使用
2013/06/05 Javascript
javascript设计简单的秒表计时器
2020/09/05 Javascript
JavaScript函数内部属性和函数方法实例详解
2016/03/17 Javascript
详解React中的组件通信问题
2017/07/31 Javascript
vue-cli构建项目使用 less的方法
2017/10/04 Javascript
使用vuex的state状态对象的5种方式
2018/04/19 Javascript
灵活使用console让js调试更简单的方法步骤
2019/04/23 Javascript
Vue仿Bibibili首页的问题
2021/01/21 Vue.js
[58:21]DOTA2亚洲邀请赛 4.3 突围赛 Liquid vs VGJ.T 第二场
2018/04/04 DOTA
pycharm 使用心得(九)解决No Python interpreter selected的问题
2014/06/06 Python
python读写ini配置文件方法实例分析
2015/06/30 Python
python 日志增量抓取实现方法
2018/04/28 Python
python用BeautifulSoup库简单爬虫实例分析
2018/07/30 Python
使用Python进行中文繁简转换的实现代码
2019/10/18 Python
用Python实现校园通知更新提醒功能
2019/11/23 Python
css3 background属性调整增强介绍
2010/12/18 HTML / CSS
HTML5实现直播间评论滚动效果的代码
2020/05/27 HTML / CSS
Linux管理员面试经常问道的相关命令
2013/04/29 面试题
疾病捐款倡议书
2014/05/13 职场文书
机械工程师岗位职责
2014/06/16 职场文书
2016北大自主招生自荐信模板
2016/01/28 职场文书
Java 实战项目之家居购物商城系统详解流程
2021/11/11 Java/Android
Golang入门之计时器
2022/05/04 Golang
苹果macOS 13开发者预览版Beta 8发布 正式版10月发布
2022/09/23 数码科技