pytorch加载预训练模型与自己模型不匹配的解决方案


Posted in Python onMay 13, 2021

pytorch中如果自己搭建网络并且加载别人的与训练模型的话,如果模型和参数不严格匹配,就可能会出问题,接下来记录一下我的解决方法。

两个有序字典找不同

模型的参数和pth文件的参数都是有序字典(OrderedDict),把字典中的键转为列表就可以在for循环里迭代找不同了。

model = ResNet18(1)
model_dict1 = torch.load('resnet18.pth')
model_dict2 = model.state_dict()
model_list1 = list(model_dict1.keys())
model_list2 = list(model_dict2.keys())
len1 = len(model_list1)
len2 = len(model_list2)
minlen = min(len1, len2)
for n in range(minlen):
    if model_dict1[model_list1[n]].shape != model_dict2[model_list2[n]].shape:
        err = 1

自己搭建模型的注意事项

搭网络时要对照pth文件的字典顺序搭,字典顺序、权重尺寸(shape)和变量命名必须与pth文件完全一致。如果仅仅是变量命名不同,可采用类似的方法对模型的权重重新赋值。

model = ResNet18(1)
model_dict1 = torch.load('resnet18.pth')
model_dict2 = model.state_dict()
model_list1 = list(model_dict1.keys())
model_list2 = list(model_dict2.keys())
len1 = len(model_list1)
len2 = len(model_list2)
minlen = min(len1, len2)
for n in range(minlen):
    if model_dict1[model_list1[n]].shape != model_dict2[model_list2[n]].shape:
        continue
    model_dict1[model_list1[n]] = model_dict2[model_list2[n]]
model.load_state_dict(model_dict2)

完整的代码见自己搭建resnet18网络并加载torchvision自带权重

新增的改进代码

model_dict1 = torch.load('yolov5.pth')
model_dict2 = model.state_dict()
model_list1 = list(model_dict1.keys())
model_list2 = list(model_dict2.keys())
len1 = len(model_list1)
len2 = len(model_list2)
m, n = 0, 0
while True:
    if m >= len1 or n >= len2:
        break
    layername1, layername2 = model_list1[m], model_list2[n]
    w1, w2 = model_dict1[layername1], model_dict2[layername2]
    if w1.shape != w2.shape:
        continue
    model_dict2[layername2] = model_dict1[layername1]
    m += 1
    n += 1
model.load_state_dict(model_dict2)

如果因为模型不匹配,运行第14行语句后,可看自己情况手动对m或n加上1。

补充:pytorch的一些坑:用预训练的vgg模型的部分层的特征报错,如张量不匹配

看代码吧~

#打算取VGG19的第二个全连接层的输出,那么就需要构建一个类,这个类要包含VGG的全部卷积层,
#以及到第二个全连接层的全部网络还有他们对应的参数
class Classification_att(nn.Module):
    def __init__(self, rgb_range):
        super(Classification_att, self).__init__()
        self.vgg19 =models.vgg19(pretrained=True)
        vgg = models.vgg19(pretrained=True).features
        conv_modules = [m for m in vgg]
        self.vgg_conv = nn.Sequential(*conv_modules[:37])
        classfi = models.vgg19(pretrained=True).classifier
        classif_modules = [n for n in classfi]
        self.vgg_class = nn.Sequential(*classif_modules[:4])
        vgg_mean = (0.485, 0.456, 0.406)
        vgg_std = (0.229 * rgb_range, 0.224 * rgb_range, 0.225 * rgb_range)
        self.sub_mean = common.MeanShift(rgb_range, vgg_mean, vgg_std)
        for p in self.vgg_conv.parameters():
            p.requires_grad = False
        for p in self.vgg_class.parameters():
            p.requires_grad = False
        self.classifi = nn.Sequential(
            nn.Linear(4096, 1024),
            nn.ReLU(True),
            nn.Linear(1024, 256),
            nn.ReLU(True),
            nn.Linear(256, 64),
        )
 
    def forward(self, x):
        x = F.interpolate(x, size=[224, 224], scale_factor=None, mode='bilinear', 
        align_corners=False)
        x = self.sub_mean(x)
        x = self.vgg_conv(x)  
        x = self.vgg_class(x)  #执行这部报错,说张量不匹配

