pytorch 预训练模型读取修改相关参数的填坑问题


Posted in Python onJune 05, 2021

pytorch 预训练模型读取修改相关参数的填坑

修改部分层,仍然调用之前的模型参数。

resnet = resnet50(pretrained=False)
resnet.load_state_dict(torch.load(args.predir))
 
res_conv31 = Bottleneck_dilated(1024, 256,dilated_rate = 2)
print("---------------------",res_conv31)
print("---------------------",resnet.layer3[1])
 
res_conv31.load_state_dict(resnet.layer3[1].state_dict())

网络预训练模型与之前的模型对应不上,名称差个前缀

model_dict = model.state_dict()
# print(model_dict)
pretrained_dict = torch.load("/yzc/reid_testpcb/se_resnet50-ce0d4300.pth")
keys = []
for k, v in pretrained_dict.items():
       keys.append(k)
i = 0
for k, v in model_dict.items():
    if v.size() == pretrained_dict[keys[i]].size():
         model_dict[k] = pretrained_dict[keys[i]]
         #print(model_dict[k])
         i = i + 1
model.load_state_dict(model_dict)

最后是修改参数名拿来用的,

from collections import OrderedDict
pretrained_dict = torch.load('premodel')
 
new_state_dict = OrderedDict()
 
# for k, v in mgn_state_dict.items():
#     name = k[7:]  # remove `module.`
#     new_state_dict[name] = v
# self.model = self.model.load_state_dict(new_state_dict)
 
for k, v in pretrained_dict.items():
    name = "model.module."+k   # remove `module.`
    # print(name)
    new_state_dict[name] = v
self.model.load_state_dict(new_state_dict)

pytorch:加载预训练模型中的部分参数,并固定该部分参数(真实有效)

大家在学习pytorch时,可能想利用pytorch进行fine-tune,但是又烦恼于参数的加载问题。下面我将讲诉我的使用心得。

Step1: 加载预训练模型,并去除需要再次训练的层

#注意:需要重新训练的层的名字要和之前的不同。
model=resnet()#自己构建的模型,以resnet为例
model_dict = model.state_dict()
pretrained_dict = torch.load('xxx.pkl')
pretrained_dict = {k: v for k, v in pretrained_dict.items() if k in model_dict}
model_dict.update(pretrained_dict)
model.load_state_dict(model_dict)

Step2:固定部分参数

#k是可训练参数的名字,v是包含可训练参数的一个实体
#可以先print(k),找到自己想进行调整的层,并将该层的名字加入到if语句中:
for k,v in model.named_parameters():
    if k!='xxx.weight' and k!='xxx.bias' :
        v.requires_grad=False#固定参数

Step3:训练部分参数

#将要训练的参数放入优化器
optimizer2=torch.optim.Adam(params=[model.xxx.weight,model.xxx.bias],lr=learning_rate,betas=(0.9,0.999),weight_decay=1e-5)

Step4:检查部分参数是否固定

debug之后,程序正常运行,最好检查一下网络的参数是否真的被固定了,如何没固定,网络的状态接近于重新训练,可能会导致网络性能不稳定,也没办法得到想要得到的性能提升。

for k,v in model.named_parameters():
   if k!='xxx.weight' and k!='xxx.bias' :
   print(v.requires_grad)#理想状态下,所有值都是False

需要注意的是,操作失误最大的影响是,loss函数几乎不会发生变化,一直处于最开始的状态,这很可能是因为所有参数都被固定了。

以上为个人经验,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
在Python中使用swapCase()方法转换大小写的教程
May 20 Python
Python多线程结合队列下载百度音乐的方法
Jul 27 Python
简单了解OpenCV是个什么东西
Nov 10 Python
python获取命令行输入参数列表的实例代码
Jun 23 Python
Python格式化输出字符串方法小结【%与format】
Oct 29 Python
python3中property使用方法详解
Apr 23 Python
pyqt5实现按钮添加背景图片以及背景图片的切换方法
Jun 13 Python
基于python的selenium两种文件上传操作实现详解
Sep 19 Python
TensorFlow2.0矩阵与向量的加减乘实例
Feb 07 Python
详解python logging日志传输
Jul 01 Python
Python matplotlib读取excel数据并用for循环画多个子图subplot操作
Jul 14 Python
Python 绘制可视化折线图
Jul 22 Python
解决pytorch 损失函数中输入输出不匹配的问题
Jun 05 #Python
Pytorch distributed 多卡并行载入模型操作
Jun 05 #Python
Pytorch中的学习率衰减及其用法详解
Jun 05 #Python
pytorch finetuning 自己的图片进行训练操作
Jun 05 #Python
Python 如何将integer转化为罗马数(3999以内)
Jun 05 #Python
刚学完怎么用Python实现定时任务,转头就跑去撩妹!
OpenCV中resize函数插值算法的实现过程(五种)
Jun 05 #Python
You might like
PHP获取数组中某元素的位置及array_keys函数应用
2013/01/29 PHP
php中Session的生成机制、回收机制和存储机制探究
2014/08/19 PHP
PHP实现的构造sql语句类实例
2016/02/03 PHP
CI框架出现mysql数据库连接资源无法释放的解决方法
2016/05/17 PHP
javascript引用对象的方法
2007/01/11 Javascript
网页中返回顶部代码(多种方法)另附注释说明
2013/04/24 Javascript
关于jquery.validate1.9.0前台验证的使用介绍
2013/04/26 Javascript
javascript右下角弹层及自动隐藏(自己编写)
2013/11/20 Javascript
JavaScript实现存储HTML字符串示例
2014/04/21 Javascript
JavaScript sup方法入门实例(把字符串显示为上标)
2014/10/20 Javascript
解析JavaScript的ES6版本中的解构赋值
2015/07/28 Javascript
Vue.js实现移动端短信验证码功能
2017/03/29 Javascript
Vue.js实现一个SPA登录页面的过程【推荐】
2017/04/29 Javascript
javascript中floor使用方法总结
2019/02/02 Javascript
使用nodejs分离html文件里的js和css详解
2019/04/12 NodeJs
nodejs读取图片返回给浏览器显示
2019/07/25 NodeJs
es6函数之严格模式用法实例分析
2020/03/17 Javascript
Python版的文曲星猜数字游戏代码
2013/09/02 Python
详解Python list 与 NumPy.ndarry 切片之间的对比
2017/07/24 Python
使用Python制作微信跳一跳辅助
2018/01/31 Python
Python为何不能用可变对象作为默认参数的值
2019/07/01 Python
Django框架基础模板标签与filter使用方法详解
2019/07/23 Python
Python JSON常用编解码方法代码实例
2020/09/05 Python
美特斯邦威官方商城:邦购网
2016/10/13 全球购物
巴西购物网站:Estrela10
2018/12/13 全球购物
凯普林包包西班牙官网:Kipling西班牙
2019/04/12 全球购物
公司面试感谢信
2014/02/01 职场文书
爱国卫生月实施方案
2014/02/21 职场文书
2015年班组长工作总结
2015/04/10 职场文书
工作自我评价范文
2019/03/21 职场文书
演讲稿:态度决定一切
2019/04/02 职场文书
《曾国藩家书》读后感——读家书,立家风
2019/08/21 职场文书
tensorflow中的数据类型dtype用法说明
2021/05/26 Python
python爬取某网站原图作为壁纸
2021/06/02 Python
Java网络编程之UDP实现原理解析
2021/09/04 Java/Android
 Python 中 logging 模块使用详情
2022/03/03 Python