解决Pytorch修改预训练模型时遇到key不匹配的情况


Posted in Python onJune 05, 2021

一、Pytorch修改预训练模型时遇到key不匹配

最近想着修改网络的预训练模型vgg.pth,但是发现当我加载预训练模型权重到新建的模型并保存之后。

在我使用新赋值的网络模型时出现了key不匹配的问题

#加载后保存(未修改网络)
base_weights = torch.load(args.save_folder + args.basenet)
ssd_net.vgg.load_state_dict(base_weights) 
torch.save(ssd_net.state_dict(), args.save_folder + 'ssd_base' + '.pth')
# 将新保存的网络代替之前的预训练模型
    ssd_net = build_ssd('train', cfg['min_dim'], cfg['num_classes'])
    net = ssd_net
    ...
    if args.resume:
        ...
    else:
        base_weights = torch.load(args.save_folder + args.basenet)
        #args.basenet为ssd_base.pth
        print('Loading base network...')
        ssd_net.vgg.load_state_dict(base_weights)

此时会如下出错误:

Loading base network…
Traceback (most recent call last):
File “train.py”, line 264, in
train()
File “train.py”, line 110, in train
ssd_net.vgg.load_state_dict(base_weights)

RuntimeError: Error(s) in loading state_dict for ModuleList:
Missing key(s) in state_dict: “0.weight”, “0.bias”, … “33.weight”, “33.bias”.
Unexpected key(s) in state_dict: “vgg.0.weight”, “vgg.0.bias”, … “vgg.33.weight”, “vgg.33.bias”.

说明之前的预训练模型 key参数为"0.weight", “0.bias”,但是经过加载保存之后变为了"vgg.0.weight", “vgg.0.bias”

我认为是因为本身的模型定义文件里self.vgg = nn.ModuleList(base)这一句。

现在的问题是因为自己定义保存的模型key参数多了一个前缀。

可以通过如下语句进行修改,并加载

from collections import OrderedDict   #导入此模块
base_weights = torch.load(args.save_folder + args.basenet)
print('Loading base network...')
new_state_dict = **OrderedDict()**  
for k, v in base_weights.items():
    name = k[4:]   # remove `vgg.`,即只取vgg.0.weights的后面几位
    new_state_dict[name] = v 
    ssd_net.vgg.load_state_dict(new_state_dict)

此时就不会再出错了。

参考了这个篇。修改一下就可以应用到自己的模型啦。

//www.3water.com/article/214214.htm

二、pytorch加载预训练模型遇到的问题:KeyError: ‘bn1.num_batches_tracked‘

最近在使用pytorch1.0加载resnet预训练模型时,遇到的一个问题,在此记录一下。

KeyError: 'layer1.0.bn1.num_batches_tracked'

其实是使用的版本的问题,pytorch0.4.1之后在BN层加入了track_running_stats这个参数,

这个参数的作用如下:

训练时用来统计训练时的forward过的min-batch数目,每经过一个min-batch, track_running_stats+=1

如果没有指定momentum, 则使用1/num_batches_tracked 作为因数来计算均值和方差(running mean and variance).

其实,这个参数没啥用.但因为官方提供的预训练模型是pytorch0.3版本训练出来的,因此没有这个参数.

所以,只要过滤一下预训练权重字典中的关键字即可,‘num_batches_tracked'.代码例子,如下.

有问题的代码:

def load_specific_param(self, state_dict, param_name, model_path):
        param_dict = torch.load(model_path)
        for i in state_dict:
            key = param_name + '.' + i
            state_dict[i].copy_(param_dict[key])
        del param_dict

对'num_batches_tracked进行过滤:

def load_specific_param(self, state_dict, param_name, model_path):
        param_dict = torch.load(model_path)
        param_dict = {k: v for k, v in param_dict.items() if 'num_batches_tracked' not in k}
        for i in state_dict:
            key = param_name + '.' + i
            if 'num_batches_tracked' in key:
                continue
            state_dict[i].copy_(param_dict[key])
        del param_dict

