Pytorch中accuracy和loss的计算知识点总结


Posted in Python onSeptember 10, 2019

这几天关于accuracy和loss的计算有一些疑惑,原来是自己还没有弄清楚。

给出实例

def train(train_loader, model, criteon, optimizer, epoch):
  train_loss = 0
  train_acc = 0
  num_correct= 0
  for step, (x,y) in enumerate(train_loader):

    # x: [b, 3, 224, 224], y: [b]
    x, y = x.to(device), y.to(device)

    model.train()
    logits = model(x)
    loss = criteon(logits, y)

    optimizer.zero_grad()
    loss.backward()
    optimizer.step()
    train_loss += float(loss.item())
    train_losses.append(train_loss)
    pred = logits.argmax(dim=1)
    num_correct += torch.eq(pred, y).sum().float().item()
  logger.info("Train Epoch: {}\t Loss: {:.6f}\t Acc: {:.6f}".format(epoch,train_loss/len(train_loader),num_correct/len(train_loader.dataset)))
  return num_correct/len(train_loader.dataset), train_loss/len(train_loader)

首先这样一次训练称为一个epoch,样本总数/batchsize是走完一个epoch所需的“步数”,相对应的,len(train_loader.dataset)也就是样本总数,len(train_loader)就是这个步数。

那么,accuracy的计算也就是在整个train_loader的for循环中(步数),把每个mini_batch中判断正确的个数累加起来,然后除以样本总数就行了;

而loss的计算有讲究了,首先在这里我们是计算交叉熵,关于交叉熵,也就是涉及到两个值,一个是模型给出的logits,也就是10个类,每个类的概率分布,另一个是样本自身的

label,在Pytorch中,只要把这两个值输进去就能计算交叉熵,用的方法是nn.CrossEntropyLoss,这个方法其实是计算了一个minibatch的均值了,因此累加以后需要除以的步数,也就是

minibatch的个数,而不是像accuracy那样是样本个数,这一点非常重要。

以上就是本次介绍的全部知识点内容,感谢大家对三水点靠木的支持。

Python 相关文章推荐
Python中列表和元组的相关语句和方法讲解
Aug 20 Python
详解Python中的array数组模块相关使用
Jul 05 Python
在python3中pyqt5和mayavi不兼容问题的解决方法
Jan 08 Python
Python中类的创建和实例化操作示例
Feb 27 Python
详解python中list的使用
Mar 15 Python
详解python解压压缩包的五种方法
Jul 05 Python
pip 安装库比较慢的解决方法(国内镜像)
Oct 06 Python
Python中base64与xml取值结合问题
Dec 22 Python
nginx搭建基于python的web环境的实现步骤
Jan 03 Python
Django实现将一个字典传到前端显示出来
Apr 03 Python
Python 如何利用ffmpeg 处理视频素材
Nov 27 Python
python3.7环境下安装Anaconda的教程图解
Sep 10 #Python
Windows10下 python3.7 安装 facenet的教程
Sep 10 #Python
python 图像处理画一个正弦函数代码实例
Sep 10 #Python
Python操作Mongodb数据库的方法小结
Sep 10 #Python
Python使用matplotlib绘制三维参数曲线操作示例
Sep 10 #Python
Python matplotlib绘制饼状图功能示例
Sep 10 #Python
numpy.random.shuffle打乱顺序函数的实现
Sep 10 #Python
You might like
PHP采集腾讯微博的实现代码
2012/01/19 PHP
利用PHP脚本在Linux下用md5函数加密字符串的方法
2015/06/29 PHP
thinkPHP实现多字段模糊匹配查询的方法
2016/12/01 PHP
php生成0~1随机小数的方法(必看)
2017/04/05 PHP
使用jQuery和PHP实现类似360功能开关效果
2014/02/12 Javascript
javascript回车完美实现tab切换功能
2014/03/13 Javascript
jquery通过closest选择器修改上级元素的方法
2015/03/17 Javascript
使用JavaScript脚本无法直接改变Asp.net中Checkbox控件的Enable属性的解决方法
2015/09/16 Javascript
详解jQuery向动态生成的内容添加事件响应jQuery live()方法
2015/11/02 Javascript
JavaScript函数学习总结以及相关的编程习惯指南
2015/11/16 Javascript
JS获取url参数、主域名的方法实例分析
2016/08/03 Javascript
JS原型链 详解及示例代码
2016/09/06 Javascript
基于jQuery封装的分页组件
2017/06/26 jQuery
详解weex默认webpack.config.js改造
2018/01/08 Javascript
Vuejs 单文件组件实例详解
2018/02/09 Javascript
angular4 JavaScript内存溢出问题
2018/03/06 Javascript
常用的 JS 排序算法 整理版
2018/04/05 Javascript
Vue.js实现的表格增加删除demo示例
2018/05/22 Javascript
使用json-server简单完成CRUD模拟后台数据的方法
2018/07/12 Javascript
bootstrap-treeview实现多级树形菜单 后台JSON格式如何组织?
2019/07/26 Javascript
VUE:vuex 用户登录信息的数据写入与获取方式
2019/11/11 Javascript
详解Vue template 如何支持多个根结点
2020/02/10 Javascript
Python迭代用法实例教程
2014/09/08 Python
Python实现爬取需要登录的网站完整示例
2017/08/19 Python
JSONLINT:python的json数据验证库实例解析
2017/11/28 Python
用Python将mysql数据导出成json的方法
2018/08/21 Python
Python在OpenCV里实现极坐标变换功能
2019/09/02 Python
使用python代码进行身份证号校验的实现示例
2019/11/21 Python
python实现随机加减法生成器
2020/02/24 Python
python 一维二维插值实例
2020/04/22 Python
Python中的None与 NULL(即空字符)的区别详解
2020/09/24 Python
Html5实现二维码扫描并解析
2016/01/20 HTML / CSS
基于html5绘制圆形多角图案
2016/04/21 HTML / CSS
黄河象教学反思
2014/02/10 职场文书
《东方明珠》教学反思
2014/04/20 职场文书
mysql中DCL常用的用户和权限控制
2022/03/31 MySQL