Pytorch框架实现mnist手写库识别(与tensorflow对比)


Posted in Python onJuly 20, 2020

前言最近在学习过程中需要用到pytorch框架,简单学习了一下,写了一个简单的案例,记录一下pytorch中搭建一个识别网络基础的东西。对应一位博主写的tensorflow的识别mnist数据集,将其改为pytorch框架,也可以详细看到两个框架大体的区别。

Tensorflow版本转载来源(CSDN博主「兔八哥1024」):https://3water.com/article/191157.htm

Pytorch实战mnist手写数字识别

#需要导入的包
import torch
import torch.nn as nn#用于构建网络层
import torch.optim as optim#导入优化器
from torch.utils.data import DataLoader#加载数据集的迭代器
from torchvision import datasets, transforms#用于加载mnsit数据集

#下载数据集

train_set = datasets.MNIST('./data', train=True, download=True,transform = transforms.Compose([
         transforms.ToTensor(),
         transforms.Normalize((0.1037,), (0.3081,))
       ]))
test_set = datasets.MNIST('./data', train=False, download=True,transform = transforms.Compose([
         transforms.ToTensor(),
         transforms.Normalize((0.1037,), (0.3081,))
       ]))

#构建网络(网络结构对应tensorflow的那一篇文章)

class Net(nn.Module):

  def __init__(self, num_classes=10):
    super(Net, self).__init__()
    self.features = nn.Sequential(
      nn.Conv2d(1, 32, kernel_size=5, stride=1, padding=2),
      nn.MaxPool2d(kernel_size=2,stride=2),
      nn.Conv2d(32, 64, kernel_size=5, stride=1, padding=2),
      nn.MaxPool2d(kernel_size=2,stride=2),

    )
    self.classifier = nn.Sequential(
      nn.Linear(3136, 7*7*64),
      nn.Linear(3136, num_classes),

    )

  def forward(self,x):
    x = self.features(x)
    x = torch.flatten(x, 1)
    x = self.classifier(x)

    return x
net=Net()
net.cuda()#用GPU运行

#计算误差,使用adam优化器优化误差
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(net.parameters(), 1e-2)

train_data = DataLoader(train_set, batch_size=128, shuffle=True)
test_data = DataLoader(test_set, batch_size=128, shuffle=False)


#训练过程
for epoch in range(1):
  net.train() ##在进行训练时加上train(),测试时加上eval()
  batch = 0

  for batch_images, batch_labels in train_data:

    average_loss = 0
    train_acc = 0

    ##在pytorch0.4之后将Variable 与tensor进行合并,所以这里不需要进行Variable封装
    if torch.cuda.is_available():
      batch_images, batch_labels = batch_images.cuda(),batch_labels.cuda()

    #前向传播
    out = net(batch_images)
    loss = criterion(out,batch_labels)


    average_loss = loss
    prediction = torch.max(out,1)[1]
    # print(prediction)

    train_correct = (prediction == batch_labels).sum()
    ##这里得到的train_correct是一个longtensor型,需要转换为float

    train_acc = (train_correct.float()) / 128

    optimizer.zero_grad() #清空梯度信息,否则在每次进行反向传播时都会累加
    loss.backward() #loss反向传播
    optimizer.step() ##梯度更新

    batch+=1
    print("Epoch: %d/%d || batch:%d/%d average_loss: %.3f || train_acc: %.2f"
       %(epoch, 20, batch, float(int(50000/128)), average_loss, train_acc))

# 在测试集上检验效果
net.eval() # 将模型改为预测模式
for idx,(im1, label1) in enumerate(test_data):
  if torch.cuda.is_available():
    im, label = im1.cuda(),label1.cuda()
  out = net(im)
  loss = criterion(out, label)

  eval_loss = loss

  pred = torch.max(out,1)[1]
  num_correct = (pred == label).sum()
  acc = (num_correct.float())/ 128
  eval_acc = acc

  print('EVA_Batch:{}, Eval Loss: {:.6f}, Eval Acc: {:.6f}'
   .format(idx,eval_loss , eval_acc))

运行结果:

Pytorch框架实现mnist手写库识别(与tensorflow对比)

