解决Pytorch修改预训练模型时遇到key不匹配的情况


Posted in Python onJune 05, 2021

一、Pytorch修改预训练模型时遇到key不匹配

最近想着修改网络的预训练模型vgg.pth,但是发现当我加载预训练模型权重到新建的模型并保存之后。

在我使用新赋值的网络模型时出现了key不匹配的问题

#加载后保存(未修改网络)
base_weights = torch.load(args.save_folder + args.basenet)
ssd_net.vgg.load_state_dict(base_weights) 
torch.save(ssd_net.state_dict(), args.save_folder + 'ssd_base' + '.pth')
# 将新保存的网络代替之前的预训练模型
    ssd_net = build_ssd('train', cfg['min_dim'], cfg['num_classes'])
    net = ssd_net
    ...
    if args.resume:
        ...
    else:
        base_weights = torch.load(args.save_folder + args.basenet)
        #args.basenet为ssd_base.pth
        print('Loading base network...')
        ssd_net.vgg.load_state_dict(base_weights)

此时会如下出错误:

Loading base network…
Traceback (most recent call last):
File “train.py”, line 264, in
train()
File “train.py”, line 110, in train
ssd_net.vgg.load_state_dict(base_weights)

RuntimeError: Error(s) in loading state_dict for ModuleList:
Missing key(s) in state_dict: “0.weight”, “0.bias”, … “33.weight”, “33.bias”.
Unexpected key(s) in state_dict: “vgg.0.weight”, “vgg.0.bias”, … “vgg.33.weight”, “vgg.33.bias”.

说明之前的预训练模型 key参数为"0.weight", “0.bias”,但是经过加载保存之后变为了"vgg.0.weight", “vgg.0.bias”

我认为是因为本身的模型定义文件里self.vgg = nn.ModuleList(base)这一句。

现在的问题是因为自己定义保存的模型key参数多了一个前缀。

可以通过如下语句进行修改,并加载

from collections import OrderedDict   #导入此模块
base_weights = torch.load(args.save_folder + args.basenet)
print('Loading base network...')
new_state_dict = **OrderedDict()**  
for k, v in base_weights.items():
    name = k[4:]   # remove `vgg.`,即只取vgg.0.weights的后面几位
    new_state_dict[name] = v 
    ssd_net.vgg.load_state_dict(new_state_dict)

此时就不会再出错了。

参考了这个篇。修改一下就可以应用到自己的模型啦。

//www.3water.com/article/214214.htm

二、pytorch加载预训练模型遇到的问题:KeyError: ‘bn1.num_batches_tracked‘

最近在使用pytorch1.0加载resnet预训练模型时,遇到的一个问题,在此记录一下。

KeyError: 'layer1.0.bn1.num_batches_tracked'

其实是使用的版本的问题,pytorch0.4.1之后在BN层加入了track_running_stats这个参数,

这个参数的作用如下:

训练时用来统计训练时的forward过的min-batch数目,每经过一个min-batch, track_running_stats+=1

如果没有指定momentum, 则使用1/num_batches_tracked 作为因数来计算均值和方差(running mean and variance).

其实,这个参数没啥用.但因为官方提供的预训练模型是pytorch0.3版本训练出来的,因此没有这个参数.

所以,只要过滤一下预训练权重字典中的关键字即可,‘num_batches_tracked'.代码例子,如下.

有问题的代码:

def load_specific_param(self, state_dict, param_name, model_path):
        param_dict = torch.load(model_path)
        for i in state_dict:
            key = param_name + '.' + i
            state_dict[i].copy_(param_dict[key])
        del param_dict

对'num_batches_tracked进行过滤:

def load_specific_param(self, state_dict, param_name, model_path):
        param_dict = torch.load(model_path)
        param_dict = {k: v for k, v in param_dict.items() if 'num_batches_tracked' not in k}
        for i in state_dict:
            key = param_name + '.' + i
            if 'num_batches_tracked' in key:
                continue
            state_dict[i].copy_(param_dict[key])
        del param_dict

