解决Pytorch修改预训练模型时遇到key不匹配的情况


Posted in Python onJune 05, 2021

一、Pytorch修改预训练模型时遇到key不匹配

最近想着修改网络的预训练模型vgg.pth,但是发现当我加载预训练模型权重到新建的模型并保存之后。

在我使用新赋值的网络模型时出现了key不匹配的问题

#加载后保存(未修改网络)
base_weights = torch.load(args.save_folder + args.basenet)
ssd_net.vgg.load_state_dict(base_weights) 
torch.save(ssd_net.state_dict(), args.save_folder + 'ssd_base' + '.pth')
# 将新保存的网络代替之前的预训练模型
    ssd_net = build_ssd('train', cfg['min_dim'], cfg['num_classes'])
    net = ssd_net
    ...
    if args.resume:
        ...
    else:
        base_weights = torch.load(args.save_folder + args.basenet)
        #args.basenet为ssd_base.pth
        print('Loading base network...')
        ssd_net.vgg.load_state_dict(base_weights)

此时会如下出错误:

Loading base network…
Traceback (most recent call last):
File “train.py”, line 264, in
train()
File “train.py”, line 110, in train
ssd_net.vgg.load_state_dict(base_weights)

RuntimeError: Error(s) in loading state_dict for ModuleList:
Missing key(s) in state_dict: “0.weight”, “0.bias”, … “33.weight”, “33.bias”.
Unexpected key(s) in state_dict: “vgg.0.weight”, “vgg.0.bias”, … “vgg.33.weight”, “vgg.33.bias”.

说明之前的预训练模型 key参数为"0.weight", “0.bias”,但是经过加载保存之后变为了"vgg.0.weight", “vgg.0.bias”

我认为是因为本身的模型定义文件里self.vgg = nn.ModuleList(base)这一句。

现在的问题是因为自己定义保存的模型key参数多了一个前缀。

可以通过如下语句进行修改,并加载

from collections import OrderedDict   #导入此模块
base_weights = torch.load(args.save_folder + args.basenet)
print('Loading base network...')
new_state_dict = **OrderedDict()**  
for k, v in base_weights.items():
    name = k[4:]   # remove `vgg.`,即只取vgg.0.weights的后面几位
    new_state_dict[name] = v 
    ssd_net.vgg.load_state_dict(new_state_dict)

此时就不会再出错了。

参考了这个篇。修改一下就可以应用到自己的模型啦。

//www.3water.com/article/214214.htm

二、pytorch加载预训练模型遇到的问题:KeyError: ‘bn1.num_batches_tracked‘

最近在使用pytorch1.0加载resnet预训练模型时,遇到的一个问题,在此记录一下。

KeyError: 'layer1.0.bn1.num_batches_tracked'

其实是使用的版本的问题,pytorch0.4.1之后在BN层加入了track_running_stats这个参数,

这个参数的作用如下:

训练时用来统计训练时的forward过的min-batch数目,每经过一个min-batch, track_running_stats+=1

如果没有指定momentum, 则使用1/num_batches_tracked 作为因数来计算均值和方差(running mean and variance).

其实,这个参数没啥用.但因为官方提供的预训练模型是pytorch0.3版本训练出来的,因此没有这个参数.

所以,只要过滤一下预训练权重字典中的关键字即可,‘num_batches_tracked'.代码例子,如下.

有问题的代码:

def load_specific_param(self, state_dict, param_name, model_path):
        param_dict = torch.load(model_path)
        for i in state_dict:
            key = param_name + '.' + i
            state_dict[i].copy_(param_dict[key])
        del param_dict

对'num_batches_tracked进行过滤:

def load_specific_param(self, state_dict, param_name, model_path):
        param_dict = torch.load(model_path)
        param_dict = {k: v for k, v in param_dict.items() if 'num_batches_tracked' not in k}
        for i in state_dict:
            key = param_name + '.' + i
            if 'num_batches_tracked' in key:
                continue
            state_dict[i].copy_(param_dict[key])
        del param_dict

