pytorch加载预训练模型与自己模型不匹配的解决方案


Posted in Python onMay 13, 2021

pytorch中如果自己搭建网络并且加载别人的与训练模型的话,如果模型和参数不严格匹配,就可能会出问题,接下来记录一下我的解决方法。

两个有序字典找不同

模型的参数和pth文件的参数都是有序字典(OrderedDict),把字典中的键转为列表就可以在for循环里迭代找不同了。

model = ResNet18(1)
model_dict1 = torch.load('resnet18.pth')
model_dict2 = model.state_dict()
model_list1 = list(model_dict1.keys())
model_list2 = list(model_dict2.keys())
len1 = len(model_list1)
len2 = len(model_list2)
minlen = min(len1, len2)
for n in range(minlen):
    if model_dict1[model_list1[n]].shape != model_dict2[model_list2[n]].shape:
        err = 1

自己搭建模型的注意事项

搭网络时要对照pth文件的字典顺序搭,字典顺序、权重尺寸(shape)和变量命名必须与pth文件完全一致。如果仅仅是变量命名不同,可采用类似的方法对模型的权重重新赋值。

model = ResNet18(1)
model_dict1 = torch.load('resnet18.pth')
model_dict2 = model.state_dict()
model_list1 = list(model_dict1.keys())
model_list2 = list(model_dict2.keys())
len1 = len(model_list1)
len2 = len(model_list2)
minlen = min(len1, len2)
for n in range(minlen):
    if model_dict1[model_list1[n]].shape != model_dict2[model_list2[n]].shape:
        continue
    model_dict1[model_list1[n]] = model_dict2[model_list2[n]]
model.load_state_dict(model_dict2)

完整的代码见自己搭建resnet18网络并加载torchvision自带权重

新增的改进代码

model_dict1 = torch.load('yolov5.pth')
model_dict2 = model.state_dict()
model_list1 = list(model_dict1.keys())
model_list2 = list(model_dict2.keys())
len1 = len(model_list1)
len2 = len(model_list2)
m, n = 0, 0
while True:
    if m >= len1 or n >= len2:
        break
    layername1, layername2 = model_list1[m], model_list2[n]
    w1, w2 = model_dict1[layername1], model_dict2[layername2]
    if w1.shape != w2.shape:
        continue
    model_dict2[layername2] = model_dict1[layername1]
    m += 1
    n += 1
model.load_state_dict(model_dict2)

如果因为模型不匹配,运行第14行语句后,可看自己情况手动对m或n加上1。

补充:pytorch的一些坑:用预训练的vgg模型的部分层的特征报错,如张量不匹配

看代码吧~

#打算取VGG19的第二个全连接层的输出,那么就需要构建一个类,这个类要包含VGG的全部卷积层,
#以及到第二个全连接层的全部网络还有他们对应的参数
class Classification_att(nn.Module):
    def __init__(self, rgb_range):
        super(Classification_att, self).__init__()
        self.vgg19 =models.vgg19(pretrained=True)
        vgg = models.vgg19(pretrained=True).features
        conv_modules = [m for m in vgg]
        self.vgg_conv = nn.Sequential(*conv_modules[:37])
        classfi = models.vgg19(pretrained=True).classifier
        classif_modules = [n for n in classfi]
        self.vgg_class = nn.Sequential(*classif_modules[:4])
        vgg_mean = (0.485, 0.456, 0.406)
        vgg_std = (0.229 * rgb_range, 0.224 * rgb_range, 0.225 * rgb_range)
        self.sub_mean = common.MeanShift(rgb_range, vgg_mean, vgg_std)
        for p in self.vgg_conv.parameters():
            p.requires_grad = False
        for p in self.vgg_class.parameters():
            p.requires_grad = False
        self.classifi = nn.Sequential(
            nn.Linear(4096, 1024),
            nn.ReLU(True),
            nn.Linear(1024, 256),
            nn.ReLU(True),
            nn.Linear(256, 64),
        )
 
    def forward(self, x):
        x = F.interpolate(x, size=[224, 224], scale_factor=None, mode='bilinear', 
        align_corners=False)
        x = self.sub_mean(x)
        x = self.vgg_conv(x)  
        x = self.vgg_class(x)  #执行这部报错,说张量不匹配

原因是因为卷积层的输出不能直接连接全连接层,即使输出的张量的总的大小是一致的

查看vgg的pytorch源码发现是

