Pytorch中accuracy和loss的计算知识点总结


Posted in Python onSeptember 10, 2019

这几天关于accuracy和loss的计算有一些疑惑,原来是自己还没有弄清楚。

给出实例

def train(train_loader, model, criteon, optimizer, epoch):
  train_loss = 0
  train_acc = 0
  num_correct= 0
  for step, (x,y) in enumerate(train_loader):

    # x: [b, 3, 224, 224], y: [b]
    x, y = x.to(device), y.to(device)

    model.train()
    logits = model(x)
    loss = criteon(logits, y)

    optimizer.zero_grad()
    loss.backward()
    optimizer.step()
    train_loss += float(loss.item())
    train_losses.append(train_loss)
    pred = logits.argmax(dim=1)
    num_correct += torch.eq(pred, y).sum().float().item()
  logger.info("Train Epoch: {}\t Loss: {:.6f}\t Acc: {:.6f}".format(epoch,train_loss/len(train_loader),num_correct/len(train_loader.dataset)))
  return num_correct/len(train_loader.dataset), train_loss/len(train_loader)

首先这样一次训练称为一个epoch,样本总数/batchsize是走完一个epoch所需的“步数”,相对应的,len(train_loader.dataset)也就是样本总数,len(train_loader)就是这个步数。

那么,accuracy的计算也就是在整个train_loader的for循环中(步数),把每个mini_batch中判断正确的个数累加起来,然后除以样本总数就行了;

而loss的计算有讲究了,首先在这里我们是计算交叉熵,关于交叉熵,也就是涉及到两个值,一个是模型给出的logits,也就是10个类,每个类的概率分布,另一个是样本自身的

label,在Pytorch中,只要把这两个值输进去就能计算交叉熵,用的方法是nn.CrossEntropyLoss,这个方法其实是计算了一个minibatch的均值了,因此累加以后需要除以的步数,也就是

minibatch的个数,而不是像accuracy那样是样本个数,这一点非常重要。

以上就是本次介绍的全部知识点内容,感谢大家对三水点靠木的支持。

Python 相关文章推荐
python实现将内容分行输出
Nov 05 Python
让python在hadoop上跑起来
Jan 27 Python
Python合并字典键值并去除重复元素的实例
Dec 18 Python
不要用强制方法杀掉python线程
Feb 26 Python
python logging日志模块以及多进程日志详解
Apr 18 Python
对python操作kafka写入json数据的简单demo分享
Dec 27 Python
numpy linalg模块的具体使用方法
May 26 Python
使用python打印十行杨辉三角过程详解
Jul 10 Python
pandas如何处理缺失值
Jul 31 Python
关于tf.reverse_sequence()简述
Jan 20 Python
pandas 数据类型转换的实现
Dec 29 Python
jupyter 添加不同内核的操作
Feb 06 Python
python3.7环境下安装Anaconda的教程图解
Sep 10 #Python
Windows10下 python3.7 安装 facenet的教程
Sep 10 #Python
python 图像处理画一个正弦函数代码实例
Sep 10 #Python
Python操作Mongodb数据库的方法小结
Sep 10 #Python
Python使用matplotlib绘制三维参数曲线操作示例
Sep 10 #Python
Python matplotlib绘制饼状图功能示例
Sep 10 #Python
numpy.random.shuffle打乱顺序函数的实现
Sep 10 #Python
You might like
PHP判断json格式是否正确的实现代码
2017/09/20 PHP
javascript setTimeout()传递函数参数(包括传递对象参数)
2010/04/07 Javascript
JQuery 文本框使用小结
2010/05/22 Javascript
JavaScript 注册事件代码
2011/01/27 Javascript
jquery+ajax+C#实现无刷新操作数据库数据的简单实例
2014/02/08 Javascript
浅谈JavaScript Array对象
2014/12/29 Javascript
JavaScript获取页面中表单(form)数量的方法
2015/04/03 Javascript
关于延迟加载JavaScript
2015/05/05 Javascript
深入理解MVC中的时间js格式化
2016/05/19 Javascript
浅谈JS中的bind方法与函数柯里化
2016/08/10 Javascript
微信小程序  audio音频播放详解及实例
2016/11/02 Javascript
Bootstrap table使用方法详细介绍
2016/12/09 Javascript
浅谈jQuery中的$.extend方法来扩展JSON对象
2017/02/12 Javascript
JS正则获取HTML元素的方法
2017/03/31 Javascript
Vue生命周期示例详解
2017/04/12 Javascript
基于JavaScript实现的希尔排序算法分析
2017/04/14 Javascript
js限制输入框只能输入数字(onkeyup触发)
2018/09/28 Javascript
JS基于ES6新特性async await进行异步处理操作示例
2019/02/02 Javascript
Vue+ElementUI项目使用webpack输出MPA的方法
2019/08/27 Javascript
实用的 vue tags 创建缓存导航的过程实现
2020/12/03 Vue.js
绘制微信小程序验证码功能的实例代码
2021/01/05 Javascript
Python3 max()函数基础用法
2019/02/19 Python
python实现远程控制电脑
2019/05/23 Python
keras 自定义loss model.add_loss的使用详解
2020/06/22 Python
Python requests接口测试实现代码
2020/09/08 Python
Python通用唯一标识符uuid模块使用案例
2020/09/10 Python
为世界各地的女性设计和生产时尚服装:ROMWE
2016/09/17 全球购物
服装设计专业毕业生推荐信
2013/11/09 职场文书
公务员总结性个人自我评价
2013/12/05 职场文书
研究生导师评语
2014/12/31 职场文书
教师工作能力自我评价
2015/03/04 职场文书
社区禁毒宣传活动总结
2015/05/07 职场文书
导游词之泉州崇武古城
2019/12/20 职场文书
Unity连接MySQL并读取表格数据的实现代码
2021/06/20 MySQL
vue postcss-px2rem 自适应布局
2022/05/15 Vue.js
Python可视化神器pyecharts绘制地理图表
2022/07/07 Python