解决Pytorch修改预训练模型时遇到key不匹配的情况


Posted in Python onJune 05, 2021

一、Pytorch修改预训练模型时遇到key不匹配

最近想着修改网络的预训练模型vgg.pth,但是发现当我加载预训练模型权重到新建的模型并保存之后。

在我使用新赋值的网络模型时出现了key不匹配的问题

#加载后保存(未修改网络)
base_weights = torch.load(args.save_folder + args.basenet)
ssd_net.vgg.load_state_dict(base_weights) 
torch.save(ssd_net.state_dict(), args.save_folder + 'ssd_base' + '.pth')
# 将新保存的网络代替之前的预训练模型
    ssd_net = build_ssd('train', cfg['min_dim'], cfg['num_classes'])
    net = ssd_net
    ...
    if args.resume:
        ...
    else:
        base_weights = torch.load(args.save_folder + args.basenet)
        #args.basenet为ssd_base.pth
        print('Loading base network...')
        ssd_net.vgg.load_state_dict(base_weights)

此时会如下出错误:

Loading base network…
Traceback (most recent call last):
File “train.py”, line 264, in
train()
File “train.py”, line 110, in train
ssd_net.vgg.load_state_dict(base_weights)

RuntimeError: Error(s) in loading state_dict for ModuleList:
Missing key(s) in state_dict: “0.weight”, “0.bias”, … “33.weight”, “33.bias”.
Unexpected key(s) in state_dict: “vgg.0.weight”, “vgg.0.bias”, … “vgg.33.weight”, “vgg.33.bias”.

说明之前的预训练模型 key参数为"0.weight", “0.bias”,但是经过加载保存之后变为了"vgg.0.weight", “vgg.0.bias”

我认为是因为本身的模型定义文件里self.vgg = nn.ModuleList(base)这一句。

现在的问题是因为自己定义保存的模型key参数多了一个前缀。

可以通过如下语句进行修改,并加载

from collections import OrderedDict   #导入此模块
base_weights = torch.load(args.save_folder + args.basenet)
print('Loading base network...')
new_state_dict = **OrderedDict()**  
for k, v in base_weights.items():
    name = k[4:]   # remove `vgg.`,即只取vgg.0.weights的后面几位
    new_state_dict[name] = v 
    ssd_net.vgg.load_state_dict(new_state_dict)

此时就不会再出错了。

参考了这个篇。修改一下就可以应用到自己的模型啦。

//www.3water.com/article/214214.htm

二、pytorch加载预训练模型遇到的问题:KeyError: ‘bn1.num_batches_tracked‘

最近在使用pytorch1.0加载resnet预训练模型时,遇到的一个问题,在此记录一下。

KeyError: 'layer1.0.bn1.num_batches_tracked'

其实是使用的版本的问题,pytorch0.4.1之后在BN层加入了track_running_stats这个参数,

这个参数的作用如下:

训练时用来统计训练时的forward过的min-batch数目,每经过一个min-batch, track_running_stats+=1

如果没有指定momentum, 则使用1/num_batches_tracked 作为因数来计算均值和方差(running mean and variance).

其实,这个参数没啥用.但因为官方提供的预训练模型是pytorch0.3版本训练出来的,因此没有这个参数.

所以,只要过滤一下预训练权重字典中的关键字即可,‘num_batches_tracked'.代码例子,如下.

有问题的代码:

def load_specific_param(self, state_dict, param_name, model_path):
        param_dict = torch.load(model_path)
        for i in state_dict:
            key = param_name + '.' + i
            state_dict[i].copy_(param_dict[key])
        del param_dict

对'num_batches_tracked进行过滤:

def load_specific_param(self, state_dict, param_name, model_path):
        param_dict = torch.load(model_path)
        param_dict = {k: v for k, v in param_dict.items() if 'num_batches_tracked' not in k}
        for i in state_dict:
            key = param_name + '.' + i
            if 'num_batches_tracked' in key:
                continue
            state_dict[i].copy_(param_dict[key])
        del param_dict