以上为个人经验,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
微信跳一跳python辅助脚本(总结)
Jan 11 Python
一道python走迷宫算法题
Jan 22 Python
Python数据分析之双色球统计两个红和蓝球哪组合比例高的方法
Feb 03 Python
Python 读取某个目录下所有的文件实例
Jun 23 Python
Python后台开发Django的教程详解(启动)
Apr 08 Python
如何运行带参数的python脚本
Nov 15 Python
wxPython实现文本框基础组件
Nov 18 Python
python二分法查找算法实现方法【递归与非递归】
Dec 06 Python
python与c语言的语法有哪些不一样的
Sep 13 Python
python如何利用traceback获取详细的异常信息
Jun 05 Python
2021年最新用于图像处理的Python库总结
Jun 15 Python
python 网络编程要点总结
Jun 18 Python
pytorch 预训练模型读取修改相关参数的填坑问题
Jun 05 #Python
解决pytorch 损失函数中输入输出不匹配的问题
Jun 05 #Python
Pytorch distributed 多卡并行载入模型操作
Jun 05 #Python
Pytorch中的学习率衰减及其用法详解
Jun 05 #Python
pytorch finetuning 自己的图片进行训练操作
Jun 05 #Python
Python 如何将integer转化为罗马数(3999以内)
Jun 05 #Python
刚学完怎么用Python实现定时任务,转头就跑去撩妹!
You might like
php中计算中文字符串长度、截取中文字符串的函数代码
2011/08/09 PHP
PHP输出数组中重名的元素的几种处理方法
2012/09/05 PHP
PHP 数据结构队列(SplQueue)和优先队列(SplPriorityQueue)简单使用实例
2015/05/12 PHP
YII动态模型(动态表名)支持分析
2016/03/29 PHP
jQuery中的.bind()、.live()和.delegate()之间区别分析
2011/06/08 Javascript
js 函数调用模式小结
2011/12/26 Javascript
JQuery UI的拖拽功能实现方法小结
2012/03/14 Javascript
extjs表格文本启用选择复制功能具体实现
2013/10/11 Javascript
jQuery圆形统计图开发实例
2015/01/04 Javascript
jquery通过load获取文件的内容并跳到锚点的方法
2015/01/29 Javascript
jQuery异步上传文件插件ajaxFileUpload详细介绍
2015/05/19 Javascript
由浅入深讲解Javascript继承机制与simple-inheritance源码分析
2015/12/13 Javascript
AngularJs IE Compatibility 兼容老版本IE
2016/09/01 Javascript
使用纯JS代码判断字符串中有多少汉字的实现方法(超简单实用)
2016/11/12 Javascript
js计算两个日期间的天数月的实例代码
2018/09/20 Javascript
elementUI中Table表格问题的解决方法
2018/12/04 Javascript
vue 详情跳转至列表页实现列表页缓存
2019/03/27 Javascript
微信小程序事件 bindtap bindinput代码实例
2019/08/26 Javascript
jquery使用echarts实现有向图可视化功能示例
2019/11/25 jQuery
vue实现路由懒加载的3种方法示例
2020/09/01 Javascript
vue3.0自定义指令(drectives)知识点总结
2020/12/27 Vue.js
利用python获取某年中每个月的第一天和最后一天
2016/12/15 Python
pycharm修改文件的默认打开方式的步骤
2019/07/29 Python
Pytorch保存模型用于测试和用于继续训练的区别详解
2020/01/10 Python
python求最大公约数和最小公倍数的简单方法
2020/02/13 Python
python数据预处理 :数据抽样解析
2020/02/24 Python
tensorflow使用freeze_graph.py将ckpt转为pb文件的方法
2020/04/22 Python
德国体育用品网上商店:SC24.com
2016/08/01 全球购物
见习期自我鉴定
2013/11/07 职场文书
企事业单位求职者的自我评价
2013/12/28 职场文书
播音主持专业个人自我评价
2014/01/09 职场文书
学习十八届三中全会精神实施方案
2014/02/17 职场文书
七一党日活动总结
2014/07/08 职场文书
群众路线剖析材料怎么写
2014/10/09 职场文书
党员三严三实心得体会
2014/10/13 职场文书
行政处罚听证告知书
2015/07/01 职场文书