解决Pytorch修改预训练模型时遇到key不匹配的情况


Posted in Python onJune 05, 2021

一、Pytorch修改预训练模型时遇到key不匹配

最近想着修改网络的预训练模型vgg.pth,但是发现当我加载预训练模型权重到新建的模型并保存之后。

在我使用新赋值的网络模型时出现了key不匹配的问题

#加载后保存(未修改网络)
base_weights = torch.load(args.save_folder + args.basenet)
ssd_net.vgg.load_state_dict(base_weights) 
torch.save(ssd_net.state_dict(), args.save_folder + 'ssd_base' + '.pth')
# 将新保存的网络代替之前的预训练模型
    ssd_net = build_ssd('train', cfg['min_dim'], cfg['num_classes'])
    net = ssd_net
    ...
    if args.resume:
        ...
    else:
        base_weights = torch.load(args.save_folder + args.basenet)
        #args.basenet为ssd_base.pth
        print('Loading base network...')
        ssd_net.vgg.load_state_dict(base_weights)

此时会如下出错误:

Loading base network…
Traceback (most recent call last):
File “train.py”, line 264, in
train()
File “train.py”, line 110, in train
ssd_net.vgg.load_state_dict(base_weights)

RuntimeError: Error(s) in loading state_dict for ModuleList:
Missing key(s) in state_dict: “0.weight”, “0.bias”, … “33.weight”, “33.bias”.
Unexpected key(s) in state_dict: “vgg.0.weight”, “vgg.0.bias”, … “vgg.33.weight”, “vgg.33.bias”.

说明之前的预训练模型 key参数为"0.weight", “0.bias”,但是经过加载保存之后变为了"vgg.0.weight", “vgg.0.bias”

我认为是因为本身的模型定义文件里self.vgg = nn.ModuleList(base)这一句。

现在的问题是因为自己定义保存的模型key参数多了一个前缀。

可以通过如下语句进行修改,并加载

from collections import OrderedDict   #导入此模块
base_weights = torch.load(args.save_folder + args.basenet)
print('Loading base network...')
new_state_dict = **OrderedDict()**  
for k, v in base_weights.items():
    name = k[4:]   # remove `vgg.`,即只取vgg.0.weights的后面几位
    new_state_dict[name] = v 
    ssd_net.vgg.load_state_dict(new_state_dict)

此时就不会再出错了。

参考了这个篇。修改一下就可以应用到自己的模型啦。

//www.3water.com/article/214214.htm

二、pytorch加载预训练模型遇到的问题:KeyError: ‘bn1.num_batches_tracked‘

最近在使用pytorch1.0加载resnet预训练模型时,遇到的一个问题,在此记录一下。

KeyError: 'layer1.0.bn1.num_batches_tracked'

其实是使用的版本的问题,pytorch0.4.1之后在BN层加入了track_running_stats这个参数,

这个参数的作用如下:

训练时用来统计训练时的forward过的min-batch数目,每经过一个min-batch, track_running_stats+=1

如果没有指定momentum, 则使用1/num_batches_tracked 作为因数来计算均值和方差(running mean and variance).

其实,这个参数没啥用.但因为官方提供的预训练模型是pytorch0.3版本训练出来的,因此没有这个参数.

所以,只要过滤一下预训练权重字典中的关键字即可,‘num_batches_tracked'.代码例子,如下.

有问题的代码:

def load_specific_param(self, state_dict, param_name, model_path):
        param_dict = torch.load(model_path)
        for i in state_dict:
            key = param_name + '.' + i
            state_dict[i].copy_(param_dict[key])
        del param_dict

对'num_batches_tracked进行过滤:

def load_specific_param(self, state_dict, param_name, model_path):
        param_dict = torch.load(model_path)
        param_dict = {k: v for k, v in param_dict.items() if 'num_batches_tracked' not in k}
        for i in state_dict:
            key = param_name + '.' + i
            if 'num_batches_tracked' in key:
                continue
            state_dict[i].copy_(param_dict[key])
        del param_dict

