解决pytorch 损失函数中输入输出不匹配的问题


Posted in Python onJune 05, 2021

一、pytorch 损失函数中输入输出不匹配问题

File "C:\Users\Rain\AppData\Local\Programs\Python\Anaconda.3.5.1\envs\python35\python35\lib\site-packages\torch\nn\modules\module.py", line 491, in __call__  result = self.forward(*input, **kwargs)

File "C:\Users\Rain\AppData\Local\Programs\Python\Anaconda.3.5.1\envs\python35\python35\lib\site-packages\torch\nn\modules\loss.py", line 500, in forward reduce=self.reduce)
 
File "C:\Users\Rain\AppData\Local\Programs\Python\Anaconda.3.5.1\envs\python35\python35\lib\site-packages\torch\nn\functional.py", line 1514, in binary_cross_entropy_with_logits
 
raise ValueError("Target size ({}) must be the same as input size ({})".format(target.size(), input.size()))
 
ValueError: Target size (torch.Size([32])) must be the same as input size (torch.Size([32,2]))

原因

input 和 target 尺寸不匹配

解决方案:

将target转为onehot

例如:

one_hot = torch.nn.functional.one_hot(masks, num_classes=args.num_classes)

二、Pytorch遇到权重不匹配的问题

最近,楼主在pytorch微调模型时遇到

size mismatch for fc.weight: copying a param with shape torch.Size([1000, 2048]) from checkpoint, the shape in current model is torch.Size([2, 2048]).

size mismatch for fc.bias: copying a param with shape torch.Size([1000]) from checkpoint, the shape in current model is torch.Size([2]).

这个是因为楼主下载的预训练模型中的全连接层是1000类别的,而楼主本人的类别只有2类,所以会报不匹配的错误

解决方案:

从报错信息可以看出,是fc层的权重参数不匹配,那我们只要不load 这一层的参数就可以了。

net = se_resnet50(num_classes=2)
pretrained_dict = torch.load("./senet/seresnet50-60a8950a85b2b.pkl")
model_dict = net.state_dict()
# 重新制作预训练的权重,主要是减去参数不匹配的层,楼主这边层名为“fc”
pretrained_dict = {k: v for k, v in pretrained_dict.items() if (k in model_dict and 'fc' not in k)}
# 更新权重
model_dict.update(pretrained_dict)
net.load_state_dict(model_dict)

以上为个人经验,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
Python 用户登录验证的小例子
Mar 06 Python
python3使用urllib示例取googletranslate(谷歌翻译)
Jan 23 Python
Python中编写ORM框架的入门指引
Apr 29 Python
详解Python中的strftime()方法的使用
May 22 Python
利用python微信库itchat实现微信自动回复功能
May 18 Python
Python使用arrow库优雅地处理时间数据详解
Oct 10 Python
python下载文件记录黑名单的实现代码
Oct 24 Python
使用 Python 实现简单的 switch/case 语句的方法
Sep 17 Python
python开发准备工作之配置虚拟环境(非常重要)
Feb 11 Python
Python利用sqlacodegen自动生成ORM实体类示例
Jun 04 Python
python或C++读取指定文件夹下的所有图片
Aug 31 Python
基于MSELoss()与CrossEntropyLoss()的区别详解
Jan 02 Python
Pytorch distributed 多卡并行载入模型操作
Jun 05 #Python
Pytorch中的学习率衰减及其用法详解
Jun 05 #Python
pytorch finetuning 自己的图片进行训练操作
Jun 05 #Python
Python 如何将integer转化为罗马数(3999以内)
Jun 05 #Python
刚学完怎么用Python实现定时任务,转头就跑去撩妹!
OpenCV中resize函数插值算法的实现过程(五种)
Jun 05 #Python
OpenCV全景图像拼接的实现示例
You might like
PHP代码优化之成员变量获取速度对比
2014/02/28 PHP
php取出数组单个值的方法
2018/03/12 PHP
Thinkphp整合阿里云OSS图片上传实例代码
2019/04/28 PHP
javascript基于jQuery的表格悬停变色/恢复,表格点击变色/恢复,点击行选Checkbox
2008/08/05 Javascript
利用谷歌地图API获取点与点的距离的js代码
2012/10/11 Javascript
js事件冒泡实例分享(已测试)
2013/04/23 Javascript
Jquery 复选框取值兼容FF和IE8(测试有效)
2013/10/29 Javascript
基于jQuery Circlr插件实现产品图片360度旋转
2015/09/20 Javascript
jQuery实现智能判断固定导航条或侧边栏的方法
2016/09/04 Javascript
canvas实现钟表效果
2017/02/13 Javascript
Angular中实现树形结构视图实例代码
2017/05/05 Javascript
jQuery.Sumoselect插件实现下拉复选框效果
2017/11/09 jQuery
vue-cli项目中使用Mockjs详解
2018/05/14 Javascript
使用VueRouter的addRoutes方法实现动态添加用户的权限路由
2019/06/03 Javascript
layui多图上传实现删除功能的例子
2019/09/23 Javascript
小程序api实现promise封装过程解析
2019/11/21 Javascript
基于JavaScript实现十五拼图代码实例
2020/04/26 Javascript
Python 爬虫学习笔记之单线程爬虫
2016/09/21 Python
python一键升级所有pip package的方法
2017/01/16 Python
Python 3.x读写csv文件中数字的方法示例
2017/08/29 Python
浅谈python迭代器
2017/11/08 Python
python实现数独游戏 java简单实现数独游戏
2018/03/30 Python
如何优雅地改进Django中的模板碎片缓存详解
2018/07/04 Python
Python中几种属性访问的区别与用法详解
2018/10/10 Python
python Jupyter运行时间实例过程解析
2019/12/13 Python
美国祛痘、抗衰老药妆品牌:Murad
2016/08/27 全球购物
文秘专业应届生求职信
2014/05/26 职场文书
2014教师教育实践活动对照检查材料思想汇报
2014/09/21 职场文书
专题民主生活会对照检查材料思想汇报
2014/09/29 职场文书
2014保险公司个人工作总结
2014/12/09 职场文书
小学二年级数学教学计划
2015/01/20 职场文书
教师廉洁自律个人总结
2015/02/10 职场文书
毕业生对母校寄语
2015/02/26 职场文书
2015毕业寄语大全
2015/02/26 职场文书
女性健康知识讲座主持词
2015/07/04 职场文书
新郎新娘致辞
2015/07/31 职场文书