解决Pytorch修改预训练模型时遇到key不匹配的情况


Posted in Python onJune 05, 2021

一、Pytorch修改预训练模型时遇到key不匹配

最近想着修改网络的预训练模型vgg.pth,但是发现当我加载预训练模型权重到新建的模型并保存之后。

在我使用新赋值的网络模型时出现了key不匹配的问题

#加载后保存(未修改网络)
base_weights = torch.load(args.save_folder + args.basenet)
ssd_net.vgg.load_state_dict(base_weights) 
torch.save(ssd_net.state_dict(), args.save_folder + 'ssd_base' + '.pth')
# 将新保存的网络代替之前的预训练模型
    ssd_net = build_ssd('train', cfg['min_dim'], cfg['num_classes'])
    net = ssd_net
    ...
    if args.resume:
        ...
    else:
        base_weights = torch.load(args.save_folder + args.basenet)
        #args.basenet为ssd_base.pth
        print('Loading base network...')
        ssd_net.vgg.load_state_dict(base_weights)

此时会如下出错误:

Loading base network…
Traceback (most recent call last):
File “train.py”, line 264, in
train()
File “train.py”, line 110, in train
ssd_net.vgg.load_state_dict(base_weights)

RuntimeError: Error(s) in loading state_dict for ModuleList:
Missing key(s) in state_dict: “0.weight”, “0.bias”, … “33.weight”, “33.bias”.
Unexpected key(s) in state_dict: “vgg.0.weight”, “vgg.0.bias”, … “vgg.33.weight”, “vgg.33.bias”.

说明之前的预训练模型 key参数为"0.weight", “0.bias”,但是经过加载保存之后变为了"vgg.0.weight", “vgg.0.bias”

我认为是因为本身的模型定义文件里self.vgg = nn.ModuleList(base)这一句。

现在的问题是因为自己定义保存的模型key参数多了一个前缀。

可以通过如下语句进行修改,并加载

from collections import OrderedDict   #导入此模块
base_weights = torch.load(args.save_folder + args.basenet)
print('Loading base network...')
new_state_dict = **OrderedDict()**  
for k, v in base_weights.items():
    name = k[4:]   # remove `vgg.`,即只取vgg.0.weights的后面几位
    new_state_dict[name] = v 
    ssd_net.vgg.load_state_dict(new_state_dict)

此时就不会再出错了。

参考了这个篇。修改一下就可以应用到自己的模型啦。

//www.3water.com/article/214214.htm

二、pytorch加载预训练模型遇到的问题:KeyError: ‘bn1.num_batches_tracked‘

最近在使用pytorch1.0加载resnet预训练模型时,遇到的一个问题,在此记录一下。

KeyError: 'layer1.0.bn1.num_batches_tracked'

其实是使用的版本的问题,pytorch0.4.1之后在BN层加入了track_running_stats这个参数,

这个参数的作用如下:

训练时用来统计训练时的forward过的min-batch数目,每经过一个min-batch, track_running_stats+=1

如果没有指定momentum, 则使用1/num_batches_tracked 作为因数来计算均值和方差(running mean and variance).

其实,这个参数没啥用.但因为官方提供的预训练模型是pytorch0.3版本训练出来的,因此没有这个参数.

所以,只要过滤一下预训练权重字典中的关键字即可,‘num_batches_tracked'.代码例子,如下.

有问题的代码:

def load_specific_param(self, state_dict, param_name, model_path):
        param_dict = torch.load(model_path)
        for i in state_dict:
            key = param_name + '.' + i
            state_dict[i].copy_(param_dict[key])
        del param_dict

对'num_batches_tracked进行过滤:

def load_specific_param(self, state_dict, param_name, model_path):
        param_dict = torch.load(model_path)
        param_dict = {k: v for k, v in param_dict.items() if 'num_batches_tracked' not in k}
        for i in state_dict:
            key = param_name + '.' + i
            if 'num_batches_tracked' in key:
                continue
            state_dict[i].copy_(param_dict[key])
        del param_dict