到此这篇关于Pytorch框架实现mnist手写库识别(与tensorflow对比)的文章就介绍到这了,更多相关Pytorch框架实现mnist手写库识别(与tensorflow对比)内容请搜索三水点靠木以前的文章或继续浏览下面的相关文章希望大家以后多多支持三水点靠木!

Python 相关文章推荐
用Python抢过年的火车票附源码
Dec 07 Python
python 读取excel文件生成sql文件实例详解
May 12 Python
Python实现分段线性插值
Dec 17 Python
利用Python查看微信共同好友功能的实现代码
Apr 24 Python
让你Python到很爽的加速递归函数的装饰器
May 26 Python
在Python中COM口的调用方法
Jul 03 Python
python实现本地批量ping多个IP的方法示例
Aug 07 Python
Python装饰器使用你可能不知道的几种姿势
Oct 25 Python
pycharm下配置pyqt5的教程(anaconda虚拟环境下+tensorflow)
Mar 25 Python
使用keras2.0 将Merge层改为函数式
May 23 Python
Python3爬虫中关于Ajax分析方法的总结
Jul 10 Python
Pandas加速代码之避免使用for循环
May 30 Python
python集合能干吗
Jul 19 #Python
python如何建立全零数组
Jul 19 #Python
解决python中0x80072ee2错误的方法
Jul 19 #Python
python给视频添加背景音乐并改变音量的具体方法
Jul 19 #Python
python中加背景音乐如何操作
Jul 19 #Python
python实现最短路径的实例方法
Jul 19 #Python
python等待10秒执行下一命令的方法
Jul 19 #Python
You might like
PHP开发实现微信退款功能示例
2017/11/25 PHP
Mootools 1.2教程 Tooltips
2009/09/15 Javascript
JS无限树状列表实现代码
2011/01/11 Javascript
网站内容禁止复制和粘贴、另存为的js代码
2014/02/26 Javascript
window.open()实现post传递参数
2015/03/12 Javascript
JS组件Bootstrap dropdown组件扩展hover事件
2016/04/17 Javascript
jQuery模拟select实现下拉菜单功能
2016/06/20 Javascript
JS动态加载脚本并执行回调操作
2016/08/24 Javascript
巧用canvas
2017/01/21 Javascript
基于es6三点运算符的使用方法(实例讲解)
2017/10/12 Javascript
vue input 输入校验字母数字组合且长度小于30的实现代码
2018/05/16 Javascript
详解vue2.0+axios+mock+axios-mock+adapter实现登陆
2018/07/19 Javascript
浅谈开发eslint规则
2018/10/01 Javascript
Vue插值、表达式、分隔符、指令知识小结
2018/10/12 Javascript
vue 动态表单开发方法案例详解
2019/12/02 Javascript
小程序实现录音功能
2020/09/22 Javascript
Python中计算三角函数之cos()方法的使用简介
2015/05/15 Python
详解Python的Django框架中Manager方法的使用
2015/07/21 Python
Python简单实现socket信息发送与监听功能示例
2018/01/03 Python
Pipenv一键搭建python虚拟环境的方法
2018/05/22 Python
基于scrapy的redis安装和配置方法
2018/06/13 Python
PyTorch 1.0 正式版已经发布了
2018/12/13 Python
python 字典操作提取key,value的方法
2019/06/26 Python
python实现宿舍管理系统
2019/11/22 Python
详解Pycharm出现out of memory的终极解决方法
2020/03/03 Python
Python参数传递实现过程及原理详解
2020/05/14 Python
详解Python中第三方库Faker
2020/09/25 Python
Python机器学习工具scikit-learn的使用笔记
2021/01/28 Python
澳大利亚个性化儿童礼品网站:Bright Star Kids
2019/06/14 全球购物
乌克兰鞋类购物网站:Eobuv.com.ua
2020/11/28 全球购物
最新英语专业学生求职信范文
2013/09/21 职场文书
校园公益广告语
2014/03/13 职场文书
民主评议政风行风整改方案
2014/09/17 职场文书
赵乐秦在党的群众路线教育实践活动总结大会上的讲话稿
2014/10/25 职场文书
Python实现简单的猜单词
2021/06/15 Python
SpringBoot快速入门详解
2021/07/21 Java/Android