pytorch加载预训练模型与自己模型不匹配的解决方案


Posted in Python onMay 13, 2021

pytorch中如果自己搭建网络并且加载别人的与训练模型的话,如果模型和参数不严格匹配,就可能会出问题,接下来记录一下我的解决方法。

两个有序字典找不同

模型的参数和pth文件的参数都是有序字典(OrderedDict),把字典中的键转为列表就可以在for循环里迭代找不同了。

model = ResNet18(1)
model_dict1 = torch.load('resnet18.pth')
model_dict2 = model.state_dict()
model_list1 = list(model_dict1.keys())
model_list2 = list(model_dict2.keys())
len1 = len(model_list1)
len2 = len(model_list2)
minlen = min(len1, len2)
for n in range(minlen):
    if model_dict1[model_list1[n]].shape != model_dict2[model_list2[n]].shape:
        err = 1

自己搭建模型的注意事项

搭网络时要对照pth文件的字典顺序搭,字典顺序、权重尺寸(shape)和变量命名必须与pth文件完全一致。如果仅仅是变量命名不同,可采用类似的方法对模型的权重重新赋值。

model = ResNet18(1)
model_dict1 = torch.load('resnet18.pth')
model_dict2 = model.state_dict()
model_list1 = list(model_dict1.keys())
model_list2 = list(model_dict2.keys())
len1 = len(model_list1)
len2 = len(model_list2)
minlen = min(len1, len2)
for n in range(minlen):
    if model_dict1[model_list1[n]].shape != model_dict2[model_list2[n]].shape:
        continue
    model_dict1[model_list1[n]] = model_dict2[model_list2[n]]
model.load_state_dict(model_dict2)

完整的代码见自己搭建resnet18网络并加载torchvision自带权重

新增的改进代码

model_dict1 = torch.load('yolov5.pth')
model_dict2 = model.state_dict()
model_list1 = list(model_dict1.keys())
model_list2 = list(model_dict2.keys())
len1 = len(model_list1)
len2 = len(model_list2)
m, n = 0, 0
while True:
    if m >= len1 or n >= len2:
        break
    layername1, layername2 = model_list1[m], model_list2[n]
    w1, w2 = model_dict1[layername1], model_dict2[layername2]
    if w1.shape != w2.shape:
        continue
    model_dict2[layername2] = model_dict1[layername1]
    m += 1
    n += 1
model.load_state_dict(model_dict2)

如果因为模型不匹配,运行第14行语句后,可看自己情况手动对m或n加上1。

补充:pytorch的一些坑:用预训练的vgg模型的部分层的特征报错,如张量不匹配

看代码吧~

#打算取VGG19的第二个全连接层的输出,那么就需要构建一个类,这个类要包含VGG的全部卷积层,
#以及到第二个全连接层的全部网络还有他们对应的参数
class Classification_att(nn.Module):
    def __init__(self, rgb_range):
        super(Classification_att, self).__init__()
        self.vgg19 =models.vgg19(pretrained=True)
        vgg = models.vgg19(pretrained=True).features
        conv_modules = [m for m in vgg]
        self.vgg_conv = nn.Sequential(*conv_modules[:37])
        classfi = models.vgg19(pretrained=True).classifier
        classif_modules = [n for n in classfi]
        self.vgg_class = nn.Sequential(*classif_modules[:4])
        vgg_mean = (0.485, 0.456, 0.406)
        vgg_std = (0.229 * rgb_range, 0.224 * rgb_range, 0.225 * rgb_range)
        self.sub_mean = common.MeanShift(rgb_range, vgg_mean, vgg_std)
        for p in self.vgg_conv.parameters():
            p.requires_grad = False
        for p in self.vgg_class.parameters():
            p.requires_grad = False
        self.classifi = nn.Sequential(
            nn.Linear(4096, 1024),
            nn.ReLU(True),
            nn.Linear(1024, 256),
            nn.ReLU(True),
            nn.Linear(256, 64),
        )
 
    def forward(self, x):
        x = F.interpolate(x, size=[224, 224], scale_factor=None, mode='bilinear', 
        align_corners=False)
        x = self.sub_mean(x)
        x = self.vgg_conv(x)  
        x = self.vgg_class(x)  #执行这部报错,说张量不匹配

