pytorch加载预训练模型与自己模型不匹配的解决方案


Posted in Python onMay 13, 2021

pytorch中如果自己搭建网络并且加载别人的与训练模型的话,如果模型和参数不严格匹配,就可能会出问题,接下来记录一下我的解决方法。

两个有序字典找不同

模型的参数和pth文件的参数都是有序字典(OrderedDict),把字典中的键转为列表就可以在for循环里迭代找不同了。

model = ResNet18(1)
model_dict1 = torch.load('resnet18.pth')
model_dict2 = model.state_dict()
model_list1 = list(model_dict1.keys())
model_list2 = list(model_dict2.keys())
len1 = len(model_list1)
len2 = len(model_list2)
minlen = min(len1, len2)
for n in range(minlen):
    if model_dict1[model_list1[n]].shape != model_dict2[model_list2[n]].shape:
        err = 1

自己搭建模型的注意事项

搭网络时要对照pth文件的字典顺序搭,字典顺序、权重尺寸(shape)和变量命名必须与pth文件完全一致。如果仅仅是变量命名不同,可采用类似的方法对模型的权重重新赋值。

model = ResNet18(1)
model_dict1 = torch.load('resnet18.pth')
model_dict2 = model.state_dict()
model_list1 = list(model_dict1.keys())
model_list2 = list(model_dict2.keys())
len1 = len(model_list1)
len2 = len(model_list2)
minlen = min(len1, len2)
for n in range(minlen):
    if model_dict1[model_list1[n]].shape != model_dict2[model_list2[n]].shape:
        continue
    model_dict1[model_list1[n]] = model_dict2[model_list2[n]]
model.load_state_dict(model_dict2)

完整的代码见自己搭建resnet18网络并加载torchvision自带权重

新增的改进代码

model_dict1 = torch.load('yolov5.pth')
model_dict2 = model.state_dict()
model_list1 = list(model_dict1.keys())
model_list2 = list(model_dict2.keys())
len1 = len(model_list1)
len2 = len(model_list2)
m, n = 0, 0
while True:
    if m >= len1 or n >= len2:
        break
    layername1, layername2 = model_list1[m], model_list2[n]
    w1, w2 = model_dict1[layername1], model_dict2[layername2]
    if w1.shape != w2.shape:
        continue
    model_dict2[layername2] = model_dict1[layername1]
    m += 1
    n += 1
model.load_state_dict(model_dict2)

如果因为模型不匹配,运行第14行语句后,可看自己情况手动对m或n加上1。

补充:pytorch的一些坑:用预训练的vgg模型的部分层的特征报错,如张量不匹配

看代码吧~

#打算取VGG19的第二个全连接层的输出,那么就需要构建一个类,这个类要包含VGG的全部卷积层,
#以及到第二个全连接层的全部网络还有他们对应的参数
class Classification_att(nn.Module):
    def __init__(self, rgb_range):
        super(Classification_att, self).__init__()
        self.vgg19 =models.vgg19(pretrained=True)
        vgg = models.vgg19(pretrained=True).features
        conv_modules = [m for m in vgg]
        self.vgg_conv = nn.Sequential(*conv_modules[:37])
        classfi = models.vgg19(pretrained=True).classifier
        classif_modules = [n for n in classfi]
        self.vgg_class = nn.Sequential(*classif_modules[:4])
        vgg_mean = (0.485, 0.456, 0.406)
        vgg_std = (0.229 * rgb_range, 0.224 * rgb_range, 0.225 * rgb_range)
        self.sub_mean = common.MeanShift(rgb_range, vgg_mean, vgg_std)
        for p in self.vgg_conv.parameters():
            p.requires_grad = False
        for p in self.vgg_class.parameters():
            p.requires_grad = False
        self.classifi = nn.Sequential(
            nn.Linear(4096, 1024),
            nn.ReLU(True),
            nn.Linear(1024, 256),
            nn.ReLU(True),
            nn.Linear(256, 64),
        )
 
    def forward(self, x):
        x = F.interpolate(x, size=[224, 224], scale_factor=None, mode='bilinear', 
        align_corners=False)
        x = self.sub_mean(x)
        x = self.vgg_conv(x)  
        x = self.vgg_class(x)  #执行这部报错,说张量不匹配

原因是因为卷积层的输出不能直接连接全连接层,即使输出的张量的总的大小是一致的