以上为个人经验,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
Python实现求最大公约数及判断素数的方法
May 26 Python
python中类和实例如何绑定属性与方法示例详解
Aug 18 Python
tensorflow创建变量以及根据名称查找变量
Mar 10 Python
使用Python和xlwt向Excel文件中写入中文的实例
Apr 21 Python
对numpy中shape的深入理解
Jun 15 Python
python逆序打印各位数字的方法
Jun 25 Python
python 检查文件mime类型的方法
Dec 08 Python
Python如何访问字符串中的值
Feb 09 Python
在pycharm中实现删除bookmark
Feb 14 Python
Python xml、字典、json、类四种数据类型如何实现互相转换
May 27 Python
Python定时任务框架APScheduler原理及常用代码
Oct 05 Python
Python实现查询剪贴板自动匹配信息的思路详解
Jul 09 Python
pytorch 预训练模型读取修改相关参数的填坑问题
Jun 05 #Python
解决pytorch 损失函数中输入输出不匹配的问题
Jun 05 #Python
Pytorch distributed 多卡并行载入模型操作
Jun 05 #Python
Pytorch中的学习率衰减及其用法详解
Jun 05 #Python
pytorch finetuning 自己的图片进行训练操作
Jun 05 #Python
Python 如何将integer转化为罗马数(3999以内)
Jun 05 #Python
刚学完怎么用Python实现定时任务,转头就跑去撩妹!
You might like
nginx+php-fpm配置文件的组织结构介绍
2012/11/07 PHP
PHP 实现explort() 功能的详解
2013/06/20 PHP
Thinkphp将二维数组变为标签适用的一维数组方法总结
2014/10/30 PHP
Yii框架中memcache用法实例
2014/12/03 PHP
神奇的代码 通杀各种网站-可随意修改复制页面内容
2008/07/17 Javascript
JQUERY复选框CHECKBOX全选,取消全选
2008/08/30 Javascript
Jquery AJAX 用于计算点击率(统计)
2010/06/30 Javascript
jquery validation验证身份证号,护照,电话号码,email(实例代码)
2013/11/06 Javascript
jquery鼠标放上去显示悬浮层即弹出定位的div层
2014/04/25 Javascript
js+html5绘制图片到canvas的方法
2015/06/05 Javascript
javaScript实现可缩放的显示区效果代码
2015/10/26 Javascript
基于jQuery实现左右图片轮播(原理通用)
2015/12/24 Javascript
javascript点击按钮实现隐藏显示切换效果
2016/02/03 Javascript
深入理解jQuery之防止冒泡事件
2016/05/24 Javascript
jquery实现图片切换代码
2016/10/13 Javascript
Bootstrap Table使用整理(四)之工具栏
2017/06/09 Javascript
浅谈ES6 模板字符串的具体使用方法
2017/11/07 Javascript
关于vue的npm run dev和npm run build的区别介绍
2019/01/14 Javascript
深入分析jQuery.one() 函数
2020/06/03 jQuery
vue tab滚动到一定高度,固定在顶部,点击tab切换不同的内容操作
2020/07/22 Javascript
基于JS实现操作成功之后自动跳转页面
2020/09/25 Javascript
Python解析nginx日志文件
2015/05/11 Python
python2与python3共存问题的解决方法
2018/09/18 Python
Python实现操纵控制windows注册表的方法分析
2019/05/24 Python
python join方法使用详解
2019/07/30 Python
PyTorch的SoftMax交叉熵损失和梯度用法
2020/01/15 Python
Tensorflow--取tensorf指定列的操作方式
2020/06/30 Python
html5 postMessage解决跨域、跨窗口消息传递方案
2016/12/20 HTML / CSS
YOOX美国官方网站:全球著名的多品牌时尚网络概念店
2016/09/11 全球购物
Cinque网上商店:德国服装品牌
2019/03/17 全球购物
高级3D打印市场:Gambody
2019/12/26 全球购物
学校安全工作制度
2014/01/19 职场文书
房屋转让协议书
2014/04/11 职场文书
大学生学雷锋活动总结
2014/06/26 职场文书
反对四风问题自我剖析材料
2014/09/29 职场文书
个人总结格式范文
2015/03/09 职场文书