pytorch加载预训练模型与自己模型不匹配的解决方案


Posted in Python onMay 13, 2021

pytorch中如果自己搭建网络并且加载别人的与训练模型的话,如果模型和参数不严格匹配,就可能会出问题,接下来记录一下我的解决方法。

两个有序字典找不同

模型的参数和pth文件的参数都是有序字典(OrderedDict),把字典中的键转为列表就可以在for循环里迭代找不同了。

model = ResNet18(1)
model_dict1 = torch.load('resnet18.pth')
model_dict2 = model.state_dict()
model_list1 = list(model_dict1.keys())
model_list2 = list(model_dict2.keys())
len1 = len(model_list1)
len2 = len(model_list2)
minlen = min(len1, len2)
for n in range(minlen):
    if model_dict1[model_list1[n]].shape != model_dict2[model_list2[n]].shape:
        err = 1

自己搭建模型的注意事项

搭网络时要对照pth文件的字典顺序搭,字典顺序、权重尺寸(shape)和变量命名必须与pth文件完全一致。如果仅仅是变量命名不同,可采用类似的方法对模型的权重重新赋值。

model = ResNet18(1)
model_dict1 = torch.load('resnet18.pth')
model_dict2 = model.state_dict()
model_list1 = list(model_dict1.keys())
model_list2 = list(model_dict2.keys())
len1 = len(model_list1)
len2 = len(model_list2)
minlen = min(len1, len2)
for n in range(minlen):
    if model_dict1[model_list1[n]].shape != model_dict2[model_list2[n]].shape:
        continue
    model_dict1[model_list1[n]] = model_dict2[model_list2[n]]
model.load_state_dict(model_dict2)

完整的代码见自己搭建resnet18网络并加载torchvision自带权重

新增的改进代码

model_dict1 = torch.load('yolov5.pth')
model_dict2 = model.state_dict()
model_list1 = list(model_dict1.keys())
model_list2 = list(model_dict2.keys())
len1 = len(model_list1)
len2 = len(model_list2)
m, n = 0, 0
while True:
    if m >= len1 or n >= len2:
        break
    layername1, layername2 = model_list1[m], model_list2[n]
    w1, w2 = model_dict1[layername1], model_dict2[layername2]
    if w1.shape != w2.shape:
        continue
    model_dict2[layername2] = model_dict1[layername1]
    m += 1
    n += 1
model.load_state_dict(model_dict2)

如果因为模型不匹配,运行第14行语句后,可看自己情况手动对m或n加上1。

补充:pytorch的一些坑:用预训练的vgg模型的部分层的特征报错,如张量不匹配

看代码吧~

#打算取VGG19的第二个全连接层的输出,那么就需要构建一个类,这个类要包含VGG的全部卷积层,
#以及到第二个全连接层的全部网络还有他们对应的参数
class Classification_att(nn.Module):
    def __init__(self, rgb_range):
        super(Classification_att, self).__init__()
        self.vgg19 =models.vgg19(pretrained=True)
        vgg = models.vgg19(pretrained=True).features
        conv_modules = [m for m in vgg]
        self.vgg_conv = nn.Sequential(*conv_modules[:37])
        classfi = models.vgg19(pretrained=True).classifier
        classif_modules = [n for n in classfi]
        self.vgg_class = nn.Sequential(*classif_modules[:4])
        vgg_mean = (0.485, 0.456, 0.406)
        vgg_std = (0.229 * rgb_range, 0.224 * rgb_range, 0.225 * rgb_range)
        self.sub_mean = common.MeanShift(rgb_range, vgg_mean, vgg_std)
        for p in self.vgg_conv.parameters():
            p.requires_grad = False
        for p in self.vgg_class.parameters():
            p.requires_grad = False
        self.classifi = nn.Sequential(
            nn.Linear(4096, 1024),
            nn.ReLU(True),
            nn.Linear(1024, 256),
            nn.ReLU(True),
            nn.Linear(256, 64),
        )
 
    def forward(self, x):
        x = F.interpolate(x, size=[224, 224], scale_factor=None, mode='bilinear', 
        align_corners=False)
        x = self.sub_mean(x)
        x = self.vgg_conv(x)  
        x = self.vgg_class(x)  #执行这部报错,说张量不匹配

原因是因为卷积层的输出不能直接连接全连接层,即使输出的张量的总的大小是一致的

查看vgg的pytorch源码发现是

