pytorch加载预训练模型与自己模型不匹配的解决方案


Posted in Python onMay 13, 2021

pytorch中如果自己搭建网络并且加载别人的与训练模型的话,如果模型和参数不严格匹配,就可能会出问题,接下来记录一下我的解决方法。

两个有序字典找不同

模型的参数和pth文件的参数都是有序字典(OrderedDict),把字典中的键转为列表就可以在for循环里迭代找不同了。

model = ResNet18(1)
model_dict1 = torch.load('resnet18.pth')
model_dict2 = model.state_dict()
model_list1 = list(model_dict1.keys())
model_list2 = list(model_dict2.keys())
len1 = len(model_list1)
len2 = len(model_list2)
minlen = min(len1, len2)
for n in range(minlen):
    if model_dict1[model_list1[n]].shape != model_dict2[model_list2[n]].shape:
        err = 1

自己搭建模型的注意事项

搭网络时要对照pth文件的字典顺序搭,字典顺序、权重尺寸(shape)和变量命名必须与pth文件完全一致。如果仅仅是变量命名不同,可采用类似的方法对模型的权重重新赋值。

model = ResNet18(1)
model_dict1 = torch.load('resnet18.pth')
model_dict2 = model.state_dict()
model_list1 = list(model_dict1.keys())
model_list2 = list(model_dict2.keys())
len1 = len(model_list1)
len2 = len(model_list2)
minlen = min(len1, len2)
for n in range(minlen):
    if model_dict1[model_list1[n]].shape != model_dict2[model_list2[n]].shape:
        continue
    model_dict1[model_list1[n]] = model_dict2[model_list2[n]]
model.load_state_dict(model_dict2)

完整的代码见自己搭建resnet18网络并加载torchvision自带权重

新增的改进代码

model_dict1 = torch.load('yolov5.pth')
model_dict2 = model.state_dict()
model_list1 = list(model_dict1.keys())
model_list2 = list(model_dict2.keys())
len1 = len(model_list1)
len2 = len(model_list2)
m, n = 0, 0
while True:
    if m >= len1 or n >= len2:
        break
    layername1, layername2 = model_list1[m], model_list2[n]
    w1, w2 = model_dict1[layername1], model_dict2[layername2]
    if w1.shape != w2.shape:
        continue
    model_dict2[layername2] = model_dict1[layername1]
    m += 1
    n += 1
model.load_state_dict(model_dict2)

如果因为模型不匹配,运行第14行语句后,可看自己情况手动对m或n加上1。

补充:pytorch的一些坑:用预训练的vgg模型的部分层的特征报错,如张量不匹配

看代码吧~

#打算取VGG19的第二个全连接层的输出,那么就需要构建一个类,这个类要包含VGG的全部卷积层,
#以及到第二个全连接层的全部网络还有他们对应的参数
class Classification_att(nn.Module):
    def __init__(self, rgb_range):
        super(Classification_att, self).__init__()
        self.vgg19 =models.vgg19(pretrained=True)
        vgg = models.vgg19(pretrained=True).features
        conv_modules = [m for m in vgg]
        self.vgg_conv = nn.Sequential(*conv_modules[:37])
        classfi = models.vgg19(pretrained=True).classifier
        classif_modules = [n for n in classfi]
        self.vgg_class = nn.Sequential(*classif_modules[:4])
        vgg_mean = (0.485, 0.456, 0.406)
        vgg_std = (0.229 * rgb_range, 0.224 * rgb_range, 0.225 * rgb_range)
        self.sub_mean = common.MeanShift(rgb_range, vgg_mean, vgg_std)
        for p in self.vgg_conv.parameters():
            p.requires_grad = False
        for p in self.vgg_class.parameters():
            p.requires_grad = False
        self.classifi = nn.Sequential(
            nn.Linear(4096, 1024),
            nn.ReLU(True),
            nn.Linear(1024, 256),
            nn.ReLU(True),
            nn.Linear(256, 64),
        )
 
    def forward(self, x):
        x = F.interpolate(x, size=[224, 224], scale_factor=None, mode='bilinear', 
        align_corners=False)
        x = self.sub_mean(x)
        x = self.vgg_conv(x)  
        x = self.vgg_class(x)  #执行这部报错,说张量不匹配

原因是因为卷积层的输出不能直接连接全连接层,即使输出的张量的总的大小是一致的

