pytorch加载预训练模型与自己模型不匹配的解决方案


Posted in Python onMay 13, 2021

pytorch中如果自己搭建网络并且加载别人的与训练模型的话,如果模型和参数不严格匹配,就可能会出问题,接下来记录一下我的解决方法。

两个有序字典找不同

模型的参数和pth文件的参数都是有序字典(OrderedDict),把字典中的键转为列表就可以在for循环里迭代找不同了。

model = ResNet18(1)
model_dict1 = torch.load('resnet18.pth')
model_dict2 = model.state_dict()
model_list1 = list(model_dict1.keys())
model_list2 = list(model_dict2.keys())
len1 = len(model_list1)
len2 = len(model_list2)
minlen = min(len1, len2)
for n in range(minlen):
    if model_dict1[model_list1[n]].shape != model_dict2[model_list2[n]].shape:
        err = 1

自己搭建模型的注意事项

搭网络时要对照pth文件的字典顺序搭,字典顺序、权重尺寸(shape)和变量命名必须与pth文件完全一致。如果仅仅是变量命名不同,可采用类似的方法对模型的权重重新赋值。

model = ResNet18(1)
model_dict1 = torch.load('resnet18.pth')
model_dict2 = model.state_dict()
model_list1 = list(model_dict1.keys())
model_list2 = list(model_dict2.keys())
len1 = len(model_list1)
len2 = len(model_list2)
minlen = min(len1, len2)
for n in range(minlen):
    if model_dict1[model_list1[n]].shape != model_dict2[model_list2[n]].shape:
        continue
    model_dict1[model_list1[n]] = model_dict2[model_list2[n]]
model.load_state_dict(model_dict2)

完整的代码见自己搭建resnet18网络并加载torchvision自带权重

新增的改进代码

model_dict1 = torch.load('yolov5.pth')
model_dict2 = model.state_dict()
model_list1 = list(model_dict1.keys())
model_list2 = list(model_dict2.keys())
len1 = len(model_list1)
len2 = len(model_list2)
m, n = 0, 0
while True:
    if m >= len1 or n >= len2:
        break
    layername1, layername2 = model_list1[m], model_list2[n]
    w1, w2 = model_dict1[layername1], model_dict2[layername2]
    if w1.shape != w2.shape:
        continue
    model_dict2[layername2] = model_dict1[layername1]
    m += 1
    n += 1
model.load_state_dict(model_dict2)

如果因为模型不匹配,运行第14行语句后,可看自己情况手动对m或n加上1。

补充:pytorch的一些坑:用预训练的vgg模型的部分层的特征报错,如张量不匹配

看代码吧~

#打算取VGG19的第二个全连接层的输出,那么就需要构建一个类,这个类要包含VGG的全部卷积层,
#以及到第二个全连接层的全部网络还有他们对应的参数
class Classification_att(nn.Module):
    def __init__(self, rgb_range):
        super(Classification_att, self).__init__()
        self.vgg19 =models.vgg19(pretrained=True)
        vgg = models.vgg19(pretrained=True).features
        conv_modules = [m for m in vgg]
        self.vgg_conv = nn.Sequential(*conv_modules[:37])
        classfi = models.vgg19(pretrained=True).classifier
        classif_modules = [n for n in classfi]
        self.vgg_class = nn.Sequential(*classif_modules[:4])
        vgg_mean = (0.485, 0.456, 0.406)
        vgg_std = (0.229 * rgb_range, 0.224 * rgb_range, 0.225 * rgb_range)
        self.sub_mean = common.MeanShift(rgb_range, vgg_mean, vgg_std)
        for p in self.vgg_conv.parameters():
            p.requires_grad = False
        for p in self.vgg_class.parameters():
            p.requires_grad = False
        self.classifi = nn.Sequential(
            nn.Linear(4096, 1024),
            nn.ReLU(True),
            nn.Linear(1024, 256),
            nn.ReLU(True),
            nn.Linear(256, 64),
        )
 
    def forward(self, x):
        x = F.interpolate(x, size=[224, 224], scale_factor=None, mode='bilinear', 
        align_corners=False)
        x = self.sub_mean(x)
        x = self.vgg_conv(x)  
        x = self.vgg_class(x)  #执行这部报错,说张量不匹配

原因是因为卷积层的输出不能直接连接全连接层,即使输出的张量的总的大小是一致的

