解决Pytorch修改预训练模型时遇到key不匹配的情况


Posted in Python onJune 05, 2021

一、Pytorch修改预训练模型时遇到key不匹配

最近想着修改网络的预训练模型vgg.pth,但是发现当我加载预训练模型权重到新建的模型并保存之后。

在我使用新赋值的网络模型时出现了key不匹配的问题

#加载后保存(未修改网络)
base_weights = torch.load(args.save_folder + args.basenet)
ssd_net.vgg.load_state_dict(base_weights) 
torch.save(ssd_net.state_dict(), args.save_folder + 'ssd_base' + '.pth')
# 将新保存的网络代替之前的预训练模型
    ssd_net = build_ssd('train', cfg['min_dim'], cfg['num_classes'])
    net = ssd_net
    ...
    if args.resume:
        ...
    else:
        base_weights = torch.load(args.save_folder + args.basenet)
        #args.basenet为ssd_base.pth
        print('Loading base network...')
        ssd_net.vgg.load_state_dict(base_weights)

此时会如下出错误:

Loading base network…
Traceback (most recent call last):
File “train.py”, line 264, in
train()
File “train.py”, line 110, in train
ssd_net.vgg.load_state_dict(base_weights)

RuntimeError: Error(s) in loading state_dict for ModuleList:
Missing key(s) in state_dict: “0.weight”, “0.bias”, … “33.weight”, “33.bias”.
Unexpected key(s) in state_dict: “vgg.0.weight”, “vgg.0.bias”, … “vgg.33.weight”, “vgg.33.bias”.

说明之前的预训练模型 key参数为"0.weight", “0.bias”,但是经过加载保存之后变为了"vgg.0.weight", “vgg.0.bias”

我认为是因为本身的模型定义文件里self.vgg = nn.ModuleList(base)这一句。

现在的问题是因为自己定义保存的模型key参数多了一个前缀。

可以通过如下语句进行修改,并加载

from collections import OrderedDict   #导入此模块
base_weights = torch.load(args.save_folder + args.basenet)
print('Loading base network...')
new_state_dict = **OrderedDict()**  
for k, v in base_weights.items():
    name = k[4:]   # remove `vgg.`,即只取vgg.0.weights的后面几位
    new_state_dict[name] = v 
    ssd_net.vgg.load_state_dict(new_state_dict)

此时就不会再出错了。

参考了这个篇。修改一下就可以应用到自己的模型啦。

//www.3water.com/article/214214.htm

二、pytorch加载预训练模型遇到的问题:KeyError: ‘bn1.num_batches_tracked‘

最近在使用pytorch1.0加载resnet预训练模型时,遇到的一个问题,在此记录一下。

KeyError: 'layer1.0.bn1.num_batches_tracked'

其实是使用的版本的问题,pytorch0.4.1之后在BN层加入了track_running_stats这个参数,

这个参数的作用如下:

训练时用来统计训练时的forward过的min-batch数目,每经过一个min-batch, track_running_stats+=1

如果没有指定momentum, 则使用1/num_batches_tracked 作为因数来计算均值和方差(running mean and variance).

其实,这个参数没啥用.但因为官方提供的预训练模型是pytorch0.3版本训练出来的,因此没有这个参数.

所以,只要过滤一下预训练权重字典中的关键字即可,‘num_batches_tracked'.代码例子,如下.

有问题的代码:

def load_specific_param(self, state_dict, param_name, model_path):
        param_dict = torch.load(model_path)
        for i in state_dict:
            key = param_name + '.' + i
            state_dict[i].copy_(param_dict[key])
        del param_dict

对'num_batches_tracked进行过滤:

def load_specific_param(self, state_dict, param_name, model_path):
        param_dict = torch.load(model_path)
        param_dict = {k: v for k, v in param_dict.items() if 'num_batches_tracked' not in k}
        for i in state_dict:
            key = param_name + '.' + i
            if 'num_batches_tracked' in key:
                continue
            state_dict[i].copy_(param_dict[key])
        del param_dict

