解决pytorch 损失函数中输入输出不匹配的问题


Posted in Python onJune 05, 2021

一、pytorch 损失函数中输入输出不匹配问题

File "C:\Users\Rain\AppData\Local\Programs\Python\Anaconda.3.5.1\envs\python35\python35\lib\site-packages\torch\nn\modules\module.py", line 491, in __call__  result = self.forward(*input, **kwargs)

File "C:\Users\Rain\AppData\Local\Programs\Python\Anaconda.3.5.1\envs\python35\python35\lib\site-packages\torch\nn\modules\loss.py", line 500, in forward reduce=self.reduce)
 
File "C:\Users\Rain\AppData\Local\Programs\Python\Anaconda.3.5.1\envs\python35\python35\lib\site-packages\torch\nn\functional.py", line 1514, in binary_cross_entropy_with_logits
 
raise ValueError("Target size ({}) must be the same as input size ({})".format(target.size(), input.size()))
 
ValueError: Target size (torch.Size([32])) must be the same as input size (torch.Size([32,2]))

原因

input 和 target 尺寸不匹配

解决方案:

将target转为onehot

例如:

one_hot = torch.nn.functional.one_hot(masks, num_classes=args.num_classes)

二、Pytorch遇到权重不匹配的问题

最近,楼主在pytorch微调模型时遇到

size mismatch for fc.weight: copying a param with shape torch.Size([1000, 2048]) from checkpoint, the shape in current model is torch.Size([2, 2048]).

size mismatch for fc.bias: copying a param with shape torch.Size([1000]) from checkpoint, the shape in current model is torch.Size([2]).

这个是因为楼主下载的预训练模型中的全连接层是1000类别的,而楼主本人的类别只有2类,所以会报不匹配的错误

解决方案:

从报错信息可以看出,是fc层的权重参数不匹配,那我们只要不load 这一层的参数就可以了。

net = se_resnet50(num_classes=2)
pretrained_dict = torch.load("./senet/seresnet50-60a8950a85b2b.pkl")
model_dict = net.state_dict()
# 重新制作预训练的权重,主要是减去参数不匹配的层,楼主这边层名为“fc”
pretrained_dict = {k: v for k, v in pretrained_dict.items() if (k in model_dict and 'fc' not in k)}
# 更新权重
model_dict.update(pretrained_dict)
net.load_state_dict(model_dict)

以上为个人经验,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
Python中的map()函数和reduce()函数的用法
Apr 27 Python
利用Django内置的认证视图实现用户密码重置功能详解
Nov 24 Python
python利用smtplib实现QQ邮箱发送邮件
May 20 Python
pycham查看程序执行的时间方法
Nov 29 Python
python读出当前时间精度到秒的代码
Jul 05 Python
python 批量添加的button 使用同一点击事件的方法
Jul 17 Python
解决在pycharm运行代码,调用CMD窗口的命令运行显示乱码问题
Aug 23 Python
Python英文文章词频统计(14份剑桥真题词频统计)
Oct 13 Python
python自动提取文本中的时间(包含中文日期)
Aug 31 Python
15个应该掌握的Jupyter Notebook使用技巧(小结)
Sep 23 Python
Python3接口性能测试实例代码
Jun 20 Python
详细介绍python操作RabbitMq
Apr 12 Python
Pytorch distributed 多卡并行载入模型操作
Jun 05 #Python
Pytorch中的学习率衰减及其用法详解
Jun 05 #Python
pytorch finetuning 自己的图片进行训练操作
Jun 05 #Python
Python 如何将integer转化为罗马数(3999以内)
Jun 05 #Python
刚学完怎么用Python实现定时任务,转头就跑去撩妹!
OpenCV中resize函数插值算法的实现过程(五种)
Jun 05 #Python
OpenCV全景图像拼接的实现示例
You might like
轻松入门: 煮好咖啡的七个诀窍
2021/03/03 冲泡冲煮
二十行语句实现从Excel到mysql的转化
2006/10/09 PHP
php完全过滤HTML,JS,CSS等标签
2009/01/16 PHP
php 函数中使用static的说明
2012/06/01 PHP
php读取文件内容的方法汇总
2015/01/24 PHP
php反射类ReflectionClass用法分析
2016/05/12 PHP
php+html5+ajax实现上传图片的方法
2016/05/14 PHP
jQuery get和post 方法传值注意事项
2009/11/03 Javascript
JS的replace方法介绍
2012/10/20 Javascript
Javascript图像处理思路及实现代码
2012/12/25 Javascript
JQuery EasyUI 加载两次url的原因分析及解决方案
2014/08/18 Javascript
JavaScript中的异常捕捉介绍
2014/12/31 Javascript
JavaScript实现网站访问次数统计代码
2015/08/12 Javascript
javascript判断图片是否加载完成的方法推荐
2016/05/13 Javascript
JavaScript数据结构之二叉树的计数算法示例
2017/04/13 Javascript
react-native DatePicker日期选择组件的实现代码
2017/09/12 Javascript
JS设计模式之状态模式概念与用法分析
2018/02/05 Javascript
JavaScript设计模式---单例模式详解【四种基本形式】
2020/05/16 Javascript
基于JavaScript或jQuery实现网站夜间/高亮模式
2020/05/30 jQuery
详解Python核心编程中的浅拷贝与深拷贝
2018/01/07 Python
Python实现自动上京东抢手机
2018/02/06 Python
Python绘制KS曲线的实现方法
2018/08/13 Python
python 计算两个列表的相关系数的实现
2019/08/29 Python
python利用dlib获取人脸的68个landmark
2019/11/27 Python
python标识符命名规范原理解析
2020/01/10 Python
css3弹性盒模型实例介绍
2013/05/27 HTML / CSS
丝芙兰巴西官方商城:SEPHORA巴西
2016/10/31 全球购物
中国跨境电商:Tomtop
2017/03/16 全球购物
俄罗斯运动鞋商店:Sneakerhead
2018/05/10 全球购物
Linden Leaves官网:新西兰纯净护肤品
2020/12/20 全球购物
技术总监的工作职责
2013/11/13 职场文书
企业委托书范本
2014/09/13 职场文书
MySQL 8.0 之不可见列的基本操作
2021/05/20 MySQL
基于angular实现树形二级表格
2021/10/16 Javascript
postman中form-data、x-www-form-urlencoded、raw、binary的区别介绍
2022/01/18 HTML / CSS
全网非常详细的pytest配置文件
2022/07/15 Python