解决Pytorch修改预训练模型时遇到key不匹配的情况


Posted in Python onJune 05, 2021

一、Pytorch修改预训练模型时遇到key不匹配

最近想着修改网络的预训练模型vgg.pth,但是发现当我加载预训练模型权重到新建的模型并保存之后。

在我使用新赋值的网络模型时出现了key不匹配的问题

#加载后保存(未修改网络)
base_weights = torch.load(args.save_folder + args.basenet)
ssd_net.vgg.load_state_dict(base_weights) 
torch.save(ssd_net.state_dict(), args.save_folder + 'ssd_base' + '.pth')
# 将新保存的网络代替之前的预训练模型
    ssd_net = build_ssd('train', cfg['min_dim'], cfg['num_classes'])
    net = ssd_net
    ...
    if args.resume:
        ...
    else:
        base_weights = torch.load(args.save_folder + args.basenet)
        #args.basenet为ssd_base.pth
        print('Loading base network...')
        ssd_net.vgg.load_state_dict(base_weights)

此时会如下出错误:

Loading base network…
Traceback (most recent call last):
File “train.py”, line 264, in
train()
File “train.py”, line 110, in train
ssd_net.vgg.load_state_dict(base_weights)

RuntimeError: Error(s) in loading state_dict for ModuleList:
Missing key(s) in state_dict: “0.weight”, “0.bias”, … “33.weight”, “33.bias”.
Unexpected key(s) in state_dict: “vgg.0.weight”, “vgg.0.bias”, … “vgg.33.weight”, “vgg.33.bias”.

说明之前的预训练模型 key参数为"0.weight", “0.bias”,但是经过加载保存之后变为了"vgg.0.weight", “vgg.0.bias”

我认为是因为本身的模型定义文件里self.vgg = nn.ModuleList(base)这一句。

现在的问题是因为自己定义保存的模型key参数多了一个前缀。

可以通过如下语句进行修改,并加载

from collections import OrderedDict   #导入此模块
base_weights = torch.load(args.save_folder + args.basenet)
print('Loading base network...')
new_state_dict = **OrderedDict()**  
for k, v in base_weights.items():
    name = k[4:]   # remove `vgg.`,即只取vgg.0.weights的后面几位
    new_state_dict[name] = v 
    ssd_net.vgg.load_state_dict(new_state_dict)

此时就不会再出错了。

参考了这个篇。修改一下就可以应用到自己的模型啦。

//www.3water.com/article/214214.htm

二、pytorch加载预训练模型遇到的问题:KeyError: ‘bn1.num_batches_tracked‘

最近在使用pytorch1.0加载resnet预训练模型时,遇到的一个问题,在此记录一下。

KeyError: 'layer1.0.bn1.num_batches_tracked'

其实是使用的版本的问题,pytorch0.4.1之后在BN层加入了track_running_stats这个参数,

这个参数的作用如下:

训练时用来统计训练时的forward过的min-batch数目,每经过一个min-batch, track_running_stats+=1

如果没有指定momentum, 则使用1/num_batches_tracked 作为因数来计算均值和方差(running mean and variance).

其实,这个参数没啥用.但因为官方提供的预训练模型是pytorch0.3版本训练出来的,因此没有这个参数.

所以,只要过滤一下预训练权重字典中的关键字即可,‘num_batches_tracked'.代码例子,如下.

有问题的代码:

def load_specific_param(self, state_dict, param_name, model_path):
        param_dict = torch.load(model_path)
        for i in state_dict:
            key = param_name + '.' + i
            state_dict[i].copy_(param_dict[key])
        del param_dict

对'num_batches_tracked进行过滤:

def load_specific_param(self, state_dict, param_name, model_path):
        param_dict = torch.load(model_path)
        param_dict = {k: v for k, v in param_dict.items() if 'num_batches_tracked' not in k}
        for i in state_dict:
            key = param_name + '.' + i
            if 'num_batches_tracked' in key:
                continue
            state_dict[i].copy_(param_dict[key])
        del param_dict