查看vgg的pytorch源码发现是

x = self.features(x)
x = self.avgpool(x)
x = torch.flatten(x, 1)
x = self.classifier(x)
#自己的代码没有torch.flatten(x, 1)这步

所以自己的少了一步

x = torch.flatten(x, 1)

补上就好了!

以上为个人经验,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
Django中几种重定向方法
Apr 28 Python
python控制台中实现进度条功能
Nov 10 Python
Python中序列的修改、散列与切片详解
Aug 27 Python
详解python里使用正则表达式的全匹配功能
Oct 19 Python
Python判断中文字符串是否相等的实例
Jul 06 Python
python 实现对文件夹中的图像连续重命名方法
Oct 25 Python
Python 数据库操作 SQLAlchemy的示例代码
Feb 18 Python
python中sort和sorted排序的实例方法
Aug 26 Python
Python threading.local代码实例及原理解析
Mar 16 Python
Python telnet登陆功能实现代码
Apr 16 Python
python统计mysql数据量变化并调用接口告警的示例代码
Sep 21 Python
Python基于unittest实现测试用例执行
Nov 25 Python
Python数据分析入门之教你怎么搭建环境
Pytorch 统计模型参数量的操作 param.numel()
May 13 #Python
Python机器学习算法之决策树算法的实现与优缺点
Python爬虫基础之爬虫的分类知识总结
pytorch中的numel函数用法说明
May 13 #Python
pytorch损失反向传播后梯度为none的问题
如何使用Python实现一个简易的ORM模型
May 12 #Python
You might like
一个简单的PHP&MYSQL留言板源码
2020/07/19 PHP
PHP中for循环语句的几种变型
2007/03/16 PHP
php 中的str_replace 函数总结
2007/04/27 PHP
为IP查询添加GOOGLE地图功能的代码
2010/08/08 PHP
PHP+Ajax异步通讯实现用户名邮箱验证是否已注册( 2种方法实现)
2011/12/28 PHP
教你在PHPStorm中配置Xdebug
2015/07/27 PHP
Laravel框架执行原生SQL语句及使用paginate分页的方法
2018/08/17 PHP
JQuery优缺点分析说明
2011/04/10 Javascript
说明你的Javascript技术很烂的五个原因
2011/04/26 Javascript
关于js new Date() 出现NaN 的分析
2012/10/23 Javascript
JavaScript实现在数组中查找不同顺序排列的字符串
2014/09/26 Javascript
javascript实现根据iphone屏幕方向调用不同样式表的方法
2015/07/13 Javascript
AngularJS中监视Scope变量以及外部调用Scope方法
2016/01/23 Javascript
JQuery 两种方法解决刚创建的元素遍历不到的问题
2016/04/13 Javascript
jquery文字填写自动高度的实现方法
2016/11/07 Javascript
从零学习node.js之express入门(六)
2017/02/25 Javascript
swiper动态改变滑动内容的实现方法
2018/01/17 Javascript
vue 中动态绑定class 和 style的方法代码详解
2018/06/01 Javascript
Node.js中的child_process模块详解
2018/06/08 Javascript
vue使用ajax获取后台数据进行显示的示例
2018/08/09 Javascript
微信小程序实现多选框全选与取消全选功能示例
2019/05/14 Javascript
老生常谈Python序列化和反序列化
2017/06/28 Python
Python中property函数用法实例分析
2018/06/04 Python
Sanic框架安装与简单入门示例
2018/07/16 Python
Python Django框架单元测试之文件上传测试示例
2019/05/17 Python
nginx+uwsgi+django环境搭建的方法步骤
2019/11/25 Python
CSS3截取字符串实例代码【推荐】
2018/06/07 HTML / CSS
canvas 如何绘制线段的实现方法
2018/07/12 HTML / CSS
岗位竞聘书范文
2014/03/31 职场文书
幼儿生日活动方案
2014/08/27 职场文书
置业顾问岗位职责
2015/02/09 职场文书
初中毕业生自我评价
2015/03/02 职场文书
升学宴学生致辞
2015/09/29 职场文书
2019年最新证婚词精选集!
2019/06/28 职场文书
2019毕业典礼主持词!
2019/07/05 职场文书
CSS list-style-type属性使用方法
2023/05/21 HTML / CSS