x = self.features(x)
x = self.avgpool(x)
x = torch.flatten(x, 1)
x = self.classifier(x)
#自己的代码没有torch.flatten(x, 1)这步

所以自己的少了一步

x = torch.flatten(x, 1)

补上就好了!

以上为个人经验,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
qpython3 读取安卓lastpass Cookies
Jun 19 Python
socket + select 完成伪并发操作的实例
Aug 15 Python
Python实现常见的回文字符串算法
Nov 14 Python
Python2与Python3的区别实例总结
Apr 17 Python
python实现Dijkstra算法的最短路径问题
Jun 21 Python
Python 从subprocess运行的子进程中实时获取输出的例子
Aug 14 Python
python 字符串常用方法汇总详解
Sep 16 Python
TFRecord格式存储数据与队列读取实例
Jan 21 Python
解决python-docx打包之后找不到default.docx的问题
Feb 13 Python
Python面向对象魔法方法和单例模块代码实例
Mar 25 Python
Python GUI编程之tkinter 关于 ttkbootstrap 的使用详解
Mar 03 Python
python使用opencv对图像添加噪声(高斯/椒盐/泊松/斑点)
Apr 06 Python
Python数据分析入门之教你怎么搭建环境
Pytorch 统计模型参数量的操作 param.numel()
May 13 #Python
Python机器学习算法之决策树算法的实现与优缺点
Python爬虫基础之爬虫的分类知识总结
pytorch中的numel函数用法说明
May 13 #Python
pytorch损失反向传播后梯度为none的问题
如何使用Python实现一个简易的ORM模型
May 12 #Python
You might like
浅析Apache中RewriteCond规则参数的详细介绍
2013/06/30 PHP
PHP生成等比缩略图类和自定义函数分享
2014/06/25 PHP
php禁止浏览器使用缓存页面的方法
2014/11/07 PHP
yii,CI,yaf框架+smarty模板使用方法
2015/12/29 PHP
PHP计算当前坐标3公里内4个角落的最大最小经纬度实例
2016/02/26 PHP
iOS自定义提示弹出框实现类似UIAlertView的效果
2016/11/16 PHP
php使用flock阻塞写入文件和非阻塞写入文件的实例讲解
2017/07/10 PHP
PHP后门隐藏的一些技巧总结
2020/11/04 PHP
一个JS翻页效果
2007/07/23 Javascript
jquery ui对话框实例代码
2013/05/10 Javascript
我的Node.js学习之路(四)--单元测试
2014/07/06 Javascript
js父页面与子页面不同时显示的方法
2014/10/16 Javascript
JQuery中使文本框获得焦点的方法实例分析
2015/02/28 Javascript
js使用split函数按照多个字符对字符串进行分割的方法
2015/03/20 Javascript
javascript三元运算符用法实例
2015/04/16 Javascript
简单谈谈node.js 版本控制 nvm和 n
2015/10/15 Javascript
Javascript中获取浏览器类型和操作系统版本等客户端信息常用代码
2016/06/28 Javascript
原生js获取iframe中dom元素--父子页面相互获取对方dom元素的方法
2016/08/05 Javascript
原生js实现可拖动的登录框效果
2017/01/21 Javascript
数组Array的排序sort方法
2017/02/17 Javascript
vue 2.0项目中如何引入element-ui详解
2017/09/06 Javascript
微信小程序获取用户信息并保存登录状态详解
2019/05/10 Javascript
js判断浏览器的环境(pc端,移动端,还是微信浏览器)
2020/12/24 Javascript
vue 数据双向绑定的实现方法
2021/03/04 Vue.js
python实现的jpg格式图片修复代码
2015/04/21 Python
Python实现简单截取中文字符串的方法
2015/06/15 Python
Python之str操作方法(详解)
2017/06/19 Python
python实现Excel文件转换为TXT文件
2019/04/28 Python
用python生成(动态彩色)二维码的方法(使用myqr库实现)
2019/06/24 Python
python飞机大战pygame游戏框架搭建操作详解
2019/12/17 Python
Python Selenium XPath根据文本内容查找元素的方法
2020/12/07 Python
决定成败的关键——创业计划书
2014/01/24 职场文书
护士自我鉴定怎么写
2014/02/07 职场文书
表扬通报怎么写
2015/01/16 职场文书
python 实现mysql自动增删分区的方法
2021/04/01 Python
html中显示特殊符号(附带特殊字符对应表)
2021/06/21 HTML / CSS