x = self.features(x)
x = self.avgpool(x)
x = torch.flatten(x, 1)
x = self.classifier(x)
#自己的代码没有torch.flatten(x, 1)这步

所以自己的少了一步

x = torch.flatten(x, 1)

补上就好了!

以上为个人经验,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
python实现批量转换文件编码(批转换编码示例)
Jan 23 Python
Python专用方法与迭代机制实例分析
Sep 15 Python
python在windows下创建隐藏窗口子进程的方法
Jun 04 Python
python基础知识小结之集合
Nov 25 Python
python中将字典形式的数据循环插入Excel
Jan 16 Python
python3实现windows下同名进程监控
Jun 21 Python
python去除拼音声调字母,替换为字母的方法
Nov 28 Python
Python面向对象基础入门之设置对象属性
Dec 11 Python
Python把对应格式的csv文件转换成字典类型存储脚本的方法
Feb 12 Python
pyqt5让图片自适应QLabel大小上以及移除已显示的图片方法
Jun 21 Python
基于python连接oracle导并出数据文件
Apr 28 Python
python 爬虫之selenium可视化爬虫的实现
Dec 04 Python
Python数据分析入门之教你怎么搭建环境
Pytorch 统计模型参数量的操作 param.numel()
May 13 #Python
Python机器学习算法之决策树算法的实现与优缺点
Python爬虫基础之爬虫的分类知识总结
pytorch中的numel函数用法说明
May 13 #Python
pytorch损失反向传播后梯度为none的问题
如何使用Python实现一个简易的ORM模型
May 12 #Python
You might like
JQuery each()嵌套使用小结
2014/04/18 Javascript
node.js中的path.isAbsolute方法使用说明
2014/12/08 Javascript
asp.net中oracle 存储过程(图文)
2015/08/12 Javascript
JS+CSS实现电子商务网站导航模板效果代码
2015/09/10 Javascript
jQuery 1.9.1源码分析系列(十三)之位置大小操作
2015/12/02 Javascript
JavaScript实现搜索框的自动完成功能(一)
2016/02/25 Javascript
jQuery实现的选择商品飞入文本框动画效果完整实例
2016/08/10 Javascript
JS实现类似百叶窗下拉菜单效果
2016/12/30 Javascript
Vue.js系列之vue-router(上)(3)
2017/01/03 Javascript
jQuery图片拖动组件Dropzone用法示例
2017/01/17 Javascript
ES6模块化的import和export用法方法总结
2017/08/08 Javascript
JS中Map和ForEach的区别
2018/02/05 Javascript
H5+C3+JS实现双人对战五子棋游戏(UI篇)
2020/05/28 Javascript
6行代码实现微信小程序页面返回顶部效果
2018/12/28 Javascript
Angular单元测试之事件触发的实现
2020/01/20 Javascript
在vue中给后台接口传的值为数组的格式代码
2020/11/12 Javascript
Python使用pip安装pySerial串口通讯模块
2018/04/20 Python
python scp 批量同步文件的实现方法
2019/01/03 Python
python爬虫 线程池创建并获取文件代码实例
2019/09/28 Python
django xadmin action兼容自定义model权限教程
2020/03/30 Python
如何配置关联Python 解释器 Anaconda的教程(图解)
2020/04/30 Python
Django 解决新建表删除后无法重新创建等问题
2020/05/21 Python
Python爬虫抓取指定网页图片代码实例
2020/07/24 Python
PHP解析URL是哪个函数?怎么用?
2013/05/09 面试题
什么叫应用程序域?什么是受管制的代码?什么是强类型系统?什么是装箱和拆箱?
2016/08/13 面试题
护理学毕业生求职信
2013/11/14 职场文书
物流管理专业应届生求职信
2013/11/21 职场文书
影视动画专业个人的自我评价
2013/12/31 职场文书
迟到早退检讨书
2014/02/10 职场文书
《最后的姿势》教学反思
2014/02/27 职场文书
2014年国庆节庆祝建国65周年比赛演讲稿
2014/09/21 职场文书
党的群众路线教育实践活动个人对照检查材料
2014/09/22 职场文书
2015年专项整治工作总结
2015/04/03 职场文书
学校安全管理制度
2015/08/06 职场文书
基于angular实现树形二级表格
2021/10/16 Javascript
MySQL 主从复制数据不一致的解决方法
2022/03/18 MySQL