pytorch加载预训练模型与自己模型不匹配的解决方案


Posted in Python onMay 13, 2021

pytorch中如果自己搭建网络并且加载别人的与训练模型的话,如果模型和参数不严格匹配,就可能会出问题,接下来记录一下我的解决方法。

两个有序字典找不同

模型的参数和pth文件的参数都是有序字典(OrderedDict),把字典中的键转为列表就可以在for循环里迭代找不同了。

model = ResNet18(1)
model_dict1 = torch.load('resnet18.pth')
model_dict2 = model.state_dict()
model_list1 = list(model_dict1.keys())
model_list2 = list(model_dict2.keys())
len1 = len(model_list1)
len2 = len(model_list2)
minlen = min(len1, len2)
for n in range(minlen):
    if model_dict1[model_list1[n]].shape != model_dict2[model_list2[n]].shape:
        err = 1

自己搭建模型的注意事项

搭网络时要对照pth文件的字典顺序搭,字典顺序、权重尺寸(shape)和变量命名必须与pth文件完全一致。如果仅仅是变量命名不同,可采用类似的方法对模型的权重重新赋值。

model = ResNet18(1)
model_dict1 = torch.load('resnet18.pth')
model_dict2 = model.state_dict()
model_list1 = list(model_dict1.keys())
model_list2 = list(model_dict2.keys())
len1 = len(model_list1)
len2 = len(model_list2)
minlen = min(len1, len2)
for n in range(minlen):
    if model_dict1[model_list1[n]].shape != model_dict2[model_list2[n]].shape:
        continue
    model_dict1[model_list1[n]] = model_dict2[model_list2[n]]
model.load_state_dict(model_dict2)

完整的代码见自己搭建resnet18网络并加载torchvision自带权重

新增的改进代码

model_dict1 = torch.load('yolov5.pth')
model_dict2 = model.state_dict()
model_list1 = list(model_dict1.keys())
model_list2 = list(model_dict2.keys())
len1 = len(model_list1)
len2 = len(model_list2)
m, n = 0, 0
while True:
    if m >= len1 or n >= len2:
        break
    layername1, layername2 = model_list1[m], model_list2[n]
    w1, w2 = model_dict1[layername1], model_dict2[layername2]
    if w1.shape != w2.shape:
        continue
    model_dict2[layername2] = model_dict1[layername1]
    m += 1
    n += 1
model.load_state_dict(model_dict2)

如果因为模型不匹配,运行第14行语句后,可看自己情况手动对m或n加上1。

补充:pytorch的一些坑:用预训练的vgg模型的部分层的特征报错,如张量不匹配

看代码吧~

#打算取VGG19的第二个全连接层的输出,那么就需要构建一个类,这个类要包含VGG的全部卷积层,
#以及到第二个全连接层的全部网络还有他们对应的参数
class Classification_att(nn.Module):
    def __init__(self, rgb_range):
        super(Classification_att, self).__init__()
        self.vgg19 =models.vgg19(pretrained=True)
        vgg = models.vgg19(pretrained=True).features
        conv_modules = [m for m in vgg]
        self.vgg_conv = nn.Sequential(*conv_modules[:37])
        classfi = models.vgg19(pretrained=True).classifier
        classif_modules = [n for n in classfi]
        self.vgg_class = nn.Sequential(*classif_modules[:4])
        vgg_mean = (0.485, 0.456, 0.406)
        vgg_std = (0.229 * rgb_range, 0.224 * rgb_range, 0.225 * rgb_range)
        self.sub_mean = common.MeanShift(rgb_range, vgg_mean, vgg_std)
        for p in self.vgg_conv.parameters():
            p.requires_grad = False
        for p in self.vgg_class.parameters():
            p.requires_grad = False
        self.classifi = nn.Sequential(
            nn.Linear(4096, 1024),
            nn.ReLU(True),
            nn.Linear(1024, 256),
            nn.ReLU(True),
            nn.Linear(256, 64),
        )
 
    def forward(self, x):
        x = F.interpolate(x, size=[224, 224], scale_factor=None, mode='bilinear', 
        align_corners=False)
        x = self.sub_mean(x)
        x = self.vgg_conv(x)  
        x = self.vgg_class(x)  #执行这部报错,说张量不匹配