原因是因为卷积层的输出不能直接连接全连接层,即使输出的张量的总的大小是一致的

查看vgg的pytorch源码发现是

x = self.features(x)
x = self.avgpool(x)
x = torch.flatten(x, 1)
x = self.classifier(x)
#自己的代码没有torch.flatten(x, 1)这步

所以自己的少了一步

x = torch.flatten(x, 1)

补上就好了!

以上为个人经验,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
Python中的is和id用法分析
Jan 26 Python
python模拟Django框架实例
May 17 Python
利用Python中SocketServer 实现客户端与服务器间非阻塞通信
Dec 15 Python
Python3 replace()函数使用方法
Mar 19 Python
攻击者是如何将PHP Phar包伪装成图像以绕过文件类型检测的(推荐)
Oct 11 Python
Django开发的简易留言板案例详解
Dec 04 Python
在python中将list分段并保存为array类型的方法
Jul 15 Python
Python使用pymysql模块操作mysql增删改查实例分析
Dec 19 Python
Pycharm最新激活码2019(推荐)
Dec 31 Python
python实现图像全景拼接
Mar 27 Python
20行Python代码实现视频字符化功能
Apr 13 Python
python怎么自定义捕获错误
Jun 29 Python
Python数据分析入门之教你怎么搭建环境
Pytorch 统计模型参数量的操作 param.numel()
May 13 #Python
Python机器学习算法之决策树算法的实现与优缺点
Python爬虫基础之爬虫的分类知识总结
pytorch中的numel函数用法说明
May 13 #Python
pytorch损失反向传播后梯度为none的问题
如何使用Python实现一个简易的ORM模型
May 12 #Python
You might like
php输出xml格式字符串(用的这个)
2012/07/12 PHP
CURL状态码列表(详细)
2013/06/27 PHP
推荐一款MAC OS X 下php集成开发环境mamp
2014/11/08 PHP
PHP编写简单的App接口
2016/08/28 PHP
PHP文件与目录操作示例
2016/12/24 PHP
javascript动画浅析
2012/08/30 Javascript
Jquery判断radio、selelct、checkbox是否选中及获取选中值方法总结
2015/04/15 Javascript
asp.net中oracle 存储过程(图文)
2015/08/12 Javascript
基于jQuery下拉选择框插件支持单选多选功能代码
2016/06/07 Javascript
详解JavaScript中的六种错误类型
2017/09/21 Javascript
node 利用进程通信实现Cluster共享内存
2017/10/27 Javascript
JavaScript面向对象程序设计创建对象的方法分析
2018/08/13 Javascript
AngularJS $http post 传递参数数据的方法
2018/10/09 Javascript
elementUI select组件默认选中效果实现的方法
2019/03/25 Javascript
微信小程序云开发之数据库操作
2019/05/18 Javascript
微信小程序实现圆形进度条动画
2020/11/18 Javascript
JS实现点击发送验证码 xx秒后重新发送功能
2019/07/30 Javascript
优化Vue中date format的性能详解
2020/01/13 Javascript
详解vue或uni-app的跨域问题解决方案
2020/02/21 Javascript
JavaScript设计模式--桥梁模式引入操作实例分析
2020/05/23 Javascript
python数字图像处理之骨架提取与分水岭算法
2018/04/27 Python
pandas实现将dataframe满足某一条件的值选出
2019/06/12 Python
Python 可变类型和不可变类型及引用过程解析
2019/09/27 Python
pyhton中__pycache__文件夹的产生与作用详解
2019/11/24 Python
如何在mac环境中用python处理protobuf
2019/12/25 Python
Django admin 实现search_fields精确查询实例
2020/03/30 Python
使用python处理题库表格并转化为word形式的实现
2020/04/14 Python
Python通过getattr函数获取对象的属性值
2020/10/16 Python
详解python os.path.exists判断文件或文件夹是否存在
2020/11/16 Python
彻底解决pip下载pytorch慢的问题方法
2021/03/01 Python
酒店管理专业毕业生推荐信
2013/11/10 职场文书
毕业生自荐信如何写
2014/03/24 职场文书
2014年医学生毕业自我鉴定
2014/03/26 职场文书
个人诉讼委托书范本
2014/10/17 职场文书
2015年基建工作总结范文
2015/05/23 职场文书
运输公司工作总结
2015/08/11 职场文书