Pytorch框架实现mnist手写库识别(与tensorflow对比)


Posted in Python onJuly 20, 2020

前言最近在学习过程中需要用到pytorch框架,简单学习了一下,写了一个简单的案例,记录一下pytorch中搭建一个识别网络基础的东西。对应一位博主写的tensorflow的识别mnist数据集,将其改为pytorch框架,也可以详细看到两个框架大体的区别。

Tensorflow版本转载来源(CSDN博主「兔八哥1024」):https://3water.com/article/191157.htm

Pytorch实战mnist手写数字识别

#需要导入的包
import torch
import torch.nn as nn#用于构建网络层
import torch.optim as optim#导入优化器
from torch.utils.data import DataLoader#加载数据集的迭代器
from torchvision import datasets, transforms#用于加载mnsit数据集

#下载数据集

train_set = datasets.MNIST('./data', train=True, download=True,transform = transforms.Compose([
         transforms.ToTensor(),
         transforms.Normalize((0.1037,), (0.3081,))
       ]))
test_set = datasets.MNIST('./data', train=False, download=True,transform = transforms.Compose([
         transforms.ToTensor(),
         transforms.Normalize((0.1037,), (0.3081,))
       ]))

#构建网络(网络结构对应tensorflow的那一篇文章)

class Net(nn.Module):

  def __init__(self, num_classes=10):
    super(Net, self).__init__()
    self.features = nn.Sequential(
      nn.Conv2d(1, 32, kernel_size=5, stride=1, padding=2),
      nn.MaxPool2d(kernel_size=2,stride=2),
      nn.Conv2d(32, 64, kernel_size=5, stride=1, padding=2),
      nn.MaxPool2d(kernel_size=2,stride=2),

    )
    self.classifier = nn.Sequential(
      nn.Linear(3136, 7*7*64),
      nn.Linear(3136, num_classes),

    )

  def forward(self,x):
    x = self.features(x)
    x = torch.flatten(x, 1)
    x = self.classifier(x)

    return x
net=Net()
net.cuda()#用GPU运行

#计算误差,使用adam优化器优化误差
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(net.parameters(), 1e-2)

train_data = DataLoader(train_set, batch_size=128, shuffle=True)
test_data = DataLoader(test_set, batch_size=128, shuffle=False)


#训练过程
for epoch in range(1):
  net.train() ##在进行训练时加上train(),测试时加上eval()
  batch = 0

  for batch_images, batch_labels in train_data:

    average_loss = 0
    train_acc = 0

    ##在pytorch0.4之后将Variable 与tensor进行合并,所以这里不需要进行Variable封装
    if torch.cuda.is_available():
      batch_images, batch_labels = batch_images.cuda(),batch_labels.cuda()

    #前向传播
    out = net(batch_images)
    loss = criterion(out,batch_labels)


    average_loss = loss
    prediction = torch.max(out,1)[1]
    # print(prediction)

    train_correct = (prediction == batch_labels).sum()
    ##这里得到的train_correct是一个longtensor型,需要转换为float

    train_acc = (train_correct.float()) / 128

    optimizer.zero_grad() #清空梯度信息,否则在每次进行反向传播时都会累加
    loss.backward() #loss反向传播
    optimizer.step() ##梯度更新

    batch+=1
    print("Epoch: %d/%d || batch:%d/%d average_loss: %.3f || train_acc: %.2f"
       %(epoch, 20, batch, float(int(50000/128)), average_loss, train_acc))

# 在测试集上检验效果
net.eval() # 将模型改为预测模式
for idx,(im1, label1) in enumerate(test_data):
  if torch.cuda.is_available():
    im, label = im1.cuda(),label1.cuda()
  out = net(im)
  loss = criterion(out, label)

  eval_loss = loss

  pred = torch.max(out,1)[1]
  num_correct = (pred == label).sum()
  acc = (num_correct.float())/ 128
  eval_acc = acc

  print('EVA_Batch:{}, Eval Loss: {:.6f}, Eval Acc: {:.6f}'
   .format(idx,eval_loss , eval_acc))

运行结果:

Pytorch框架实现mnist手写库识别(与tensorflow对比)

