解决pytorch 损失函数中输入输出不匹配的问题


Posted in Python onJune 05, 2021

一、pytorch 损失函数中输入输出不匹配问题

File "C:\Users\Rain\AppData\Local\Programs\Python\Anaconda.3.5.1\envs\python35\python35\lib\site-packages\torch\nn\modules\module.py", line 491, in __call__  result = self.forward(*input, **kwargs)

File "C:\Users\Rain\AppData\Local\Programs\Python\Anaconda.3.5.1\envs\python35\python35\lib\site-packages\torch\nn\modules\loss.py", line 500, in forward reduce=self.reduce)
 
File "C:\Users\Rain\AppData\Local\Programs\Python\Anaconda.3.5.1\envs\python35\python35\lib\site-packages\torch\nn\functional.py", line 1514, in binary_cross_entropy_with_logits
 
raise ValueError("Target size ({}) must be the same as input size ({})".format(target.size(), input.size()))
 
ValueError: Target size (torch.Size([32])) must be the same as input size (torch.Size([32,2]))

原因

input 和 target 尺寸不匹配

解决方案:

将target转为onehot

例如:

one_hot = torch.nn.functional.one_hot(masks, num_classes=args.num_classes)

二、Pytorch遇到权重不匹配的问题

最近,楼主在pytorch微调模型时遇到

size mismatch for fc.weight: copying a param with shape torch.Size([1000, 2048]) from checkpoint, the shape in current model is torch.Size([2, 2048]).

size mismatch for fc.bias: copying a param with shape torch.Size([1000]) from checkpoint, the shape in current model is torch.Size([2]).

这个是因为楼主下载的预训练模型中的全连接层是1000类别的,而楼主本人的类别只有2类,所以会报不匹配的错误

解决方案:

从报错信息可以看出,是fc层的权重参数不匹配,那我们只要不load 这一层的参数就可以了。

net = se_resnet50(num_classes=2)
pretrained_dict = torch.load("./senet/seresnet50-60a8950a85b2b.pkl")
model_dict = net.state_dict()
# 重新制作预训练的权重,主要是减去参数不匹配的层,楼主这边层名为“fc”
pretrained_dict = {k: v for k, v in pretrained_dict.items() if (k in model_dict and 'fc' not in k)}
# 更新权重
model_dict.update(pretrained_dict)
net.load_state_dict(model_dict)

以上为个人经验,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
Python内置函数Type()函数一个有趣的用法
Feb 18 Python
Python with用法实例
Apr 14 Python
Python中使用hashlib模块处理算法的教程
Apr 28 Python
Python使用jsonpath-rw模块处理Json对象操作示例
Jul 31 Python
解决python线程卡死的问题
Feb 18 Python
python与字符编码问题
May 24 Python
Python小程序 控制鼠标循环点击代码实例
Oct 08 Python
python3连接kafka模块pykafka生产者简单封装代码
Dec 23 Python
基于Numba提高python运行效率过程解析
Mar 02 Python
Python Sqlalchemy如何实现select for update
Oct 12 Python
Python使用Opencv实现边缘检测以及轮廓检测的实现
Dec 31 Python
Python一些基本的图像操作和处理总结
Jun 23 Python
Pytorch distributed 多卡并行载入模型操作
Jun 05 #Python
Pytorch中的学习率衰减及其用法详解
Jun 05 #Python
pytorch finetuning 自己的图片进行训练操作
Jun 05 #Python
Python 如何将integer转化为罗马数(3999以内)
Jun 05 #Python
刚学完怎么用Python实现定时任务,转头就跑去撩妹!
OpenCV中resize函数插值算法的实现过程(五种)
Jun 05 #Python
OpenCV全景图像拼接的实现示例
You might like
php 随机数的产生、页面跳转、件读写、文件重命名、switch语句
2009/08/07 PHP
php流量统计功能的实现代码
2012/09/29 PHP
PHP的基本常识小结
2013/07/05 PHP
PHPCrawl爬虫库实现抓取酷狗歌单的方法示例
2017/12/21 PHP
实例讲解php实现多线程
2019/01/27 PHP
mapper--图片热点区域高亮组件官方站点
2007/12/22 Javascript
js多级树形弹出一个小窗口层(非常好用)实例代码
2013/03/19 Javascript
使用Node.js实现一个简单的FastCGI服务器实例
2014/06/09 Javascript
js实现宇宙星空背景效果的方法
2015/03/03 Javascript
JS实现网页上随机产生超链接地址的方法
2015/11/09 Javascript
js图片跟随鼠标移动代码
2015/11/26 Javascript
浅谈javascript:两种注释,声明变量,定义函数
2016/10/05 Javascript
AngularJS指令中的绑定策略实例分析
2016/12/14 Javascript
jquery实现拖动效果(代码分享)
2017/01/25 Javascript
JavaScript学习总结之正则的元字符和一些简单的应用
2017/06/30 Javascript
详解JS数组Reduce()方法详解及高级技巧
2017/08/18 Javascript
vue数字类型过滤器的示例代码
2017/09/07 Javascript
vue移动端轻量级的轮播组件实现代码
2018/07/12 Javascript
vue页面加载时的进度条功能(实例代码)
2020/01/13 Javascript
vue项目打包为APP,静态资源正常显示,但API请求不到数据的操作
2020/09/12 Javascript
js实现简易点击切换显示或隐藏
2020/11/29 Javascript
[02:08]2018年度CS GO枪械皮肤设计大赛优秀作者-完美盛典
2018/12/16 DOTA
[01:27:30]LGD vs Newbee 2019国际邀请赛小组赛 BO2 第二场 8.16
2019/08/19 DOTA
python生成器generator用法实例分析
2015/06/04 Python
浅谈python内置变量-reversed(seq)
2017/06/21 Python
基于python和flask实现http接口过程解析
2020/06/15 Python
Timberland美国官网:全球领先的户外品牌
2016/08/15 全球购物
Parfumdreams英国:香水和化妆品
2019/05/10 全球购物
一加手机美国官方网站:OnePlus美国
2019/09/19 全球购物
类、抽象类、接口的差异
2016/06/13 面试题
销售总经理岗位职责
2014/03/15 职场文书
应届生求职自荐信范文
2014/04/07 职场文书
学习演讲稿范文
2014/05/10 职场文书
人力资源管理专业毕业生自荐书
2014/05/25 职场文书
最美家庭活动方案
2014/08/31 职场文书
民事诉讼代理授权委托书范本
2014/10/08 职场文书