以上为个人经验,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
使用Python判断IP地址合法性的方法实例
Mar 13 Python
11个并不被常用但对开发非常有帮助的Python库
Mar 31 Python
python打开url并按指定块读取网页内容的方法
Apr 29 Python
使用python实现rsa算法代码
Feb 17 Python
判断网页编码的方法python版
Aug 12 Python
使用Python对Csv文件操作实例代码
May 12 Python
基于pycharm导入模块显示不存在的解决方法
Oct 13 Python
记一次python 内存泄漏问题及解决过程
Nov 29 Python
python实现文件的分割与合并
Aug 29 Python
TensorBoard 计算图的查看方式
Feb 15 Python
浅析pip安装第三方库及pycharm中导入第三方库的问题
Mar 10 Python
Python数据分析之pandas读取数据
Jun 02 Python
pytorch 预训练模型读取修改相关参数的填坑问题
Jun 05 #Python
解决pytorch 损失函数中输入输出不匹配的问题
Jun 05 #Python
Pytorch distributed 多卡并行载入模型操作
Jun 05 #Python
Pytorch中的学习率衰减及其用法详解
Jun 05 #Python
pytorch finetuning 自己的图片进行训练操作
Jun 05 #Python
Python 如何将integer转化为罗马数(3999以内)
Jun 05 #Python
刚学完怎么用Python实现定时任务,转头就跑去撩妹!
You might like
example2.php
2006/10/09 PHP
php递归方法实现无限分类实例代码
2014/02/28 PHP
ThinkPHP学习笔记(一)ThinkPHP部署
2014/06/22 PHP
PHP移动文件指针ftell()、fseek()、rewind()函数总结
2014/11/18 PHP
PHP中is_dir()函数使用指南
2015/05/08 PHP
php 实现301重定向跳转实例代码
2016/07/18 PHP
php curl批处理实现可控并发异步操作示例
2018/05/09 PHP
Laravel5.1框架路由分组用法实例分析
2020/01/04 PHP
Jquery的each里用return true或false代替break或continue
2014/05/21 Javascript
一款基于jQuery的图片场景标注提示弹窗特效
2015/01/05 Javascript
深入浅析javascript立即执行函数
2015/10/23 Javascript
无缝滚动的简单实现代码(推荐)
2016/06/07 Javascript
js实现百度地图定位于地址逆解析,显示自己当前的地理位置
2016/12/08 Javascript
AngularJS入门教程之路由机制ngRoute实例分析
2016/12/13 Javascript
vue.js学习之vue-cli定制脚手架详解
2017/07/02 Javascript
关于Angularjs中自定义指令一些有价值的细节和技巧小结
2018/04/22 Javascript
Vue手把手教你撸一个 beforeEnter 钩子函数
2018/04/24 Javascript
layui动态绑定事件的方法
2019/09/20 Javascript
vue多页面项目中路由使用history模式的方法
2019/09/23 Javascript
vue-cli中实现响应式布局的方法
2021/03/02 Vue.js
[57:59]完美世界DOTA2联赛循环赛 Ink Ice vs LBZS BO2第一场 11.05
2020/11/05 DOTA
Python中每次处理一个字符的5种方法
2015/05/21 Python
python文件操作之批量修改文件后缀名的方法
2018/08/10 Python
python编写俄罗斯方块
2020/03/13 Python
django实现将修改好的新模型写入数据库
2020/03/31 Python
Python多线程:主线程等待所有子线程结束代码
2020/04/25 Python
html5 canvas手势解锁源码分享
2020/01/07 HTML / CSS
Ancheer官方户外和运动商店:销售电动自行车
2019/08/07 全球购物
银行会计职员个人的自我评价
2013/09/29 职场文书
学期研究性学习个人的自我评价
2014/01/09 职场文书
求职信模板标准格式范文
2014/02/23 职场文书
个人授权委托书格式
2014/08/30 职场文书
研究生论文答辩开场白
2015/05/27 职场文书
新教师教学工作总结
2015/08/14 职场文书
SQL 尚未定义空闲 CPU 条件 - OnIdle 作业计划将不起任何作用
2021/06/30 SQL Server
vue如何使用模拟的json数据查看效果
2022/03/31 Vue.js