解决Pytorch修改预训练模型时遇到key不匹配的情况


Posted in Python onJune 05, 2021

一、Pytorch修改预训练模型时遇到key不匹配

最近想着修改网络的预训练模型vgg.pth,但是发现当我加载预训练模型权重到新建的模型并保存之后。

在我使用新赋值的网络模型时出现了key不匹配的问题

#加载后保存(未修改网络)
base_weights = torch.load(args.save_folder + args.basenet)
ssd_net.vgg.load_state_dict(base_weights) 
torch.save(ssd_net.state_dict(), args.save_folder + 'ssd_base' + '.pth')
# 将新保存的网络代替之前的预训练模型
    ssd_net = build_ssd('train', cfg['min_dim'], cfg['num_classes'])
    net = ssd_net
    ...
    if args.resume:
        ...
    else:
        base_weights = torch.load(args.save_folder + args.basenet)
        #args.basenet为ssd_base.pth
        print('Loading base network...')
        ssd_net.vgg.load_state_dict(base_weights)

此时会如下出错误:

Loading base network…
Traceback (most recent call last):
File “train.py”, line 264, in
train()
File “train.py”, line 110, in train
ssd_net.vgg.load_state_dict(base_weights)

RuntimeError: Error(s) in loading state_dict for ModuleList:
Missing key(s) in state_dict: “0.weight”, “0.bias”, … “33.weight”, “33.bias”.
Unexpected key(s) in state_dict: “vgg.0.weight”, “vgg.0.bias”, … “vgg.33.weight”, “vgg.33.bias”.

说明之前的预训练模型 key参数为"0.weight", “0.bias”,但是经过加载保存之后变为了"vgg.0.weight", “vgg.0.bias”

我认为是因为本身的模型定义文件里self.vgg = nn.ModuleList(base)这一句。

现在的问题是因为自己定义保存的模型key参数多了一个前缀。

可以通过如下语句进行修改,并加载

from collections import OrderedDict   #导入此模块
base_weights = torch.load(args.save_folder + args.basenet)
print('Loading base network...')
new_state_dict = **OrderedDict()**  
for k, v in base_weights.items():
    name = k[4:]   # remove `vgg.`,即只取vgg.0.weights的后面几位
    new_state_dict[name] = v 
    ssd_net.vgg.load_state_dict(new_state_dict)

此时就不会再出错了。

参考了这个篇。修改一下就可以应用到自己的模型啦。

//www.3water.com/article/214214.htm

二、pytorch加载预训练模型遇到的问题:KeyError: ‘bn1.num_batches_tracked‘

最近在使用pytorch1.0加载resnet预训练模型时,遇到的一个问题,在此记录一下。

KeyError: 'layer1.0.bn1.num_batches_tracked'

其实是使用的版本的问题,pytorch0.4.1之后在BN层加入了track_running_stats这个参数,

这个参数的作用如下:

训练时用来统计训练时的forward过的min-batch数目,每经过一个min-batch, track_running_stats+=1

如果没有指定momentum, 则使用1/num_batches_tracked 作为因数来计算均值和方差(running mean and variance).

其实,这个参数没啥用.但因为官方提供的预训练模型是pytorch0.3版本训练出来的,因此没有这个参数.

所以,只要过滤一下预训练权重字典中的关键字即可,‘num_batches_tracked'.代码例子,如下.

有问题的代码:

def load_specific_param(self, state_dict, param_name, model_path):
        param_dict = torch.load(model_path)
        for i in state_dict:
            key = param_name + '.' + i
            state_dict[i].copy_(param_dict[key])
        del param_dict

对'num_batches_tracked进行过滤:

def load_specific_param(self, state_dict, param_name, model_path):
        param_dict = torch.load(model_path)
        param_dict = {k: v for k, v in param_dict.items() if 'num_batches_tracked' not in k}
        for i in state_dict:
            key = param_name + '.' + i
            if 'num_batches_tracked' in key:
                continue
            state_dict[i].copy_(param_dict[key])
        del param_dict

