Pytorch中accuracy和loss的计算知识点总结


Posted in Python onSeptember 10, 2019

这几天关于accuracy和loss的计算有一些疑惑,原来是自己还没有弄清楚。

给出实例

def train(train_loader, model, criteon, optimizer, epoch):
  train_loss = 0
  train_acc = 0
  num_correct= 0
  for step, (x,y) in enumerate(train_loader):

    # x: [b, 3, 224, 224], y: [b]
    x, y = x.to(device), y.to(device)

    model.train()
    logits = model(x)
    loss = criteon(logits, y)

    optimizer.zero_grad()
    loss.backward()
    optimizer.step()
    train_loss += float(loss.item())
    train_losses.append(train_loss)
    pred = logits.argmax(dim=1)
    num_correct += torch.eq(pred, y).sum().float().item()
  logger.info("Train Epoch: {}\t Loss: {:.6f}\t Acc: {:.6f}".format(epoch,train_loss/len(train_loader),num_correct/len(train_loader.dataset)))
  return num_correct/len(train_loader.dataset), train_loss/len(train_loader)

首先这样一次训练称为一个epoch,样本总数/batchsize是走完一个epoch所需的“步数”,相对应的,len(train_loader.dataset)也就是样本总数,len(train_loader)就是这个步数。

那么,accuracy的计算也就是在整个train_loader的for循环中(步数),把每个mini_batch中判断正确的个数累加起来,然后除以样本总数就行了;

而loss的计算有讲究了,首先在这里我们是计算交叉熵,关于交叉熵,也就是涉及到两个值,一个是模型给出的logits,也就是10个类,每个类的概率分布,另一个是样本自身的

label,在Pytorch中,只要把这两个值输进去就能计算交叉熵,用的方法是nn.CrossEntropyLoss,这个方法其实是计算了一个minibatch的均值了,因此累加以后需要除以的步数,也就是

minibatch的个数,而不是像accuracy那样是样本个数,这一点非常重要。

以上就是本次介绍的全部知识点内容,感谢大家对三水点靠木的支持。

Python 相关文章推荐
使用PYTHON创建XML文档
Mar 01 Python
Python在线运行代码助手
Jul 15 Python
Python实现拷贝多个文件到同一目录的方法
Sep 19 Python
使用Python的Django和layim实现即时通讯的方法
May 25 Python
Python3实现腾讯云OCR识别
Nov 27 Python
django 使用全局搜索功能的实例详解
Jul 18 Python
Python多叉树的构造及取出节点数据(treelib)的方法
Aug 09 Python
简单了解为什么python函数后有多个括号
Dec 19 Python
对Pytorch中Tensor的各种池化操作解析
Jan 03 Python
在python3中实现更新界面
Feb 21 Python
Python selenium如何打包静态网页并下载
Aug 12 Python
MAC平台基于Python Appium环境搭建过程图解
Aug 13 Python
python3.7环境下安装Anaconda的教程图解
Sep 10 #Python
Windows10下 python3.7 安装 facenet的教程
Sep 10 #Python
python 图像处理画一个正弦函数代码实例
Sep 10 #Python
Python操作Mongodb数据库的方法小结
Sep 10 #Python
Python使用matplotlib绘制三维参数曲线操作示例
Sep 10 #Python
Python matplotlib绘制饼状图功能示例
Sep 10 #Python
numpy.random.shuffle打乱顺序函数的实现
Sep 10 #Python
You might like
解析php中获取系统信息的方法
2013/06/25 PHP
php中使用key,value,current,next和prev函数遍历数组的方法
2015/03/17 PHP
Laravel自定义 封装便捷返回Json数据格式的引用方法
2019/09/29 PHP
ThinkPHP5与单元测试PHPUnit使用详解
2020/02/23 PHP
拖动Html元素集合 Drag and Drop any item
2006/12/22 Javascript
对 lightbox JS 图片控件进行了一下改造, 使其他支持复杂的图片说明
2010/03/20 Javascript
jquery乱码与contentType属性设置问题解决方案
2013/01/07 Javascript
JS案例分享之金额小写转大写
2014/05/15 Javascript
JavaScript字符串对象toLowerCase方法入门实例(用于把字母转换为小写)
2014/10/17 Javascript
JavaScript事件委托技术实例分析
2015/02/06 Javascript
基于Vue实现页面切换左右滑动效果
2020/06/29 Javascript
jQuery动态添加.active 实现导航效果代码思路详解
2017/08/29 jQuery
基于Swiper实现移动端页面图片轮播效果
2017/12/28 Javascript
用vue快速开发app的脚手架工具
2018/06/11 Javascript
详解使用VUE搭建后台管理系统(vue-cli更新至3.0)
2018/08/22 Javascript
微信小程序外卖选购页实现切换分类与数量加减功能案例
2019/01/15 Javascript
详解如何用webpack4从零开始构建react开发环境
2019/01/27 Javascript
es6函数name属性功能与用法实例分析
2020/04/18 Javascript
微信小程序图片右边加两行文字的代码
2020/04/23 Javascript
vue项目中使用多选框的实例代码
2020/07/22 Javascript
[16:27]DOTA2 HEROS教学视频教你分分钟做大人-艾欧
2014/06/11 DOTA
[05:20]2018DOTA2亚洲邀请赛主赛事第三日战况回顾 LGD率先挺进胜者组决赛
2018/04/06 DOTA
Python编程语言的35个与众不同之处(语言特征和使用技巧)
2014/07/07 Python
python递归打印某个目录的内容(实例讲解)
2017/08/30 Python
python实现录音小程序
2020/10/26 Python
pyqt5实现俄罗斯方块游戏
2019/01/11 Python
django的ORM模型的实现原理
2019/03/04 Python
python获取时间戳的实现示例(10位和13位)
2020/09/23 Python
椰子猫砂:CatSpot
2018/08/27 全球购物
Perfume’s Club法国站:购买香水和化妆品
2019/05/02 全球购物
人力资源部经理的岗位职责
2014/03/04 职场文书
小学生作文评语
2014/04/18 职场文书
2014老师三严三实对照检查材料思想汇报
2014/09/18 职场文书
考研复习计划
2015/01/19 职场文书
超强台风观后感
2015/06/09 职场文书
vue项目proxyTable配置和部署服务器
2022/04/14 Vue.js