Pytorch中accuracy和loss的计算知识点总结


Posted in Python onSeptember 10, 2019

这几天关于accuracy和loss的计算有一些疑惑,原来是自己还没有弄清楚。

给出实例

def train(train_loader, model, criteon, optimizer, epoch):
  train_loss = 0
  train_acc = 0
  num_correct= 0
  for step, (x,y) in enumerate(train_loader):

    # x: [b, 3, 224, 224], y: [b]
    x, y = x.to(device), y.to(device)

    model.train()
    logits = model(x)
    loss = criteon(logits, y)

    optimizer.zero_grad()
    loss.backward()
    optimizer.step()
    train_loss += float(loss.item())
    train_losses.append(train_loss)
    pred = logits.argmax(dim=1)
    num_correct += torch.eq(pred, y).sum().float().item()
  logger.info("Train Epoch: {}\t Loss: {:.6f}\t Acc: {:.6f}".format(epoch,train_loss/len(train_loader),num_correct/len(train_loader.dataset)))
  return num_correct/len(train_loader.dataset), train_loss/len(train_loader)

首先这样一次训练称为一个epoch,样本总数/batchsize是走完一个epoch所需的“步数”,相对应的,len(train_loader.dataset)也就是样本总数,len(train_loader)就是这个步数。

那么,accuracy的计算也就是在整个train_loader的for循环中(步数),把每个mini_batch中判断正确的个数累加起来,然后除以样本总数就行了;

而loss的计算有讲究了,首先在这里我们是计算交叉熵,关于交叉熵,也就是涉及到两个值,一个是模型给出的logits,也就是10个类,每个类的概率分布,另一个是样本自身的

label,在Pytorch中,只要把这两个值输进去就能计算交叉熵,用的方法是nn.CrossEntropyLoss,这个方法其实是计算了一个minibatch的均值了,因此累加以后需要除以的步数,也就是

minibatch的个数,而不是像accuracy那样是样本个数,这一点非常重要。

以上就是本次介绍的全部知识点内容,感谢大家对三水点靠木的支持。

Python 相关文章推荐
ubuntu系统下 python链接mysql数据库的方法
Jan 09 Python
利用python将xml文件解析成html文件的实现方法
Dec 22 Python
python语言中with as的用法使用详解
Feb 23 Python
python去掉空白行的多种实现代码
Mar 19 Python
对python中的logger模块全面讲解
Apr 28 Python
python修改txt文件中的某一项方法
Dec 29 Python
在pandas中遍历DataFrame行的实现方法
Oct 23 Python
python中rb含义理解
Jun 18 Python
读取nii或nii.gz文件中的信息即输出图像操作
Jul 01 Python
python matlab库简单用法讲解
Dec 31 Python
python 实现IP子网计算
Feb 18 Python
Python中OpenCV实现简单车牌字符切割
Jun 11 Python
python3.7环境下安装Anaconda的教程图解
Sep 10 #Python
Windows10下 python3.7 安装 facenet的教程
Sep 10 #Python
python 图像处理画一个正弦函数代码实例
Sep 10 #Python
Python操作Mongodb数据库的方法小结
Sep 10 #Python
Python使用matplotlib绘制三维参数曲线操作示例
Sep 10 #Python
Python matplotlib绘制饼状图功能示例
Sep 10 #Python
numpy.random.shuffle打乱顺序函数的实现
Sep 10 #Python
You might like
php解析xml提示Invalid byte 1 of 1-byte UTF-8 sequence错误的处理方法
2013/11/14 PHP
php简单定时执行任务的实现方法
2015/02/23 PHP
php版微信公众账号第三方管理工具开发简明教程
2016/09/23 PHP
php实现不通过扩展名准确判断文件类型的方法【finfo_file方法与二进制流】
2017/04/18 PHP
PHP操作Postgresql封装类与应用完整实例
2018/04/24 PHP
ThinkPHP5 框架引入 Go AOP,PHP AOP编程项目详解
2020/05/12 PHP
javascript setTimeout和setInterval计时的区别详解
2013/06/21 Javascript
使用js在页面中绘制表格核心代码
2013/09/16 Javascript
Javascript添加监听与删除监听用法详解
2014/12/19 Javascript
JSON取值前判断
2014/12/23 Javascript
javascript通过获取html标签属性class实现多选项卡的方法
2015/07/27 Javascript
用js实现每隔一秒刷新时间的实例(含年月日时分秒)
2017/10/25 Javascript
浅谈Node 调试工具入门教程
2018/03/20 Javascript
vuex中的 mapState,mapGetters,mapActions,mapMutations 的使用
2018/04/13 Javascript
JavaScript设计模式之观察者模式(发布订阅模式)原理与实现方法示例
2018/07/27 Javascript
js删除数组中某几项的方法总结
2019/01/16 Javascript
简单了解JavaScript中的执行上下文和堆栈
2019/06/24 Javascript
react MPA 多页配置详解
2019/10/18 Javascript
vue 项目软键盘回车触发搜索事件
2020/09/09 Javascript
vue 自定指令生成uuid滚动监听达到tab表格吸顶效果的代码
2020/09/16 Javascript
微信小程序将页面按钮悬浮固定在底部的实现代码
2020/10/29 Javascript
微信小程序中target和currentTarget的区别小结
2020/11/06 Javascript
Python自动化运维之IP地址处理模块详解
2017/12/10 Python
django反向解析URL和URL命名空间的方法
2018/06/05 Python
Django框架多表查询实例分析
2018/07/04 Python
详解Django中间件的5种自定义方法
2018/07/26 Python
我们为什么要减少Python中循环的使用
2019/07/10 Python
python+Django实现防止SQL注入的办法
2019/10/31 Python
美国著名手表网站:Timepiece
2017/11/15 全球购物
Travelstart沙特阿拉伯:廉价航班、豪华酒店和实惠的汽车租赁优惠
2019/04/06 全球购物
意大利在线高尔夫商店:Online Golf
2021/03/09 全球购物
应届生简历中的自我评价
2014/01/13 职场文书
《难忘的泼水节》教学反思
2014/02/27 职场文书
教师党员先进性教育自我剖析材料思想汇报
2014/09/24 职场文书
2019森林防火宣传标语大全!
2019/07/03 职场文书
python自动化之如何利用allure生成测试报告
2021/05/02 Python