pytorch加载预训练模型与自己模型不匹配的解决方案


Posted in Python onMay 13, 2021

pytorch中如果自己搭建网络并且加载别人的与训练模型的话,如果模型和参数不严格匹配,就可能会出问题,接下来记录一下我的解决方法。

两个有序字典找不同

模型的参数和pth文件的参数都是有序字典(OrderedDict),把字典中的键转为列表就可以在for循环里迭代找不同了。

model = ResNet18(1)
model_dict1 = torch.load('resnet18.pth')
model_dict2 = model.state_dict()
model_list1 = list(model_dict1.keys())
model_list2 = list(model_dict2.keys())
len1 = len(model_list1)
len2 = len(model_list2)
minlen = min(len1, len2)
for n in range(minlen):
    if model_dict1[model_list1[n]].shape != model_dict2[model_list2[n]].shape:
        err = 1

自己搭建模型的注意事项

搭网络时要对照pth文件的字典顺序搭,字典顺序、权重尺寸(shape)和变量命名必须与pth文件完全一致。如果仅仅是变量命名不同,可采用类似的方法对模型的权重重新赋值。

model = ResNet18(1)
model_dict1 = torch.load('resnet18.pth')
model_dict2 = model.state_dict()
model_list1 = list(model_dict1.keys())
model_list2 = list(model_dict2.keys())
len1 = len(model_list1)
len2 = len(model_list2)
minlen = min(len1, len2)
for n in range(minlen):
    if model_dict1[model_list1[n]].shape != model_dict2[model_list2[n]].shape:
        continue
    model_dict1[model_list1[n]] = model_dict2[model_list2[n]]
model.load_state_dict(model_dict2)

完整的代码见自己搭建resnet18网络并加载torchvision自带权重

新增的改进代码

model_dict1 = torch.load('yolov5.pth')
model_dict2 = model.state_dict()
model_list1 = list(model_dict1.keys())
model_list2 = list(model_dict2.keys())
len1 = len(model_list1)
len2 = len(model_list2)
m, n = 0, 0
while True:
    if m >= len1 or n >= len2:
        break
    layername1, layername2 = model_list1[m], model_list2[n]
    w1, w2 = model_dict1[layername1], model_dict2[layername2]
    if w1.shape != w2.shape:
        continue
    model_dict2[layername2] = model_dict1[layername1]
    m += 1
    n += 1
model.load_state_dict(model_dict2)

如果因为模型不匹配,运行第14行语句后,可看自己情况手动对m或n加上1。

补充:pytorch的一些坑:用预训练的vgg模型的部分层的特征报错,如张量不匹配

看代码吧~

#打算取VGG19的第二个全连接层的输出,那么就需要构建一个类,这个类要包含VGG的全部卷积层,
#以及到第二个全连接层的全部网络还有他们对应的参数
class Classification_att(nn.Module):
    def __init__(self, rgb_range):
        super(Classification_att, self).__init__()
        self.vgg19 =models.vgg19(pretrained=True)
        vgg = models.vgg19(pretrained=True).features
        conv_modules = [m for m in vgg]
        self.vgg_conv = nn.Sequential(*conv_modules[:37])
        classfi = models.vgg19(pretrained=True).classifier
        classif_modules = [n for n in classfi]
        self.vgg_class = nn.Sequential(*classif_modules[:4])
        vgg_mean = (0.485, 0.456, 0.406)
        vgg_std = (0.229 * rgb_range, 0.224 * rgb_range, 0.225 * rgb_range)
        self.sub_mean = common.MeanShift(rgb_range, vgg_mean, vgg_std)
        for p in self.vgg_conv.parameters():
            p.requires_grad = False
        for p in self.vgg_class.parameters():
            p.requires_grad = False
        self.classifi = nn.Sequential(
            nn.Linear(4096, 1024),
            nn.ReLU(True),
            nn.Linear(1024, 256),
            nn.ReLU(True),
            nn.Linear(256, 64),
        )
 
    def forward(self, x):
        x = F.interpolate(x, size=[224, 224], scale_factor=None, mode='bilinear', 
        align_corners=False)
        x = self.sub_mean(x)
        x = self.vgg_conv(x)  
        x = self.vgg_class(x)  #执行这部报错,说张量不匹配

原因是因为卷积层的输出不能直接连接全连接层,即使输出的张量的总的大小是一致的

查看vgg的pytorch源码发现是