以上为个人经验,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
跟老齐学Python之用while来循环
Oct 02 Python
详解django中自定义标签和过滤器
Jul 03 Python
python list删除元素时要注意的坑点分享
Apr 18 Python
python脚本监控Tomcat服务器的方法
Jul 06 Python
pandas DataFrame索引行列的实现
Jun 04 Python
python 设置输出图像的像素大小方法
Jul 04 Python
Python 进程之间共享数据(全局变量)的方法
Jul 16 Python
python flask几分钟实现web服务的例子
Jul 26 Python
Python完成哈夫曼树编码过程及原理详解
Jul 29 Python
Python简单实现区域生长方式
Jan 16 Python
Python HTMLTestRunner如何下载生成报告
Sep 04 Python
Matplotlib配色之Colormap详解
Jan 05 Python
pytorch 预训练模型读取修改相关参数的填坑问题
Jun 05 #Python
解决pytorch 损失函数中输入输出不匹配的问题
Jun 05 #Python
Pytorch distributed 多卡并行载入模型操作
Jun 05 #Python
Pytorch中的学习率衰减及其用法详解
Jun 05 #Python
pytorch finetuning 自己的图片进行训练操作
Jun 05 #Python
Python 如何将integer转化为罗马数(3999以内)
Jun 05 #Python
刚学完怎么用Python实现定时任务,转头就跑去撩妹!
You might like
DOTA2 无惧惊涛骇浪 昆卡大型水友攻略
2020/04/20 DOTA
通过ICQ网关发送手机短信的PHP源程序
2006/10/09 PHP
thinkphp路由规则使用示例详解和伪静态功能实现(apache重写)
2014/02/24 PHP
PHP中SimpleXML函数用法分析
2014/11/26 PHP
PHP遍历数组的三种方法及效率对比分析
2015/02/12 PHP
PHP5.5安装PHPRedis扩展及连接测试方法
2017/01/22 PHP
thinkPHP5框架闭包函数与子查询传参用法示例
2018/08/02 PHP
laravel框架模型中非静态方法也能静态调用的原理分析
2019/11/23 PHP
HR vs CL BO3 第二场 2.13
2021/03/10 DOTA
选择TreeView控件的树状数据节点的JS方法(jquery)
2010/02/06 Javascript
jQuery Selector选择器小结
2010/05/06 Javascript
5个JavaScript经典面试题
2014/10/13 Javascript
jquery实现标签上移、下移、置顶
2015/04/26 Javascript
javascript实现textarea中tab键的缩排处理方法
2015/06/26 Javascript
Vue-cli项目获取本地json文件数据的实例
2018/03/07 Javascript
微信小程序项目总结之点赞 删除列表 分享功能
2018/06/25 Javascript
详解nodejs解压版安装和配置(带有搭建前端项目脚手架)
2018/12/06 NodeJs
使用Vue中 v-for循环列表控制按钮隐藏显示功能
2019/04/23 Javascript
基于layPage插件实现两种分页方式浅析
2019/07/27 Javascript
vue中配置scss全局变量的步骤
2020/12/28 Vue.js
python数据结构树和二叉树简介
2014/04/29 Python
初步介绍Python中的pydoc模块和distutils模块
2015/04/13 Python
Python聚类算法之基本K均值实例详解
2015/11/20 Python
python shell根据ip获取主机名代码示例
2017/11/25 Python
Python实现连接postgresql数据库的方法分析
2017/12/27 Python
python实现单链表中删除倒数第K个节点的方法
2018/09/28 Python
Python使用pyserial进行串口通信的实例
2019/07/02 Python
django中瀑布流写法实例代码
2019/10/14 Python
Python面向对象之多态原理与用法案例分析
2019/12/30 Python
python随机模块random使用方法详解
2020/02/14 Python
css3编写浏览器背景渐变背景色的方法
2018/03/05 HTML / CSS
KIEHL’S科颜氏官方旗舰店:源自美国的顶级护肤品牌
2018/06/07 全球购物
《草原的早晨》教学反思
2014/04/08 职场文书
经销商年会策划方案
2014/05/29 职场文书
小学六年级毕业感言
2015/07/30 职场文书
世界上超棒的8种逻辑思维
2019/08/06 职场文书