Pytorch中accuracy和loss的计算知识点总结


Posted in Python onSeptember 10, 2019

这几天关于accuracy和loss的计算有一些疑惑,原来是自己还没有弄清楚。

给出实例

def train(train_loader, model, criteon, optimizer, epoch):
  train_loss = 0
  train_acc = 0
  num_correct= 0
  for step, (x,y) in enumerate(train_loader):

    # x: [b, 3, 224, 224], y: [b]
    x, y = x.to(device), y.to(device)

    model.train()
    logits = model(x)
    loss = criteon(logits, y)

    optimizer.zero_grad()
    loss.backward()
    optimizer.step()
    train_loss += float(loss.item())
    train_losses.append(train_loss)
    pred = logits.argmax(dim=1)
    num_correct += torch.eq(pred, y).sum().float().item()
  logger.info("Train Epoch: {}\t Loss: {:.6f}\t Acc: {:.6f}".format(epoch,train_loss/len(train_loader),num_correct/len(train_loader.dataset)))
  return num_correct/len(train_loader.dataset), train_loss/len(train_loader)

首先这样一次训练称为一个epoch,样本总数/batchsize是走完一个epoch所需的“步数”,相对应的,len(train_loader.dataset)也就是样本总数,len(train_loader)就是这个步数。

那么,accuracy的计算也就是在整个train_loader的for循环中(步数),把每个mini_batch中判断正确的个数累加起来,然后除以样本总数就行了;

而loss的计算有讲究了,首先在这里我们是计算交叉熵,关于交叉熵,也就是涉及到两个值,一个是模型给出的logits,也就是10个类,每个类的概率分布,另一个是样本自身的

label,在Pytorch中,只要把这两个值输进去就能计算交叉熵,用的方法是nn.CrossEntropyLoss,这个方法其实是计算了一个minibatch的均值了,因此累加以后需要除以的步数,也就是

minibatch的个数,而不是像accuracy那样是样本个数,这一点非常重要。

以上就是本次介绍的全部知识点内容,感谢大家对三水点靠木的支持。

Python 相关文章推荐
Python实现复杂对象转JSON的方法示例
Jun 22 Python
Python实现简单的语音识别系统
Dec 13 Python
Python实现控制台中的进度条功能代码
Dec 22 Python
python 平衡二叉树实现代码示例
Jul 07 Python
Python通用循环的构造方法实例分析
Dec 19 Python
基于python if 判断选择结构的实例详解
May 06 Python
Flask框架路由和视图用法实例分析
Nov 07 Python
Python使用plt.boxplot() 参数绘制箱线图
Jun 04 Python
如何利用python进行时间序列分析
Aug 04 Python
python如何运行js语句
Sep 09 Python
使用python画出逻辑斯蒂映射(logistic map)中的分叉图案例
Dec 11 Python
Python上下文管理器Content Manager
Jun 26 Python
python3.7环境下安装Anaconda的教程图解
Sep 10 #Python
Windows10下 python3.7 安装 facenet的教程
Sep 10 #Python
python 图像处理画一个正弦函数代码实例
Sep 10 #Python
Python操作Mongodb数据库的方法小结
Sep 10 #Python
Python使用matplotlib绘制三维参数曲线操作示例
Sep 10 #Python
Python matplotlib绘制饼状图功能示例
Sep 10 #Python
numpy.random.shuffle打乱顺序函数的实现
Sep 10 #Python
You might like
非常不错的MySQL优化的8条经验
2008/03/24 PHP
PHP下打开URL地址的几种方法小结
2010/05/16 PHP
ThinkPHP基于PHPExcel导入Excel文件的方法
2014/10/15 PHP
PHP读取大文件的多种方法介绍
2016/04/04 PHP
PHP 中使用explode()函数切割字符串为数组的示例
2017/05/06 PHP
Javascript里使用Dom操作Xml
2006/09/20 Javascript
Javascript实例教程(19) 使用HoTMetal(4)
2006/12/23 Javascript
js资料prototype 属性
2007/03/13 Javascript
基于jquery的页面划词搜索JS
2010/09/14 Javascript
js 点击页面其他地方关闭弹出层(示例代码)
2013/12/24 Javascript
Node.js和PHP根据ip获取地理位置的方法
2014/03/14 Javascript
jquery通过closest选择器修改上级元素的方法
2015/03/17 Javascript
初步认识JavaScript函数库jQuery
2015/06/18 Javascript
vue组件实现可搜索下拉框扩展
2020/10/23 Javascript
ES6的Fetch异步请求的实现方法
2018/12/07 Javascript
微信小程序如何修改本地缓存key中单个数据的详解
2019/04/26 Javascript
JS实现吸顶特效
2020/01/08 Javascript
Python自定义函数的创建、调用和函数的参数详解
2014/03/11 Python
使用python读取csv文件快速插入数据库的实例
2018/06/21 Python
python 生成器和迭代器的原理解析
2019/10/12 Python
Python中内建模块collections如何使用
2020/05/27 Python
PyTorch-GPU加速实例
2020/06/23 Python
Python中的With语句的使用及原理
2020/07/29 Python
python pip如何手动安装二进制包
2020/09/30 Python
探讨HTML5移动开发的几大特性(必看)
2015/12/30 HTML / CSS
TOWER London官网:鞋子、靴子、运动鞋等
2019/07/14 全球购物
金额转换,阿拉伯数字的金额转换成中国传统的形式如:(¥1011)-> (一千零一拾一元整)输出
2015/05/29 面试题
保密工作实施方案
2014/02/24 职场文书
2014年学校国庆主题活动方案
2014/09/16 职场文书
2016年国庆节67周年活动总结
2016/04/01 职场文书
2019商业计划书格式、范文
2019/04/24 职场文书
nginx里的rewrite跳转的实现
2021/03/31 Servers
PyQt5 显示超清高分辨率图片的方法
2021/04/11 Python
在Python中如何使用yield
2021/06/07 Python
【2·13】一图读懂中国无线电发展
2022/02/18 无线电
nginx常用配置conf的示例代码详解
2022/03/21 Servers