pytorch 预训练模型读取修改相关参数的填坑问题


Posted in Python onJune 05, 2021

pytorch 预训练模型读取修改相关参数的填坑

修改部分层,仍然调用之前的模型参数。

resnet = resnet50(pretrained=False)
resnet.load_state_dict(torch.load(args.predir))
 
res_conv31 = Bottleneck_dilated(1024, 256,dilated_rate = 2)
print("---------------------",res_conv31)
print("---------------------",resnet.layer3[1])
 
res_conv31.load_state_dict(resnet.layer3[1].state_dict())

网络预训练模型与之前的模型对应不上,名称差个前缀

model_dict = model.state_dict()
# print(model_dict)
pretrained_dict = torch.load("/yzc/reid_testpcb/se_resnet50-ce0d4300.pth")
keys = []
for k, v in pretrained_dict.items():
       keys.append(k)
i = 0
for k, v in model_dict.items():
    if v.size() == pretrained_dict[keys[i]].size():
         model_dict[k] = pretrained_dict[keys[i]]
         #print(model_dict[k])
         i = i + 1
model.load_state_dict(model_dict)

最后是修改参数名拿来用的,

from collections import OrderedDict
pretrained_dict = torch.load('premodel')
 
new_state_dict = OrderedDict()
 
# for k, v in mgn_state_dict.items():
#     name = k[7:]  # remove `module.`
#     new_state_dict[name] = v
# self.model = self.model.load_state_dict(new_state_dict)
 
for k, v in pretrained_dict.items():
    name = "model.module."+k   # remove `module.`
    # print(name)
    new_state_dict[name] = v
self.model.load_state_dict(new_state_dict)

pytorch:加载预训练模型中的部分参数,并固定该部分参数(真实有效)

大家在学习pytorch时,可能想利用pytorch进行fine-tune,但是又烦恼于参数的加载问题。下面我将讲诉我的使用心得。

Step1: 加载预训练模型,并去除需要再次训练的层

#注意:需要重新训练的层的名字要和之前的不同。
model=resnet()#自己构建的模型,以resnet为例
model_dict = model.state_dict()
pretrained_dict = torch.load('xxx.pkl')
pretrained_dict = {k: v for k, v in pretrained_dict.items() if k in model_dict}
model_dict.update(pretrained_dict)
model.load_state_dict(model_dict)

Step2:固定部分参数

#k是可训练参数的名字,v是包含可训练参数的一个实体
#可以先print(k),找到自己想进行调整的层,并将该层的名字加入到if语句中:
for k,v in model.named_parameters():
    if k!='xxx.weight' and k!='xxx.bias' :
        v.requires_grad=False#固定参数

Step3:训练部分参数

#将要训练的参数放入优化器
optimizer2=torch.optim.Adam(params=[model.xxx.weight,model.xxx.bias],lr=learning_rate,betas=(0.9,0.999),weight_decay=1e-5)

Step4:检查部分参数是否固定

debug之后,程序正常运行,最好检查一下网络的参数是否真的被固定了,如何没固定,网络的状态接近于重新训练,可能会导致网络性能不稳定,也没办法得到想要得到的性能提升。

for k,v in model.named_parameters():
   if k!='xxx.weight' and k!='xxx.bias' :
   print(v.requires_grad)#理想状态下,所有值都是False

需要注意的是,操作失误最大的影响是,loss函数几乎不会发生变化,一直处于最开始的状态,这很可能是因为所有参数都被固定了。

以上为个人经验,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
python解决字典中的值是列表问题的方法
Mar 04 Python
python实现基于两张图片生成圆角图标效果的方法
Mar 26 Python
在Python中使用pngquant压缩png图片的教程
Apr 09 Python
浅析Python中将单词首字母大写的capitalize()方法
May 18 Python
web.py 十分钟创建简易博客实现代码
Apr 22 Python
Python中set与frozenset方法和区别详解
May 23 Python
Python3使用正则表达式爬取内涵段子示例
Apr 22 Python
Python多重继承的方法解析执行顺序实例分析
May 26 Python
使用Python向C语言的链接库传递数组、结构体、指针类型的数据
Jan 29 Python
python实现的按要求生成手机号功能示例
Oct 08 Python
python 五子棋如何获得鼠标点击坐标
Nov 04 Python
Python Matplotlib绘制动画的代码详解
May 30 Python
解决pytorch 损失函数中输入输出不匹配的问题
Jun 05 #Python
Pytorch distributed 多卡并行载入模型操作
Jun 05 #Python
Pytorch中的学习率衰减及其用法详解
Jun 05 #Python
pytorch finetuning 自己的图片进行训练操作
Jun 05 #Python
Python 如何将integer转化为罗马数(3999以内)
Jun 05 #Python
刚学完怎么用Python实现定时任务,转头就跑去撩妹!
OpenCV中resize函数插值算法的实现过程(五种)
Jun 05 #Python
You might like
php面向对象全攻略 (七) 继承性
2009/09/30 PHP
php日期操作技巧小结
2016/06/25 PHP
基于Laravel5.4实现多字段登录功能方法示例
2017/08/11 PHP
Javascript 更新 JavaScript 数组的 uniq 方法
2008/01/23 Javascript
HTML5之lang属性与dir属性的详解
2013/06/19 Javascript
js的2种继承方式详解
2014/03/04 Javascript
JS+CSS实现带有碰撞缓冲效果的竖向导航条代码
2015/09/15 Javascript
CSS中position属性之fixed实现div居中
2015/12/14 Javascript
javascript运动效果实例总结(放大缩小、滑动淡入、滚动)
2016/01/08 Javascript
JavaScript驾驭网页-获取网页元素
2016/03/24 Javascript
再谈javascript常见错误及解决方法
2016/09/16 Javascript
详解AngularJS中的表单验证(推荐)
2016/11/17 Javascript
Webpack实现按需打包Lodash的几种方法详解
2017/05/08 Javascript
Angular指令之restict匹配模式的详解
2017/07/27 Javascript
nodejs 图解express+supervisor+ejs的用法(推荐)
2017/09/08 NodeJs
seaJs使用心得之exports与module.exports的区别实例分析
2017/10/13 Javascript
javascript自定义右键菜单插件
2019/12/16 Javascript
vue实现防抖的实例代码
2021/01/11 Vue.js
Python多线程同步Lock、RLock、Semaphore、Event实例
2014/11/21 Python
django的ORM操作 增加和查询
2019/07/26 Python
Django中使用极验Geetest滑动验证码过程解析
2019/07/31 Python
python sitk.show()与imageJ结合使用常见的问题
2020/04/20 Python
python使用for...else跳出双层嵌套循环的方法实例
2020/05/17 Python
Django Path转换器自定义及正则代码实例
2020/05/29 Python
Python常用模块函数代码汇总解析
2020/08/31 Python
html5-Canvas可以在web中绘制各种图形
2012/12/26 HTML / CSS
土建资料员岗位职责
2014/01/04 职场文书
物流管理毕业生自荐信范文
2014/03/15 职场文书
三年级小学生评语
2014/04/22 职场文书
党员服务承诺书
2014/05/28 职场文书
电力工程合作意向书
2015/05/11 职场文书
2016廉洁从政心得体会
2016/01/19 职场文书
Nginx如何配置Http、Https、WS、WSS的方法步骤
2021/05/11 Servers
Vue和Flask通信的实现
2021/05/19 Vue.js
叶县这家生产军用电台的兵工厂,人称“四机部”,走出一上将
2022/02/18 无线电
springmvc直接不经过controller访问WEB-INF中的页面问题
2022/02/24 Java/Android