以上为个人经验,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
在Python中使用CasperJS获取JS渲染生成的HTML内容的教程
Apr 09 Python
Python算法应用实战之栈详解
Feb 04 Python
Python决策树分类算法学习
Dec 22 Python
python pandas dataframe 行列选择,切片操作方法
Apr 10 Python
PyCharm设置SSH远程调试的方法
Jul 17 Python
解决PyCharm的Python.exe已经停止工作的问题
Nov 29 Python
Tensorflow 卷积的梯度反向传播过程
Feb 10 Python
在tensorflow下利用plt画论文中loss,acc等曲线图实例
Jun 15 Python
详解python logging日志传输
Jul 01 Python
python进度条显示-tqmd模块的实现示例
Aug 23 Python
基于python获取本地时间并转换时间戳和日期格式
Oct 27 Python
python入门之算法学习
Apr 22 Python
pytorch 预训练模型读取修改相关参数的填坑问题
Jun 05 #Python
解决pytorch 损失函数中输入输出不匹配的问题
Jun 05 #Python
Pytorch distributed 多卡并行载入模型操作
Jun 05 #Python
Pytorch中的学习率衰减及其用法详解
Jun 05 #Python
pytorch finetuning 自己的图片进行训练操作
Jun 05 #Python
Python 如何将integer转化为罗马数(3999以内)
Jun 05 #Python
刚学完怎么用Python实现定时任务,转头就跑去撩妹!
You might like
php xml-rpc远程调用
2008/12/19 PHP
PHP 图片文件上传实现代码
2010/12/29 PHP
PHP代码判断设备是手机还是平板电脑(两种方法)
2015/10/19 PHP
Zend Framework教程之模型Model用法简单实例
2016/03/04 PHP
smarty学习笔记之常见代码段用法总结
2016/03/19 PHP
基于json的jquery地区联动效果代码
2011/07/06 Javascript
你必须知道的JavaScript 中字符串连接的性能的一些问题
2013/05/07 Javascript
使用jquery的ajax需要注意的地方dataType的设置
2013/08/12 Javascript
JS实现从表格中动态删除指定行的方法
2015/03/31 Javascript
Jquery1.9.1源码分析系列(六)延时对象应用之jQuery.ready
2015/11/24 Javascript
JavaScript ParseFloat()方法
2015/12/18 Javascript
深入理解jQuery中的事件冒泡
2016/05/24 Javascript
基于gulp合并压缩Seajs模块的方式说明
2016/06/14 Javascript
微信小程序 教程之数据绑定
2016/10/18 Javascript
JavaScript队列、优先队列与循环队列
2016/11/14 Javascript
JS实现点击表头表格自动排序(含数字、字符串、日期)
2017/01/22 Javascript
用jquery的attr方法实现图片切换效果
2017/02/05 Javascript
JS实现加载和读取XML文件的方法详解
2017/04/24 Javascript
Vue自定义属性实例分析
2019/02/23 Javascript
js图数据结构处理 迪杰斯特拉算法代码实例
2019/09/11 Javascript
JS常用排序方法实例代码解析
2020/03/03 Javascript
Nodejs文件上传、监听上传进度的代码
2020/03/27 NodeJs
Vue中关闭弹窗组件时销毁并隐藏操作
2020/09/01 Javascript
Python中利用sqrt()方法进行平方根计算的教程
2015/05/15 Python
Python matplotlib绘制饼状图功能示例
2019/09/10 Python
娇韵诗加拿大官网:Clarins加拿大
2017/11/20 全球购物
菲律宾购物网站:Lazada菲律宾
2018/04/05 全球购物
程序员机试试题汇总
2012/03/07 面试题
如何执行一个shell程序
2012/11/23 面试题
生物制药毕业生自荐信
2013/10/16 职场文书
学生打架检讨书1000字
2014/01/16 职场文书
促销活动总结报告
2014/04/26 职场文书
竞选班干部演讲稿500字
2014/08/20 职场文书
补充协议书
2015/01/28 职场文书
医务人员医德考评自我评价
2015/03/03 职场文书
Python使用Opencv打开笔记本电脑摄像头报错解问题及解决
2022/06/21 Python