以上为个人经验,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
python提取字典key列表的方法
Jul 11 Python
python语言使用技巧分享
May 31 Python
Python字典简介以及用法详解
Nov 15 Python
python的构建工具setup.py的方法使用示例
Oct 23 Python
python中dir()与__dict__属性的区别浅析
Dec 10 Python
Python实现简易过滤删除数字的方法小结
Jan 09 Python
python实现批量注册网站用户的示例
Feb 22 Python
pow在python中的含义及用法
Jul 11 Python
python 基于卡方值分箱算法的实现示例
Jul 17 Python
Django Auth用户认证组件实现代码
Oct 13 Python
使用Python脚本对GiteePages进行一键部署的使用说明
May 27 Python
Python中22个万用公式的小结
Jul 21 Python
pytorch 预训练模型读取修改相关参数的填坑问题
Jun 05 #Python
解决pytorch 损失函数中输入输出不匹配的问题
Jun 05 #Python
Pytorch distributed 多卡并行载入模型操作
Jun 05 #Python
Pytorch中的学习率衰减及其用法详解
Jun 05 #Python
pytorch finetuning 自己的图片进行训练操作
Jun 05 #Python
Python 如何将integer转化为罗马数(3999以内)
Jun 05 #Python
刚学完怎么用Python实现定时任务,转头就跑去撩妹!
You might like
php中配置文件操作 如config.php文件的读取修改等操作
2012/07/07 PHP
比较discuz和ecshop的截取字符串函数php版
2012/09/03 PHP
详解WordPress中分类函数wp_list_categories的使用
2016/01/04 PHP
PHP怎样用正则抓取页面中的网址
2016/08/09 PHP
php安全配置记录和常见错误梳理(总结)
2017/03/28 PHP
Thinkphp5.0自动生成模块及目录的方法详解
2017/04/17 PHP
Thinkphp5 微信公众号token验证不成功的原因及解决方法
2017/11/12 PHP
PHP bin2hex()函数基础实例讲解
2019/02/11 PHP
JavaScript Event学习第二章 Event浏览器兼容性
2010/02/07 Javascript
在IE和VB中支持png图片透明效果的实现方法(vb源码打包)
2011/04/01 Javascript
键盘上一张下一张兼容IE/google/firefox等浏览器
2014/01/28 Javascript
jquery ajax应用中iframe自适应高度问题解决方法
2014/04/12 Javascript
javascript event在FF和IE的兼容传参心得(绝对好用)
2014/07/10 Javascript
简述AngularJS相关的一些编程思想
2015/06/23 Javascript
AngularJS+Node.js实现在线聊天室
2015/08/28 Javascript
jQuery实现分章节锚点“回到顶部”动画特效代码
2015/10/23 Javascript
jQuery增加和删除表格项目及实现表格项目排序的方法
2016/05/30 Javascript
JavaScript中Number对象的toFixed() 方法详解
2016/09/02 Javascript
js实现日历的简单算法
2017/01/24 Javascript
JavaScript实现form表单的多文件上传
2020/03/27 Javascript
JavaScript之map reduce_动力节点Java学院整理
2017/06/29 Javascript
vue使用 better-scroll的参数和方法详解
2018/01/25 Javascript
vue实现图片预览组件封装与使用
2019/07/13 Javascript
vue 使用 vue-pdf 实现pdf在线预览的示例代码
2020/04/26 Javascript
Python 获得13位unix时间戳的方法
2017/10/20 Python
python 爬虫一键爬取 淘宝天猫宝贝页面主图颜色图和详情图的教程
2018/05/22 Python
PyQt5+requests实现车票查询工具
2019/01/21 Python
Python使用import导入本地脚本及导入模块的技巧总结
2019/08/07 Python
Python数据库小程序源代码
2019/09/15 Python
python日志模块logbook使用方法
2019/09/19 Python
办理生育手续介绍信
2014/01/14 职场文书
人力资源求职信
2014/05/25 职场文书
合唱兴趣小组活动总结
2014/07/10 职场文书
汽车机电维修工求职信
2014/09/30 职场文书
小学英语教师2015年度个人工作总结
2015/10/14 职场文书
浅谈mysql返回Boolean类型的几种情况
2021/06/04 MySQL