以上为个人经验,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
python数据结构之图的实现方法
Jul 08 Python
python函数局部变量用法实例分析
Aug 04 Python
详谈Python2.6和Python3.0中对除法操作的异同
Apr 28 Python
Python3的介绍、安装和命令行的认识(推荐)
Oct 20 Python
Python实现简单的列表冒泡排序和反转列表操作示例
Jul 10 Python
Django 路由控制的实现
Jul 17 Python
django多种支付、并发订单处理实例代码
Dec 13 Python
pycharm 激活码及使用方式的详细教程
May 12 Python
python实现mask矩阵示例(根据列表所给元素)
Jul 30 Python
python pymysql库的常用操作
Oct 16 Python
python向xls写入数据(包括合并,边框,对齐,列宽)
Feb 02 Python
Python使用OpenCV和K-Means聚类对毕业照进行图像分割
Jun 11 Python
pytorch 预训练模型读取修改相关参数的填坑问题
Jun 05 #Python
解决pytorch 损失函数中输入输出不匹配的问题
Jun 05 #Python
Pytorch distributed 多卡并行载入模型操作
Jun 05 #Python
Pytorch中的学习率衰减及其用法详解
Jun 05 #Python
pytorch finetuning 自己的图片进行训练操作
Jun 05 #Python
Python 如何将integer转化为罗马数(3999以内)
Jun 05 #Python
刚学完怎么用Python实现定时任务,转头就跑去撩妹!
You might like
模板引擎Smarty深入浅出介绍
2006/12/06 PHP
php 将bmp图片转为jpg等其他任意格式的图片
2009/06/21 PHP
php递归创建和删除文件夹的代码小结
2012/04/13 PHP
使用PHP实现Mysql读写分离
2013/06/28 PHP
php支付宝接口用法分析
2015/01/04 PHP
PHP实现的随机IP函数【国内IP段】
2016/07/20 PHP
PHP实现的无限分类类库定义与用法示例【基于thinkPHP】
2018/08/06 PHP
PHP5.6读写excel表格文件操作示例
2019/02/26 PHP
Nigma vs AM BO3 第一场2.13
2021/03/10 DOTA
iframe子页面获取父页面元素的方法
2013/11/05 Javascript
JavaScript实现存储HTML字符串示例
2014/04/21 Javascript
JavaScript实现信用卡校验方法
2015/04/07 Javascript
第六章之辅组类与响应式工具
2016/04/25 Javascript
jQuery封装的屏幕居中提示信息代码
2016/06/08 Javascript
vue拖拽排序插件vuedraggable使用方法详解
2020/08/21 Javascript
解决Vue router-link绑定事件不生效的问题
2020/07/22 Javascript
VUE实时监听元素距离顶部高度的操作
2020/07/29 Javascript
[00:35]DOTA2上海特级锦标赛 Newbee战队宣传片
2016/03/03 DOTA
[03:41]DOTA2上海特锦赛小组赛第三日recap精彩回顾
2016/02/28 DOTA
Python实现多线程抓取妹子图
2015/08/08 Python
python实现生命游戏的示例代码(Game of Life)
2018/01/24 Python
对python中的*args与**kwgs的含义与作用详解
2019/08/28 Python
pytorch之inception_v3的实现案例
2020/01/06 Python
Matplotlib自定义坐标轴刻度的实现示例
2020/06/18 Python
请介绍一下Ant
2016/07/22 面试题
法人委托书范本
2014/04/04 职场文书
房地产推广策划方案
2014/05/19 职场文书
工作时间擅自离岗检讨书
2014/10/24 职场文书
谢师宴答谢词
2015/01/05 职场文书
婚礼庆典答谢词
2015/01/20 职场文书
化验员岗位职责
2015/02/14 职场文书
幼儿园见习总结
2015/06/23 职场文书
校运会新闻稿
2015/07/17 职场文书
安全责任协议书范本
2016/03/23 职场文书
幼儿园教师辞职信
2019/06/21 职场文书
MySQL 外键约束和表关系相关总结
2021/06/20 MySQL