以上为个人经验,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
python连接池实现示例程序
Nov 26 Python
Django中几种重定向方法
Apr 28 Python
Python读取指定目录下指定后缀文件并保存为docx
Apr 23 Python
理解Python中的绝对路径和相对路径
Aug 30 Python
Python数据结构之单链表详解
Sep 12 Python
python的文件操作方法汇总
Nov 10 Python
tensorflow 获取变量&打印权值的实例讲解
Jun 14 Python
python-opencv颜色提取分割方法
Dec 08 Python
Python控制键盘鼠标pynput的详细用法
Jan 28 Python
python画图--输出指定像素点的颜色值方法
Jul 03 Python
win10安装tesserocr配置 Python使用tesserocr识别字母数字验证码
Jan 16 Python
Python 实现日志同时输出到屏幕和文件
Feb 19 Python
pytorch 预训练模型读取修改相关参数的填坑问题
Jun 05 #Python
解决pytorch 损失函数中输入输出不匹配的问题
Jun 05 #Python
Pytorch distributed 多卡并行载入模型操作
Jun 05 #Python
Pytorch中的学习率衰减及其用法详解
Jun 05 #Python
pytorch finetuning 自己的图片进行训练操作
Jun 05 #Python
Python 如何将integer转化为罗马数(3999以内)
Jun 05 #Python
刚学完怎么用Python实现定时任务,转头就跑去撩妹!
You might like
提升PHP执行速度全攻略(下)
2006/10/09 PHP
PHP获得用户使用的代理服务器ip即真实ip
2006/12/31 PHP
PHP mkdir()定义和用法
2009/01/14 PHP
php设计模式 Builder(建造者模式)
2011/06/26 PHP
PHP运用foreach神奇的转换数组(实例讲解)
2018/02/01 PHP
YII2框架中添加自定义模块的方法实例分析
2020/03/18 PHP
2020最新版 PhpStudy V8.1版本下载安装使用详解
2020/10/30 PHP
JavaScript中的Window窗口对象
2008/01/16 Javascript
动态表格Table类的实现
2009/08/26 Javascript
Javascript无阻塞加载具体方式
2013/06/28 Javascript
JavaScript四种调用模式和this示例介绍
2014/01/02 Javascript
JS实现网页表格自动变大缩小的方法
2015/03/09 Javascript
JS实现控制表格单元格垂直对齐的方法
2015/03/30 Javascript
浅谈String.valueOf()方法的使用
2016/06/06 Javascript
基于JavaScript实现多级菜单效果
2017/07/25 Javascript
vue+Vue Router多级侧导航切换路由(页面)的实现代码
2018/12/20 Javascript
JavaScript异步操作的几种常见处理方法实例总结
2020/05/11 Javascript
基于Vue3.0开发轻量级手机端弹框组件V3Popup的场景分析
2020/12/30 Vue.js
[03:01]DOTA2英雄基础教程 露娜
2014/01/07 DOTA
python定时采集摄像头图像上传ftp服务器功能实现
2013/12/23 Python
Python内建数据结构详解
2016/02/03 Python
详解tensorflow实现迁移学习实例
2018/02/10 Python
PyTorch学习笔记之回归实战
2018/05/28 Python
对numpy中shape的深入理解
2018/06/15 Python
Python之列表实现栈的工作功能
2019/01/28 Python
numpy基础教程之np.linalg
2019/02/12 Python
Python实现最常见加密方式详解
2019/07/13 Python
利用Python将图片中扭曲矩形的复原
2020/09/07 Python
英国剑桥包官网:The Cambridge Satchel Company
2016/08/01 全球购物
中国茶叶、茶具一站式网上购物商城:醉品茶城
2018/07/03 全球购物
军训 自我鉴定
2014/02/03 职场文书
护士自我鉴定总结
2014/03/24 职场文书
小学优秀辅导员事迹材料
2014/05/11 职场文书
中学生关于梦想的演讲稿
2014/08/22 职场文书
教你快速开启Apache SkyWalking的自监控
2021/04/25 Servers
Python语言内置数据类型
2022/02/24 Python