以上为个人经验,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
Python os模块介绍
Nov 30 Python
python 链接和操作 memcache方法
Mar 04 Python
给你选择Python语言实现机器学习算法的三大理由
Nov 15 Python
详解Django之auth模块(用户认证)
Apr 17 Python
关于python中密码加盐的学习体会小结
Jul 15 Python
Pytorch反向求导更新网络参数的方法
Aug 17 Python
Python大数据之从网页上爬取数据的方法详解
Nov 16 Python
python实现画出e指数函数的图像
Nov 21 Python
Python:二维列表下标互换方式(矩阵转置)
Dec 02 Python
Python imageio读取视频并进行编解码详解
Dec 10 Python
将pytorch转成longtensor的简单方法
Feb 18 Python
Python函数对象与闭包函数
Apr 13 Python
pytorch 预训练模型读取修改相关参数的填坑问题
Jun 05 #Python
解决pytorch 损失函数中输入输出不匹配的问题
Jun 05 #Python
Pytorch distributed 多卡并行载入模型操作
Jun 05 #Python
Pytorch中的学习率衰减及其用法详解
Jun 05 #Python
pytorch finetuning 自己的图片进行训练操作
Jun 05 #Python
Python 如何将integer转化为罗马数(3999以内)
Jun 05 #Python
刚学完怎么用Python实现定时任务,转头就跑去撩妹!
You might like
php ckeditor上传图片文件名乱码解决方法
2013/11/15 PHP
php中执行系统命令的方法
2015/03/21 PHP
PHP统一页面编码避免乱码问题
2015/04/09 PHP
Yii 框架使用Forms操作详解
2020/05/18 PHP
JQuery 动态扩展对象之另类视角
2010/05/25 Javascript
jQuery链式操作如何实现以及为什么要用链式操作
2013/01/17 Javascript
Jquery封装tab自动切换效果的具体实现
2013/07/13 Javascript
Knockout text绑定DOM的使用方法
2013/11/15 Javascript
jquery对ajax的支持介绍
2013/12/10 Javascript
ExtJS4如何给同一个formpanel不同的url
2014/05/02 Javascript
判断复选框是否被选中的两种方法
2014/06/04 Javascript
基于jquery实现的可编辑下拉框实现代码
2014/08/02 Javascript
让图片跳跃起来  javascript图片轮播特效
2016/02/16 Javascript
Backbone View 之间通信的三种方式
2016/08/09 Javascript
js上传图片预览的实现方法
2017/05/09 Javascript
浅谈webpack打包之后的文件过大的解决方法
2018/03/07 Javascript
echarts鼠标覆盖高亮显示节点及关系名称详解
2018/03/17 Javascript
webpack4.x下babel的安装、配置及使用详解
2019/03/07 Javascript
微信小程序如何播放腾讯视频的实现
2019/09/20 Javascript
javascript canvas API内容整理
2020/02/16 Javascript
[03:40]DOTA2英雄梦之声_第01期_炼金术士
2014/06/23 DOTA
Python二维码生成库qrcode安装和使用示例
2014/12/16 Python
浅谈Python基础之I/O模型
2017/05/11 Python
Python pyinotify日志监控系统处理日志的方法
2018/03/08 Python
django 自定义过滤器(filter)处理较为复杂的变量方法
2019/08/12 Python
Python集合基本概念与相关操作实例分析
2019/10/30 Python
python socket通信编程实现文件上传代码实例
2019/12/14 Python
Pycharm debug调试时带参数过程解析
2020/02/03 Python
Tensorflow训练模型越来越慢的2种解决方案
2020/02/07 Python
英国休闲奢华的缩影:Crew Clothing
2019/05/05 全球购物
商场促销活动方案
2014/02/08 职场文书
数学与统计学院学生个人职业生涯规划书
2014/02/10 职场文书
医学专业毕业生推荐信
2014/07/12 职场文书
计划生育诚信协议书
2014/11/02 职场文书
2014年英语教师工作总结
2014/12/03 职场文书
MYSQL数据库使用UTF-8中文编码乱码的解决办法
2021/05/26 MySQL