查看vgg的pytorch源码发现是

x = self.features(x)
x = self.avgpool(x)
x = torch.flatten(x, 1)
x = self.classifier(x)
#自己的代码没有torch.flatten(x, 1)这步

所以自己的少了一步

x = torch.flatten(x, 1)

补上就好了!

以上为个人经验,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
Python类的基础入门知识
Nov 24 Python
Python Paramiko模块的安装与使用详解
Nov 18 Python
Python中sort和sorted函数代码解析
Jan 25 Python
python实现下载pop3邮件保存到本地
Jun 19 Python
Flask-Mail用法实例分析
Jul 21 Python
在win10和linux上分别安装Python虚拟环境的方法步骤
May 09 Python
计算机二级python学习教程(1) 教大家如何学习python
May 16 Python
Python 计算任意两向量之间的夹角方法
Jul 05 Python
Python使用正则表达式分割字符串的实现方法
Jul 16 Python
15个Pythonic的代码示例(值得收藏)
Oct 29 Python
scrapy redis配置文件setting参数详解
Nov 18 Python
详解修改Anaconda中的Jupyter Notebook默认工作路径的三种方式
Jan 24 Python
Python数据分析入门之教你怎么搭建环境
Pytorch 统计模型参数量的操作 param.numel()
May 13 #Python
Python机器学习算法之决策树算法的实现与优缺点
Python爬虫基础之爬虫的分类知识总结
pytorch中的numel函数用法说明
May 13 #Python
pytorch损失反向传播后梯度为none的问题
如何使用Python实现一个简易的ORM模型
May 12 #Python
You might like
教你如何使用php session
2013/10/28 PHP
php中error与exception的区别及应用
2014/07/28 PHP
浅析PHP文件下载原理
2014/12/25 PHP
PHP中Static(静态)关键字功能与用法实例分析
2019/04/05 PHP
PHP7新特性
2021/03/09 PHP
juqery 学习之三 选择器 简单 内容
2010/11/25 Javascript
nodejs入门详解(多篇文章结合)
2012/03/07 NodeJs
5个最佳的Javascript日期处理类库分享
2012/04/15 Javascript
JQuery设置文本框和密码框得到焦点时的样式
2013/08/30 Javascript
Jquery倒计时源码分享
2014/05/16 Javascript
JQuery实现带排序功能的权限选择实例
2015/05/18 Javascript
使用jQuery+EasyUI实现CheckBoxTree的级联选中特效
2015/12/06 Javascript
JavaScript实现身份证验证代码
2016/02/17 Javascript
jQuery webuploader分片上传大文件
2016/11/07 Javascript
React Js 微信禁止复制链接分享禁止隐藏右上角菜单功能
2017/05/26 Javascript
es6学习之解构时应该注意的点
2017/08/29 Javascript
jQuery选择器之属性过滤选择器详解
2017/09/28 jQuery
vue 的keep-alive缓存功能的实现
2018/03/22 Javascript
Vue开发实现吸顶效果的示例代码
2018/08/21 Javascript
vue2实现搜索结果中的搜索关键字高亮的代码
2018/08/29 Javascript
js监听html页面的上下滚动事件方法
2018/09/11 Javascript
基于element-ui对话框el-dialog初始化的校验问题解决
2020/09/11 Javascript
小程序实现上下切换位置
2020/11/16 Javascript
使用Python控制摄像头拍照并发邮件
2019/04/23 Python
python实现两个经纬度点之间的距离和方位角的方法
2019/07/05 Python
python根据多个文件名批量查找文件
2019/08/13 Python
pyinstaller打包opencv和numpy程序运行错误解决
2019/08/16 Python
python 三元运算符使用解析
2019/09/16 Python
Python控制台实现交互式环境执行
2020/06/09 Python
香港太阳眼镜网上商店:SmartBuyGlasses香港
2016/07/22 全球购物
旅行社各个岗位职责
2014/03/15 职场文书
五月的鲜花活动方案
2014/08/21 职场文书
四风问题个人对照检查材料
2014/09/26 职场文书
团党委领导干部党的群众路线教育实践活动个人对照检查材料思想汇
2014/10/05 职场文书
单位接收函范文
2015/01/30 职场文书
穷人该怎么创业?谨记以下几点
2019/07/11 职场文书