原因是因为卷积层的输出不能直接连接全连接层,即使输出的张量的总的大小是一致的

查看vgg的pytorch源码发现是

x = self.features(x)
x = self.avgpool(x)
x = torch.flatten(x, 1)
x = self.classifier(x)
#自己的代码没有torch.flatten(x, 1)这步

所以自己的少了一步

x = torch.flatten(x, 1)

补上就好了!

以上为个人经验,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
python判断给定的字符串是否是有效日期的方法
May 13 Python
在Python中处理字符串之ljust()方法的使用简介
May 19 Python
Python 中urls.py:URL dispatcher(路由配置文件)详解
Mar 24 Python
基于Python socket的端口扫描程序实例代码
Feb 09 Python
python使用代理ip访问网站的实例
May 07 Python
Python常见数据结构之栈与队列用法示例
Jan 14 Python
python爬虫 正则表达式解析
Sep 28 Python
Django学习之文件上传与下载
Oct 06 Python
Pytorch.nn.conv2d 过程验证方式(单,多通道卷积过程)
Jan 03 Python
tensorflow 获取所有variable或tensor的name示例
Jan 04 Python
使用Pytorch实现two-head(多输出)模型的操作
May 28 Python
Pycharm连接远程服务器并远程调试的全过程
Jun 24 Python
Python数据分析入门之教你怎么搭建环境
Pytorch 统计模型参数量的操作 param.numel()
May 13 #Python
Python机器学习算法之决策树算法的实现与优缺点
Python爬虫基础之爬虫的分类知识总结
pytorch中的numel函数用法说明
May 13 #Python
pytorch损失反向传播后梯度为none的问题
如何使用Python实现一个简易的ORM模型
May 12 #Python
You might like
php实现网站插件机制的方法
2009/11/10 PHP
php开发微信支付获取用户地址
2015/10/04 PHP
学习php设计模式 php实现原型模式(prototype)
2015/12/07 PHP
PHP Class SoapClient not found解决方法
2018/01/20 PHP
PHP实现从PostgreSQL数据库检索数据分页显示及根据条件查找数据示例
2018/06/09 PHP
加载 Javascript 最佳实践
2011/10/30 Javascript
node.js正则表达式获取网页中所有链接的代码实例
2014/06/03 Javascript
推荐一个封装好的getElementsByClassName方法
2014/12/02 Javascript
jquery.form.js实现将form提交转为ajax方式提交的方法
2015/04/07 Javascript
jquery表单插件Autotab使用方法详解
2016/06/24 Javascript
Jq通过td获取同行其它列td的方法
2016/10/05 Javascript
JS中页面与页面之间超链接跳转中文乱码问题的解决办法
2016/12/15 Javascript
vue中实现methods一个方法调用另外一个方法
2018/02/08 Javascript
Vue+iview+webpack ie浏览器兼容简单处理
2019/09/20 Javascript
node静态服务器实现静态读取文件或文件夹
2019/12/03 Javascript
JS几个常用的函数和对象定义与用法示例
2020/01/15 Javascript
Python字符转换
2008/09/06 Python
使用Python的Dataframe取两列时间值相差一年的所有行方法
2018/07/10 Python
Python设计模式之组合模式原理与用法实例分析
2019/01/11 Python
对python判断ip是否可达的实例详解
2019/01/31 Python
Python字符串的常见操作实例小结
2019/04/08 Python
Pandas库之DataFrame使用的学习笔记
2019/06/21 Python
python requests抓取one推送文字和图片代码实例
2019/11/04 Python
Python opencv相机标定实现原理及步骤详解
2020/04/09 Python
Python3基于plotly模块保存图片表格
2020/08/03 Python
安装并免费使用Pycharm专业版(学生/教师)
2020/09/24 Python
吉列剃须刀美国官网:Gillette美国
2018/07/13 全球购物
超市后勤自我鉴定
2014/01/17 职场文书
手机银行营销方案
2014/03/14 职场文书
体育课课后反思
2014/04/24 职场文书
中学生国旗下讲话稿
2014/04/26 职场文书
住房租房协议书
2014/08/20 职场文书
涉外离婚协议书怎么写
2014/11/20 职场文书
导游词400字
2015/02/13 职场文书
财务工作个人总结
2015/02/27 职场文书
利用Python实现翻译HTML中的文本字符串
2022/06/21 Python