Pytorch中accuracy和loss的计算知识点总结


Posted in Python onSeptember 10, 2019

这几天关于accuracy和loss的计算有一些疑惑,原来是自己还没有弄清楚。

给出实例

def train(train_loader, model, criteon, optimizer, epoch):
  train_loss = 0
  train_acc = 0
  num_correct= 0
  for step, (x,y) in enumerate(train_loader):

    # x: [b, 3, 224, 224], y: [b]
    x, y = x.to(device), y.to(device)

    model.train()
    logits = model(x)
    loss = criteon(logits, y)

    optimizer.zero_grad()
    loss.backward()
    optimizer.step()
    train_loss += float(loss.item())
    train_losses.append(train_loss)
    pred = logits.argmax(dim=1)
    num_correct += torch.eq(pred, y).sum().float().item()
  logger.info("Train Epoch: {}\t Loss: {:.6f}\t Acc: {:.6f}".format(epoch,train_loss/len(train_loader),num_correct/len(train_loader.dataset)))
  return num_correct/len(train_loader.dataset), train_loss/len(train_loader)

首先这样一次训练称为一个epoch,样本总数/batchsize是走完一个epoch所需的“步数”,相对应的,len(train_loader.dataset)也就是样本总数,len(train_loader)就是这个步数。

那么,accuracy的计算也就是在整个train_loader的for循环中(步数),把每个mini_batch中判断正确的个数累加起来,然后除以样本总数就行了;

而loss的计算有讲究了,首先在这里我们是计算交叉熵,关于交叉熵,也就是涉及到两个值,一个是模型给出的logits,也就是10个类,每个类的概率分布,另一个是样本自身的

label,在Pytorch中,只要把这两个值输进去就能计算交叉熵,用的方法是nn.CrossEntropyLoss,这个方法其实是计算了一个minibatch的均值了,因此累加以后需要除以的步数,也就是

minibatch的个数,而不是像accuracy那样是样本个数,这一点非常重要。

以上就是本次介绍的全部知识点内容,感谢大家对三水点靠木的支持。

Python 相关文章推荐
Python中列表和元组的相关语句和方法讲解
Aug 20 Python
使用Python的Twisted框架编写非阻塞程序的代码示例
May 25 Python
解决PyCharm中光标变粗的问题
Aug 05 Python
python实现自动网页截图并裁剪图片
Jul 30 Python
Python Flask前后端Ajax交互的方法示例
Jul 31 Python
Python面向对象程序设计OOP深入分析【构造函数,组合类,工具类等】
Jan 05 Python
利用python实现对web服务器的目录探测的方法
Feb 26 Python
深入学习python多线程与GIL
Aug 26 Python
python deque模块简单使用代码实例
Mar 12 Python
Python如何转换字符串大小写
Jun 04 Python
关于python3.9安装wordcloud出错的问题及解决办法
Nov 02 Python
使用Python爬取Json数据的示例代码
Dec 07 Python
python3.7环境下安装Anaconda的教程图解
Sep 10 #Python
Windows10下 python3.7 安装 facenet的教程
Sep 10 #Python
python 图像处理画一个正弦函数代码实例
Sep 10 #Python
Python操作Mongodb数据库的方法小结
Sep 10 #Python
Python使用matplotlib绘制三维参数曲线操作示例
Sep 10 #Python
Python matplotlib绘制饼状图功能示例
Sep 10 #Python
numpy.random.shuffle打乱顺序函数的实现
Sep 10 #Python
You might like
php实现mysql同步的实现方法
2009/10/21 PHP
PHP XML数据解析代码
2010/05/26 PHP
IIS6.0 开启Gzip方法及PHP Gzip函数分享
2014/06/08 PHP
Ajax PHP JavaScript MySQL实现简易无刷新在线聊天室
2016/08/17 PHP
PHP中利用sleep函数实现定时执行功能实现代码
2016/08/25 PHP
php对接java现实加签验签的实例
2016/11/25 PHP
PHP在同一域名下两个不同的项目做独立登录机制详解
2017/09/22 PHP
Mootools 1.2教程(21)——类(二)
2009/09/15 Javascript
js对象之JS入门之Array对象操作小结
2011/01/09 Javascript
Jquery+CSS3实现一款简洁大气带滑动效果的弹出层
2013/05/15 Javascript
高效的获取当前元素是父元素的第几个子元素
2013/10/15 Javascript
类似php的js数组的in_array函数自定义方法
2013/12/27 Javascript
js实现DOM走马灯特效的方法
2015/01/21 Javascript
小心!AngularJS结合RequireJS做文件合并压缩的那些坑
2016/01/09 Javascript
jQuery on()绑定动态元素出现的问题小结
2016/02/19 Javascript
实例讲解DataTables固定表格宽度(设置横向滚动条)
2017/07/11 Javascript
详解nodeJs文件系统(fs)与流(stream)
2018/01/24 NodeJs
JavaScript面向对象程序设计中对象的定义和继承详解
2019/07/29 Javascript
three.js显示中文字体与tween应用详析
2021/01/04 Javascript
jQuery实现购物车全功能
2021/01/11 jQuery
[01:03:59]2018DOTA2亚洲邀请赛3月30日 小组赛B组VGJ.T VS Secret
2018/03/31 DOTA
[01:16:12]完美世界DOTA2联赛PWL S2 FTD vs Inki 第一场 11.21
2020/11/23 DOTA
python中去空格函数的用法
2014/08/21 Python
pandas DataFrame数据转为list的方法
2018/04/11 Python
在Mac上删除自己安装的Python方法
2018/10/29 Python
在Python中增加和插入元素的示例
2018/11/01 Python
基于python的BP神经网络及异或实现过程解析
2019/09/30 Python
tensorflow 获取checkpoint中的变量列表实例
2020/02/11 Python
详解移动端h5页面根据屏幕适配的四种方案
2020/04/15 HTML / CSS
AmazeUI 导航条的实现示例
2020/08/14 HTML / CSS
AmazeUI框架搭建的方法步骤(图文)
2020/08/17 HTML / CSS
安全伴我行演讲稿
2014/09/04 职场文书
孩子满月酒答谢词
2015/09/30 职场文书
2016护理专业求职自荐书
2016/01/28 职场文书
你会写请假条吗?
2019/06/26 职场文书
MySQL解决Navicat设置默认字符串时的报错问题
2022/06/16 MySQL