x = self.features(x)
x = self.avgpool(x)
x = torch.flatten(x, 1)
x = self.classifier(x)
#自己的代码没有torch.flatten(x, 1)这步

所以自己的少了一步

x = torch.flatten(x, 1)

补上就好了!

以上为个人经验,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
Python实现调度算法代码详解
Dec 01 Python
Python使用pickle模块存储数据报错解决示例代码
Jan 26 Python
利用Python如何实现数据驱动的接口自动化测试
May 11 Python
Python中list查询及所需时间计算操作示例
Jun 21 Python
python tornado微信开发入门代码
Aug 24 Python
使用TensorFlow实现二分类的方法示例
Feb 05 Python
Pycharm中出现ImportError:DLL load failed:找不到指定模块的解决方法
Sep 17 Python
python打印直角三角形与等腰三角形实例代码
Oct 20 Python
django使用xadmin的全局配置详解
Nov 15 Python
python进程的状态、创建及使用方法详解
Dec 06 Python
Pytorch中Tensor与各种图像格式的相互转化详解
Dec 26 Python
python转化excel数字日期为标准日期操作
Jul 14 Python
Python数据分析入门之教你怎么搭建环境
Pytorch 统计模型参数量的操作 param.numel()
May 13 #Python
Python机器学习算法之决策树算法的实现与优缺点
Python爬虫基础之爬虫的分类知识总结
pytorch中的numel函数用法说明
May 13 #Python
pytorch损失反向传播后梯度为none的问题
如何使用Python实现一个简易的ORM模型
May 12 #Python
You might like
PHP_MySQL教程-第一天
2007/03/18 PHP
解析PHPExcel使用的常用说明以及把PHPExcel整合进CI框架的介绍
2013/06/24 PHP
PHP如何将log信息写入服务器中的log文件
2015/07/29 PHP
bindParam和bindValue的区别以及在Yii2中的使用详解
2018/03/12 PHP
070823更新的一个[消息提示框]组件 兼容ie7
2007/08/29 Javascript
javascript 多级checkbox选择效果
2009/08/20 Javascript
js png图片(有含有透明)在IE6中为什么不透明了
2010/02/07 Javascript
editable.js 基于jquery的表格的编辑插件
2011/10/24 Javascript
用jquery方法操作radio使其默认选项是否
2013/09/10 Javascript
YUI模块开发原理详解
2013/11/18 Javascript
jQuery EasyUI基础教程之EasyUI常用组件(推荐)
2016/07/15 Javascript
浅谈toLowerCase和toLocaleLowerCase的区别
2016/08/15 Javascript
JS获取当前使用的浏览器名字以及版本号实现方法
2016/08/19 Javascript
详解JS几种变量交换方式以及性能分析对比
2016/11/25 Javascript
Bootstrap CSS组件之按钮组(btn-group)
2016/12/17 Javascript
js实现1,2,3,5数字按照概率生成
2017/09/12 Javascript
vue mint-ui 实现省市区街道4级联动示例(仿淘宝京东收货地址4级联动)
2017/10/16 Javascript
详解webpack多页面配置记录
2018/01/22 Javascript
Vuex中mutations与actions的区别详解
2018/03/01 Javascript
使用webpack4编译并压缩ES6代码的方法示例
2019/04/24 Javascript
Vue.js原理分析之nextTick实现详解
2020/09/07 Javascript
详解Python中的各种函数的使用
2015/05/24 Python
python ddt实现数据驱动
2018/03/14 Python
Pyecharts绘制全球流向图的示例代码
2020/01/08 Python
python梯度下降算法的实现
2020/02/24 Python
Python爬虫爬取微信朋友圈
2020/08/06 Python
Python求区间正整数内所有素数之和的方法实例
2020/10/13 Python
CSS3贝塞尔曲线示例:创建链接悬停动画效果
2020/11/19 HTML / CSS
一级方程式赛车官方网上商店:F1 Store(支持中文)
2018/01/12 全球购物
Qoo10台湾站:亚洲领先的在线市场
2018/05/15 全球购物
莫斯科大型旅游休闲商品超市:Camping.ru
2020/09/16 全球购物
自动化专业个人求职信范文
2013/11/29 职场文书
大学生学业生涯规划
2014/01/05 职场文书
建筑投标担保书
2014/05/20 职场文书
施工安全生产承诺书
2014/05/23 职场文书
一年级语文教学随笔
2015/08/14 职场文书