pytorch 预训练模型读取修改相关参数的填坑问题


Posted in Python onJune 05, 2021

pytorch 预训练模型读取修改相关参数的填坑

修改部分层,仍然调用之前的模型参数。

resnet = resnet50(pretrained=False)
resnet.load_state_dict(torch.load(args.predir))
 
res_conv31 = Bottleneck_dilated(1024, 256,dilated_rate = 2)
print("---------------------",res_conv31)
print("---------------------",resnet.layer3[1])
 
res_conv31.load_state_dict(resnet.layer3[1].state_dict())

网络预训练模型与之前的模型对应不上,名称差个前缀

model_dict = model.state_dict()
# print(model_dict)
pretrained_dict = torch.load("/yzc/reid_testpcb/se_resnet50-ce0d4300.pth")
keys = []
for k, v in pretrained_dict.items():
       keys.append(k)
i = 0
for k, v in model_dict.items():
    if v.size() == pretrained_dict[keys[i]].size():
         model_dict[k] = pretrained_dict[keys[i]]
         #print(model_dict[k])
         i = i + 1
model.load_state_dict(model_dict)

最后是修改参数名拿来用的,

from collections import OrderedDict
pretrained_dict = torch.load('premodel')
 
new_state_dict = OrderedDict()
 
# for k, v in mgn_state_dict.items():
#     name = k[7:]  # remove `module.`
#     new_state_dict[name] = v
# self.model = self.model.load_state_dict(new_state_dict)
 
for k, v in pretrained_dict.items():
    name = "model.module."+k   # remove `module.`
    # print(name)
    new_state_dict[name] = v
self.model.load_state_dict(new_state_dict)

pytorch:加载预训练模型中的部分参数,并固定该部分参数(真实有效)

大家在学习pytorch时,可能想利用pytorch进行fine-tune,但是又烦恼于参数的加载问题。下面我将讲诉我的使用心得。

Step1: 加载预训练模型,并去除需要再次训练的层

#注意:需要重新训练的层的名字要和之前的不同。
model=resnet()#自己构建的模型,以resnet为例
model_dict = model.state_dict()
pretrained_dict = torch.load('xxx.pkl')
pretrained_dict = {k: v for k, v in pretrained_dict.items() if k in model_dict}
model_dict.update(pretrained_dict)
model.load_state_dict(model_dict)

Step2:固定部分参数

#k是可训练参数的名字,v是包含可训练参数的一个实体
#可以先print(k),找到自己想进行调整的层,并将该层的名字加入到if语句中:
for k,v in model.named_parameters():
    if k!='xxx.weight' and k!='xxx.bias' :
        v.requires_grad=False#固定参数

Step3:训练部分参数

#将要训练的参数放入优化器
optimizer2=torch.optim.Adam(params=[model.xxx.weight,model.xxx.bias],lr=learning_rate,betas=(0.9,0.999),weight_decay=1e-5)

Step4:检查部分参数是否固定

debug之后,程序正常运行,最好检查一下网络的参数是否真的被固定了,如何没固定,网络的状态接近于重新训练,可能会导致网络性能不稳定,也没办法得到想要得到的性能提升。

for k,v in model.named_parameters():
   if k!='xxx.weight' and k!='xxx.bias' :
   print(v.requires_grad)#理想状态下,所有值都是False

需要注意的是,操作失误最大的影响是,loss函数几乎不会发生变化,一直处于最开始的状态,这很可能是因为所有参数都被固定了。

以上为个人经验,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
python使用mysqldb连接数据库操作方法示例详解
Dec 03 Python
跟老齐学Python之使用Python查询更新数据库
Nov 25 Python
python爬虫headers设置后无效的解决方法
Oct 21 Python
实例讲解Python中整数的最大值输出
Mar 17 Python
python 进程间数据共享multiProcess.Manger实现解析
Sep 23 Python
详解Python绘图Turtle库
Oct 12 Python
Python namedtuple命名元组实现过程解析
Jan 08 Python
关于windows下Tensorflow和pytorch安装教程
Feb 04 Python
python中安装django模块的方法
Mar 12 Python
基于pytorch中的Sequential用法说明
Jun 24 Python
python3代码输出嵌套式对象实例详解
Dec 03 Python
解决PyCharm无法使用lxml库的问题(图解)
Dec 22 Python
解决pytorch 损失函数中输入输出不匹配的问题
Jun 05 #Python
Pytorch distributed 多卡并行载入模型操作
Jun 05 #Python
Pytorch中的学习率衰减及其用法详解
Jun 05 #Python
pytorch finetuning 自己的图片进行训练操作
Jun 05 #Python
Python 如何将integer转化为罗马数(3999以内)
Jun 05 #Python
刚学完怎么用Python实现定时任务,转头就跑去撩妹!
OpenCV中resize函数插值算法的实现过程(五种)
Jun 05 #Python
You might like
解析php中获取url与物理路径的总结
2013/06/21 PHP
php使用json_encode对变量json编码
2014/04/07 PHP
PHP连接SQLServer2005的方法
2015/01/27 PHP
PHPExcel导出2003和2007的excel文档功能示例
2017/01/04 PHP
prototype 1.5 & scriptaculous 1.6.1 学习笔记
2006/09/07 Javascript
对xmlHttp对象方法和属性的理解
2011/01/17 Javascript
jquery实现点击TreeView文本父节点展开/折叠子节点
2013/01/10 Javascript
深入理解javascript动态插入技术
2013/11/12 Javascript
Labelauty?jQuery单选框/复选框美化插件分享
2015/09/26 Javascript
js/jquery控制页面动态加载数据 滑动滚动条自动加载事件的方法
2017/02/08 Javascript
基于JS实现仿百度百家主页的轮播图效果
2017/03/06 Javascript
Vue.directive自定义指令的使用详解
2017/03/10 Javascript
jquery实现tab选项卡切换效果(悬停、下方横线动画位移)
2017/05/05 jQuery
基于easyui checkbox 的一些操作处理方法
2017/07/10 Javascript
vue中各组件之间传递数据的方法示例
2017/07/27 Javascript
详解基于 Nuxt 的 Vue.js 服务端渲染实践
2017/10/24 Javascript
Thinkjs3新手入门之如何使用静态资源目录
2017/12/06 Javascript
对vue中v-if的常见使用方法详解
2018/09/28 Javascript
原生JS实现旋转轮播图+文字内容切换效果【附源码】
2018/09/29 Javascript
怎么使用javascript深度拷贝一个数组
2019/06/06 Javascript
Vue v-text指令简单使用方法示例
2019/09/19 Javascript
详解Python的Django框架中inclusion_tag的使用
2015/07/21 Python
深入解析Python的Tornado框架中内置的模板引擎
2016/07/11 Python
python 容器总结整理
2017/04/04 Python
python 连接sqlite及简单操作
2017/06/30 Python
Python2/3中urllib库的一些常见用法
2017/12/19 Python
numpy求平均值的维度设定的例子
2019/08/24 Python
感知器基础原理及python实现过程详解
2019/09/30 Python
python读文件的步骤
2019/10/08 Python
葡萄牙航空官方网站:TAP Air Portugal
2019/10/31 全球购物
澳大利亚家具商店:Freedom
2020/12/17 全球购物
班组长竞聘书
2014/03/31 职场文书
党员自我评价范文2015
2015/03/03 职场文书
《学会看病》教学反思
2016/02/17 职场文书
八年级作文之感恩
2019/11/22 职场文书
世界十大评分最高的动漫,CLANNAD上榜,第八赚足人们眼泪
2022/03/18 日漫