原因是因为卷积层的输出不能直接连接全连接层,即使输出的张量的总的大小是一致的

查看vgg的pytorch源码发现是

x = self.features(x)
x = self.avgpool(x)
x = torch.flatten(x, 1)
x = self.classifier(x)
#自己的代码没有torch.flatten(x, 1)这步

所以自己的少了一步

x = torch.flatten(x, 1)

补上就好了!

以上为个人经验,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
基于Python __dict__与dir()的区别详解
Oct 30 Python
python生成excel的实例代码
Nov 08 Python
python中利用zfill方法自动给数字前面补0
Apr 10 Python
Python深拷贝与浅拷贝用法实例分析
May 05 Python
python如何实现代码检查
Jun 28 Python
python顺序执行多个py文件的方法
Jun 29 Python
pandas DataFrame的修改方法(值、列、索引)
Aug 02 Python
关于Python3 类方法、静态方法新解
Aug 30 Python
Python 最强编辑器详细使用指南(PyCharm )
Sep 16 Python
Python中使用socks5设置全局代理的方法示例
Apr 15 Python
Python使用scapy模块发包收包
May 07 Python
Python办公自动化之教你用Python批量识别发票并录入到Excel表格中
Jun 26 Python
Python数据分析入门之教你怎么搭建环境
Pytorch 统计模型参数量的操作 param.numel()
May 13 #Python
Python机器学习算法之决策树算法的实现与优缺点
Python爬虫基础之爬虫的分类知识总结
pytorch中的numel函数用法说明
May 13 #Python
pytorch损失反向传播后梯度为none的问题
如何使用Python实现一个简易的ORM模型
May 12 #Python
You might like
用穿越火线快速入门php面向对象
2012/02/22 PHP
初步介绍PHP扩展开发经验分享
2012/09/06 PHP
在win7中搭建Linux+PHP 开发环境
2014/10/08 PHP
thinkphp获取栏目和文章当前位置的方法
2014/10/29 PHP
Symfony2函数用法实例分析
2016/03/18 PHP
PHP基于GD库实现的生成图片缩略图函数示例
2017/07/05 PHP
php 广告点击统计代码(php+mysql)
2018/02/21 PHP
js查找父节点的简单方法
2008/06/28 Javascript
无缝滚动改进版支持上下左右滚动(封装成函数)
2012/12/04 Javascript
JS特权方法定义作用以及与公有方法的区别
2013/03/18 Javascript
JS 页面计时器示例代码
2013/10/28 Javascript
JQuery获取与设置HTML元素的内容或文本的实现代码
2014/06/20 Javascript
Javascript前端UI框架Kit使用指南之Kitjs简介
2014/11/28 Javascript
浅谈jquery.fn.extend与jquery.extend区别
2015/07/13 Javascript
Bootstrap框架的学习教程详解(二)
2016/10/18 Javascript
高性能js数组去重(12种方法,史上最全)
2019/12/21 Javascript
微信小程序onShareTimeline()实现分享朋友圈
2021/01/07 Javascript
[00:43]DOTA2小紫本全民票选福利PA至宝全方位展示
2014/11/25 DOTA
Python中使用select模块实现非阻塞的IO
2015/02/03 Python
Python 描述符(Descriptor)入门
2016/11/20 Python
python-xpath获取html文档的部分内容
2020/03/06 Python
pycharm 代码自动补全的实现方法(图文)
2020/09/18 Python
python读取excel数据并且画图的实现示例
2021/02/08 Python
英国领先的名牌服装折扣零售商:Brown Bag Clothing
2019/01/08 全球购物
教师自我鉴定
2013/12/13 职场文书
外贸采购员岗位职责
2014/03/08 职场文书
大跃进口号
2014/06/16 职场文书
五好家庭事迹材料
2014/12/20 职场文书
2015年世界环境日活动总结
2015/02/11 职场文书
端午节寄语2015
2015/03/23 职场文书
走进科学观后感
2015/06/18 职场文书
诚实守信主题班会
2015/08/13 职场文书
《假如》教学反思
2016/02/17 职场文书
因个人工作失误检讨书
2019/06/21 职场文书
MySQL查询学习之基础查询操作
2021/05/08 MySQL
MySQL系列之一 MariaDB-server安装
2021/07/02 MySQL