解决Pytorch修改预训练模型时遇到key不匹配的情况


Posted in Python onJune 05, 2021

一、Pytorch修改预训练模型时遇到key不匹配

最近想着修改网络的预训练模型vgg.pth,但是发现当我加载预训练模型权重到新建的模型并保存之后。

在我使用新赋值的网络模型时出现了key不匹配的问题

#加载后保存(未修改网络)
base_weights = torch.load(args.save_folder + args.basenet)
ssd_net.vgg.load_state_dict(base_weights) 
torch.save(ssd_net.state_dict(), args.save_folder + 'ssd_base' + '.pth')
# 将新保存的网络代替之前的预训练模型
    ssd_net = build_ssd('train', cfg['min_dim'], cfg['num_classes'])
    net = ssd_net
    ...
    if args.resume:
        ...
    else:
        base_weights = torch.load(args.save_folder + args.basenet)
        #args.basenet为ssd_base.pth
        print('Loading base network...')
        ssd_net.vgg.load_state_dict(base_weights)

此时会如下出错误:

Loading base network…
Traceback (most recent call last):
File “train.py”, line 264, in
train()
File “train.py”, line 110, in train
ssd_net.vgg.load_state_dict(base_weights)

RuntimeError: Error(s) in loading state_dict for ModuleList:
Missing key(s) in state_dict: “0.weight”, “0.bias”, … “33.weight”, “33.bias”.
Unexpected key(s) in state_dict: “vgg.0.weight”, “vgg.0.bias”, … “vgg.33.weight”, “vgg.33.bias”.

说明之前的预训练模型 key参数为"0.weight", “0.bias”,但是经过加载保存之后变为了"vgg.0.weight", “vgg.0.bias”

我认为是因为本身的模型定义文件里self.vgg = nn.ModuleList(base)这一句。

现在的问题是因为自己定义保存的模型key参数多了一个前缀。

可以通过如下语句进行修改,并加载

from collections import OrderedDict   #导入此模块
base_weights = torch.load(args.save_folder + args.basenet)
print('Loading base network...')
new_state_dict = **OrderedDict()**  
for k, v in base_weights.items():
    name = k[4:]   # remove `vgg.`,即只取vgg.0.weights的后面几位
    new_state_dict[name] = v 
    ssd_net.vgg.load_state_dict(new_state_dict)

此时就不会再出错了。

参考了这个篇。修改一下就可以应用到自己的模型啦。

//www.3water.com/article/214214.htm

二、pytorch加载预训练模型遇到的问题:KeyError: ‘bn1.num_batches_tracked‘

最近在使用pytorch1.0加载resnet预训练模型时,遇到的一个问题,在此记录一下。

KeyError: 'layer1.0.bn1.num_batches_tracked'

其实是使用的版本的问题,pytorch0.4.1之后在BN层加入了track_running_stats这个参数,

这个参数的作用如下:

训练时用来统计训练时的forward过的min-batch数目,每经过一个min-batch, track_running_stats+=1

如果没有指定momentum, 则使用1/num_batches_tracked 作为因数来计算均值和方差(running mean and variance).

其实,这个参数没啥用.但因为官方提供的预训练模型是pytorch0.3版本训练出来的,因此没有这个参数.

所以,只要过滤一下预训练权重字典中的关键字即可,‘num_batches_tracked'.代码例子,如下.

有问题的代码:

def load_specific_param(self, state_dict, param_name, model_path):
        param_dict = torch.load(model_path)
        for i in state_dict:
            key = param_name + '.' + i
            state_dict[i].copy_(param_dict[key])
        del param_dict

对'num_batches_tracked进行过滤:

def load_specific_param(self, state_dict, param_name, model_path):
        param_dict = torch.load(model_path)
        param_dict = {k: v for k, v in param_dict.items() if 'num_batches_tracked' not in k}
        for i in state_dict:
            key = param_name + '.' + i
            if 'num_batches_tracked' in key:
                continue
            state_dict[i].copy_(param_dict[key])
        del param_dict