到此这篇关于Pytorch框架实现mnist手写库识别(与tensorflow对比)的文章就介绍到这了,更多相关Pytorch框架实现mnist手写库识别(与tensorflow对比)内容请搜索三水点靠木以前的文章或继续浏览下面的相关文章希望大家以后多多支持三水点靠木!

Python 相关文章推荐
Python输出PowerPoint(ppt)文件中全部文字信息的方法
Apr 28 Python
深入讲解Python中的迭代器和生成器
Oct 26 Python
Python模拟百度登录实例详解
Jan 20 Python
python 匹配url中是否存在IP地址的方法
Jun 04 Python
详解Django 时间与时区设置问题
Jul 23 Python
python re模块匹配贪婪和非贪婪模式详解
Feb 11 Python
Python cookie的保存与读取、SSL讲解
Feb 17 Python
OpenCV+python实现实时目标检测功能
Jun 24 Python
Python持续监听文件变化代码实例
Jul 22 Python
python 线程的五个状态
Sep 22 Python
Anaconda使用IDLE的实现示例
Sep 23 Python
pycharm 如何查看某一函数源码的快捷键
May 12 Python
python集合能干吗
Jul 19 #Python
python如何建立全零数组
Jul 19 #Python
解决python中0x80072ee2错误的方法
Jul 19 #Python
python给视频添加背景音乐并改变音量的具体方法
Jul 19 #Python
python中加背景音乐如何操作
Jul 19 #Python
python实现最短路径的实例方法
Jul 19 #Python
python等待10秒执行下一命令的方法
Jul 19 #Python
You might like
PHP常用代码大全(新手入门必备)
2010/06/29 PHP
php上传图片到指定位置路径保存到数据库的具体实现
2013/12/30 PHP
zf框架的校验器InArray使用示例
2014/03/13 PHP
兼容ie6浏览器的php下载文件代码分享
2014/07/14 PHP
php使用wordwrap格式化文本段落的方法
2015/03/17 PHP
PHP实现冒泡排序的简单实例
2016/05/26 PHP
Yii2中使用asset压缩js,css文件的方法
2016/11/24 PHP
PHP预定义接口――Iterator用法示例
2020/06/05 PHP
JQuery获取元素文档大小、偏移和位置和滚动条位置的方法集合
2010/01/12 Javascript
前端开发部分总结[兼容性、DOM操作、跨域等](持续更新)
2010/03/04 Javascript
JavaScript判断DOM何时加载完毕的技巧
2012/11/11 Javascript
javascript的document.referrer浏览器支持、失效情况总结
2014/07/18 Javascript
jquery 取子节点及当前节点属性值的方法
2014/08/24 Javascript
javascript HTML+CSS实现经典橙色导航菜单
2016/02/16 Javascript
浅析vue数据绑定
2017/01/17 Javascript
jQuery实用密码强度检测
2017/03/02 Javascript
protractor的安装与基本使用教程
2017/07/07 Javascript
vue一个页面实现音乐播放器的示例
2018/02/06 Javascript
BootStrap中的模态框(modal,弹出层)功能示例代码
2018/11/02 Javascript
微信小程序云开发之使用云存储
2019/05/17 Javascript
js实现掷骰子小游戏
2019/10/24 Javascript
[46:57]EG vs Winstrike 2018国际邀请赛小组赛BO2 第二场 8.18
2018/08/19 DOTA
Python中列表(list)操作方法汇总
2014/08/18 Python
wxpython实现图书管理系统
2018/03/12 Python
python 内置函数汇总详解
2019/09/16 Python
浅谈Python type的使用
2019/11/19 Python
面试求职的个人自我评价
2013/11/16 职场文书
25岁生日感言
2014/01/13 职场文书
简历上的自我评价怎么写
2014/01/28 职场文书
小学校园活动策划
2014/01/30 职场文书
2014年元旦感言
2014/03/06 职场文书
服务行业口号
2014/06/11 职场文书
人代会标语
2014/06/30 职场文书
交通事故和解协议书
2014/09/25 职场文书
幸福终点站观后感
2015/06/04 职场文书
MySQL学习之基础命令实操总结
2022/03/19 MySQL