查看vgg的pytorch源码发现是

x = self.features(x)
x = self.avgpool(x)
x = torch.flatten(x, 1)
x = self.classifier(x)
#自己的代码没有torch.flatten(x, 1)这步

所以自己的少了一步

x = torch.flatten(x, 1)

补上就好了!

以上为个人经验,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
python打开网页和暂停实例
Sep 30 Python
详解Python编程中time模块的使用
Nov 20 Python
Python模拟脉冲星伪信号频率实例代码
Jan 03 Python
Python 实现在文件中的每一行添加一个逗号
Apr 29 Python
python查看模块,对象的函数方法
Oct 16 Python
Django ManyToManyField 跨越中间表查询的方法
Dec 18 Python
对Django项目中的ORM映射与模糊查询的使用详解
Jul 18 Python
Python中的 sort 和 sorted的用法与区别
Aug 10 Python
pytorch 归一化与反归一化实例
Dec 31 Python
详解Flask前后端分离项目案例
Jul 24 Python
Pycharm新手使用教程(图文详解)
Sep 17 Python
详解Python调用系统命令的六种方法
Jan 28 Python
Python数据分析入门之教你怎么搭建环境
Pytorch 统计模型参数量的操作 param.numel()
May 13 #Python
Python机器学习算法之决策树算法的实现与优缺点
Python爬虫基础之爬虫的分类知识总结
pytorch中的numel函数用法说明
May 13 #Python
pytorch损失反向传播后梯度为none的问题
如何使用Python实现一个简易的ORM模型
May 12 #Python
You might like
完美利用Yii2微信后台开发的系列总结
2016/07/18 PHP
Laravel中的Blade模板引擎示例详解
2017/10/10 PHP
javascript 兼容鼠标滚轮事件
2009/04/07 Javascript
JQuery 网站换肤功能实现代码
2009/11/02 Javascript
深入理解javascript学习笔记(一) 编写高质量代码
2012/08/09 Javascript
web的各种前端打印方法之jquery打印插件jqprint实现网页打印
2013/01/09 Javascript
javascript中取前n天日期的两种方法分享
2014/01/26 Javascript
Ajax局部更新导致JS事件重复触发问题的解决方法
2014/10/14 Javascript
JavaScript DOM元素尺寸和位置
2015/04/13 Javascript
纯JavaScript代码实现移动设备绘图解锁
2015/10/16 Javascript
信息页文内画中画广告js实现代码(文中加载广告方式)
2016/01/03 Javascript
非常酷炫的Bootstrap图片轮播动画
2016/05/27 Javascript
jquery根据一个值来选中select下的option实例代码
2016/08/29 Javascript
Javascript 动态改变imput type属性
2016/11/01 Javascript
JS设置时间无效问题的解决办法
2017/02/18 Javascript
js登录滑动验证的实现(不滑动无法登陆)
2018/01/03 Javascript
深入理解Antd-Select组件的用法
2020/02/25 Javascript
Vue初始化中的选项合并之initInternalComponent详解
2020/06/11 Javascript
如何基于viewport vm适配移动端页面
2020/11/13 Javascript
[08:54]《一刀刀一天》之DOTA全时刻18:十九支奔赴西雅图队伍全部出炉
2014/06/04 DOTA
[01:21:58]守擂赛DOTA2第一周决赛
2020/04/22 DOTA
python解决网站的反爬虫策略总结
2016/10/26 Python
Python实现两个list对应元素相减操作示例
2017/06/09 Python
python 用下标截取字符串的实例
2018/12/25 Python
python 多线程重启方法
2019/02/18 Python
python+Django实现防止SQL注入的办法
2019/10/31 Python
Python任务调度利器之APScheduler详解
2020/04/02 Python
Python 添加文件注释和函数注释操作
2020/08/09 Python
超级实用的8个Python列表技巧
2020/08/24 Python
实列教程 一款基于jquery和css3的响应式二级导航菜单
2014/11/13 HTML / CSS
HTML5 自动聚焦(autofocus)属性使用介绍
2013/08/07 HTML / CSS
营销总经理岗位职责
2014/02/02 职场文书
工作能力自我评价2015
2015/03/05 职场文书
2015年学校政教处工作总结
2015/05/26 职场文书
2019年世界儿童日宣传标语
2019/11/22 职场文书
Python Parser的用法
2021/05/12 Python