解决pytorch 损失函数中输入输出不匹配的问题


Posted in Python onJune 05, 2021

一、pytorch 损失函数中输入输出不匹配问题

File "C:\Users\Rain\AppData\Local\Programs\Python\Anaconda.3.5.1\envs\python35\python35\lib\site-packages\torch\nn\modules\module.py", line 491, in __call__  result = self.forward(*input, **kwargs)

File "C:\Users\Rain\AppData\Local\Programs\Python\Anaconda.3.5.1\envs\python35\python35\lib\site-packages\torch\nn\modules\loss.py", line 500, in forward reduce=self.reduce)
 
File "C:\Users\Rain\AppData\Local\Programs\Python\Anaconda.3.5.1\envs\python35\python35\lib\site-packages\torch\nn\functional.py", line 1514, in binary_cross_entropy_with_logits
 
raise ValueError("Target size ({}) must be the same as input size ({})".format(target.size(), input.size()))
 
ValueError: Target size (torch.Size([32])) must be the same as input size (torch.Size([32,2]))

原因

input 和 target 尺寸不匹配

解决方案:

将target转为onehot

例如:

one_hot = torch.nn.functional.one_hot(masks, num_classes=args.num_classes)

二、Pytorch遇到权重不匹配的问题

最近,楼主在pytorch微调模型时遇到

size mismatch for fc.weight: copying a param with shape torch.Size([1000, 2048]) from checkpoint, the shape in current model is torch.Size([2, 2048]).

size mismatch for fc.bias: copying a param with shape torch.Size([1000]) from checkpoint, the shape in current model is torch.Size([2]).

这个是因为楼主下载的预训练模型中的全连接层是1000类别的,而楼主本人的类别只有2类,所以会报不匹配的错误

解决方案:

从报错信息可以看出,是fc层的权重参数不匹配,那我们只要不load 这一层的参数就可以了。

net = se_resnet50(num_classes=2)
pretrained_dict = torch.load("./senet/seresnet50-60a8950a85b2b.pkl")
model_dict = net.state_dict()
# 重新制作预训练的权重,主要是减去参数不匹配的层,楼主这边层名为“fc”
pretrained_dict = {k: v for k, v in pretrained_dict.items() if (k in model_dict and 'fc' not in k)}
# 更新权重
model_dict.update(pretrained_dict)
net.load_state_dict(model_dict)

以上为个人经验,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
python中pycurl库的用法实例
Sep 30 Python
Python中一些自然语言工具的使用的入门教程
Apr 13 Python
归纳整理Python中的控制流语句的知识点
Apr 14 Python
python解决pandas处理缺失值为空字符串的问题
Apr 08 Python
python3监控CentOS磁盘空间脚本
Jun 21 Python
python一行sql太长折成多行并且有多个参数的方法
Jul 19 Python
浅谈django orm 优化
Aug 18 Python
python画图把时间作为横坐标的方法
Jul 07 Python
Python时间差中seconds和total_seconds的区别详解
Dec 26 Python
pytorch 实现将自己的图片数据处理成可以训练的图片类型
Jan 08 Python
Tensorflow获取张量Tensor的具体维数实例
Jan 19 Python
分享Python获取本机IP地址的几种方法
Mar 17 Python
Pytorch distributed 多卡并行载入模型操作
Jun 05 #Python
Pytorch中的学习率衰减及其用法详解
Jun 05 #Python
pytorch finetuning 自己的图片进行训练操作
Jun 05 #Python
Python 如何将integer转化为罗马数(3999以内)
Jun 05 #Python
刚学完怎么用Python实现定时任务,转头就跑去撩妹!
OpenCV中resize函数插值算法的实现过程(五种)
Jun 05 #Python
OpenCV全景图像拼接的实现示例
You might like
PHP删除HTMl标签的三种解决方法
2013/06/30 PHP
PHP Imagick完美实现图片裁切、生成缩略图、添加水印
2016/02/22 PHP
总结一些js自定义的函数
2006/08/05 Javascript
setInterval 和 setTimeout会产生内存溢出
2008/02/15 Javascript
面向对象的Javascript之三(封装和信息隐藏)
2012/01/27 Javascript
js/jQuery对象互转(快速操作dom元素)
2013/02/04 Javascript
javascript图片相似度算法实现 js实现直方图和向量算法
2014/01/14 Javascript
纯js实现div内图片自适应大小(已测试,兼容火狐)
2014/06/16 Javascript
使用jQuery仿苹果官网焦点图特效
2014/12/23 Javascript
JS实现的系统调色板完整实例
2016/12/21 Javascript
jquery横向纵向鼠标滚轮全屏切换
2017/02/27 Javascript
深入理解vue Render函数
2017/07/19 Javascript
vue利用axios来完成数据的交互
2018/03/23 Javascript
vue里面使用mui的弹出日期选择插件实例
2018/09/16 Javascript
一些可能会用到的Node.js面试题
2019/06/15 Javascript
python使用cookielib库示例分享
2014/03/03 Python
python标准算法实现数组全排列的方法
2015/03/17 Python
Python实现的圆形绘制(画圆)示例
2018/01/31 Python
python 解决flask uwsgi 获取不到全局变量的问题
2019/12/22 Python
python使用dlib进行人脸检测和关键点的示例
2020/12/05 Python
Python爬虫制作翻译程序的示例代码
2021/02/22 Python
html5仿支付宝密码框的实现代码
2017/09/06 HTML / CSS
美国知名的女性服饰品牌:LOFT(洛芙特)
2016/08/05 全球购物
优衣库澳大利亚官网:UNIQLO澳大利亚
2017/01/18 全球购物
新媒传信软件测试面试题
2013/02/24 面试题
四年级科学教学反思
2014/02/10 职场文书
旅游市场营销方案
2014/03/09 职场文书
广告艺术设计专业自荐书
2014/07/08 职场文书
贯彻落实“八项规定”思想汇报
2014/09/13 职场文书
2015年综治维稳工作总结
2015/04/07 职场文书
党员“一帮一”活动总结
2015/05/07 职场文书
妈妈别哭观后感
2015/06/08 职场文书
敬老院活动感想
2015/08/07 职场文书
2016春季幼儿园开学寄语
2015/12/03 职场文书
子女赡养老人协议书
2016/03/23 职场文书
IIS服务器中设置HTTP重定向访问HTTPS
2022/04/29 Servers