以上为个人经验,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
python集合用法实例分析
May 30 Python
Python中.py文件打包成exe可执行文件详解
Mar 22 Python
使用Python的package机制如何简化utils包设计详解
Dec 11 Python
numpy中索引和切片详解
Dec 15 Python
Python使用matplotlib实现的图像读取、切割裁剪功能示例
Apr 28 Python
Python3.5内置模块之os模块、sys模块、shutil模块用法实例分析
Apr 27 Python
python循环嵌套的多种使用方法解析
Nov 29 Python
Pytorch 之修改Tensor部分值方式
Dec 27 Python
pytorch 图像预处理之减去均值,除以方差的实例
Jan 02 Python
Windows下实现将Pascal VOC转化为TFRecords
Feb 17 Python
Python基于百度AI实现OCR文字识别
Apr 02 Python
10行Python代码实现Web自动化管控的示例代码
Aug 14 Python
pytorch 预训练模型读取修改相关参数的填坑问题
Jun 05 #Python
解决pytorch 损失函数中输入输出不匹配的问题
Jun 05 #Python
Pytorch distributed 多卡并行载入模型操作
Jun 05 #Python
Pytorch中的学习率衰减及其用法详解
Jun 05 #Python
pytorch finetuning 自己的图片进行训练操作
Jun 05 #Python
Python 如何将integer转化为罗马数(3999以内)
Jun 05 #Python
刚学完怎么用Python实现定时任务,转头就跑去撩妹!
You might like
用PHP函数解决SQL injection
2006/12/09 PHP
PHP中改变图片的尺寸大小的代码
2011/07/17 PHP
php&mysql 日期操作小记
2012/02/27 PHP
php实现二进制和文本相互转换的方法
2015/04/18 PHP
基于laravel Request的所有方法详解
2019/09/29 PHP
浅析JS刷新框架中的其他页面 && JS刷新窗口方法汇总
2013/07/08 Javascript
JavaScript打印iframe内容示例代码
2013/08/20 Javascript
NodeJS实现阿里大鱼短信通知发送
2016/01/17 NodeJs
BootStrap文件上传样式超好看【持续更新】
2016/05/10 Javascript
Bootstrap编写一个兼容主流浏览器的受众巨幕式风格页面
2016/07/01 Javascript
BootStrap使用file-input插件上传图片的方法
2016/09/05 Javascript
JS中实现函数return多个返回值的实例
2017/02/21 Javascript
Vue.js中的computed工作原理
2018/03/22 Javascript
vue中的router-view组件的使用教程
2018/10/23 Javascript
JavaScript遍历DOM元素的常见方式示例
2019/02/16 Javascript
浅谈对于“不用setInterval,用setTimeout”的理解
2019/08/28 Javascript
Python中__new__与__init__方法的区别详解
2015/05/04 Python
python实现解数独程序代码
2017/04/12 Python
python 去除txt文本中的空格、数字、特定字母等方法
2018/07/24 Python
python绘制地震散点图
2019/06/18 Python
python快速编写单行注释多行注释的方法
2019/07/31 Python
tensorflow没有output结点,存储成pb文件的例子
2020/01/04 Python
Python尾递归优化实现代码及原理详解
2020/10/09 Python
python实现杨辉三角的几种方法代码实例
2021/03/02 Python
CSS3 icon font完全指南(CSS3 font 会取代icon图标)
2013/01/06 HTML / CSS
HTML5语义化元素你真的用对了吗
2019/08/22 HTML / CSS
林清轩官方网站:山茶花润肤油开创者
2016/10/26 全球购物
Alexandre Birman美国官网:亚历山大·伯曼
2019/10/30 全球购物
Diesel美国网上商店:意大利牛仔时装品牌
2020/12/10 全球购物
儿科护士自我鉴定
2013/10/14 职场文书
学术会议邀请函范文
2014/01/22 职场文书
期末自我鉴定
2014/01/23 职场文书
工厂仓管员岗位职责范本
2014/07/17 职场文书
临时用工协议书范本
2014/10/29 职场文书
详解MySQL 用户权限管理
2021/04/20 MySQL
html5中sharedWorker实现多页面通信的示例代码
2021/05/07 Javascript