Pytorch框架实现mnist手写库识别(与tensorflow对比)


Posted in Python onJuly 20, 2020

前言最近在学习过程中需要用到pytorch框架,简单学习了一下,写了一个简单的案例,记录一下pytorch中搭建一个识别网络基础的东西。对应一位博主写的tensorflow的识别mnist数据集,将其改为pytorch框架,也可以详细看到两个框架大体的区别。

Tensorflow版本转载来源(CSDN博主「兔八哥1024」):https://3water.com/article/191157.htm

Pytorch实战mnist手写数字识别

#需要导入的包
import torch
import torch.nn as nn#用于构建网络层
import torch.optim as optim#导入优化器
from torch.utils.data import DataLoader#加载数据集的迭代器
from torchvision import datasets, transforms#用于加载mnsit数据集

#下载数据集

train_set = datasets.MNIST('./data', train=True, download=True,transform = transforms.Compose([
         transforms.ToTensor(),
         transforms.Normalize((0.1037,), (0.3081,))
       ]))
test_set = datasets.MNIST('./data', train=False, download=True,transform = transforms.Compose([
         transforms.ToTensor(),
         transforms.Normalize((0.1037,), (0.3081,))
       ]))

#构建网络(网络结构对应tensorflow的那一篇文章)

class Net(nn.Module):

  def __init__(self, num_classes=10):
    super(Net, self).__init__()
    self.features = nn.Sequential(
      nn.Conv2d(1, 32, kernel_size=5, stride=1, padding=2),
      nn.MaxPool2d(kernel_size=2,stride=2),
      nn.Conv2d(32, 64, kernel_size=5, stride=1, padding=2),
      nn.MaxPool2d(kernel_size=2,stride=2),

    )
    self.classifier = nn.Sequential(
      nn.Linear(3136, 7*7*64),
      nn.Linear(3136, num_classes),

    )

  def forward(self,x):
    x = self.features(x)
    x = torch.flatten(x, 1)
    x = self.classifier(x)

    return x
net=Net()
net.cuda()#用GPU运行

#计算误差,使用adam优化器优化误差
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(net.parameters(), 1e-2)

train_data = DataLoader(train_set, batch_size=128, shuffle=True)
test_data = DataLoader(test_set, batch_size=128, shuffle=False)


#训练过程
for epoch in range(1):
  net.train() ##在进行训练时加上train(),测试时加上eval()
  batch = 0

  for batch_images, batch_labels in train_data:

    average_loss = 0
    train_acc = 0

    ##在pytorch0.4之后将Variable 与tensor进行合并,所以这里不需要进行Variable封装
    if torch.cuda.is_available():
      batch_images, batch_labels = batch_images.cuda(),batch_labels.cuda()

    #前向传播
    out = net(batch_images)
    loss = criterion(out,batch_labels)


    average_loss = loss
    prediction = torch.max(out,1)[1]
    # print(prediction)

    train_correct = (prediction == batch_labels).sum()
    ##这里得到的train_correct是一个longtensor型,需要转换为float

    train_acc = (train_correct.float()) / 128

    optimizer.zero_grad() #清空梯度信息,否则在每次进行反向传播时都会累加
    loss.backward() #loss反向传播
    optimizer.step() ##梯度更新

    batch+=1
    print("Epoch: %d/%d || batch:%d/%d average_loss: %.3f || train_acc: %.2f"
       %(epoch, 20, batch, float(int(50000/128)), average_loss, train_acc))

# 在测试集上检验效果
net.eval() # 将模型改为预测模式
for idx,(im1, label1) in enumerate(test_data):
  if torch.cuda.is_available():
    im, label = im1.cuda(),label1.cuda()
  out = net(im)
  loss = criterion(out, label)

  eval_loss = loss

  pred = torch.max(out,1)[1]
  num_correct = (pred == label).sum()
  acc = (num_correct.float())/ 128
  eval_acc = acc

  print('EVA_Batch:{}, Eval Loss: {:.6f}, Eval Acc: {:.6f}'
   .format(idx,eval_loss , eval_acc))

运行结果:

Pytorch框架实现mnist手写库识别(与tensorflow对比)

到此这篇关于Pytorch框架实现mnist手写库识别(与tensorflow对比)的文章就介绍到这了,更多相关Pytorch框架实现mnist手写库识别(与tensorflow对比)内容请搜索三水点靠木以前的文章或继续浏览下面的相关文章希望大家以后多多支持三水点靠木!

Python 相关文章推荐
python实现从一组颜色中找出与给定颜色最接近颜色的方法
Mar 19 Python
python在Windows下安装setuptools(easy_install工具)步骤详解
Jul 01 Python
Python多继承原理与用法示例
Aug 23 Python
在Python dataframe中出生日期转化为年龄的实现方法
Oct 20 Python
Python网页正文转换语音文件的操作方法
Dec 09 Python
Python实现使用request模块下载图片demo示例
May 24 Python
python创建子类的方法分析
Nov 28 Python
使用遗传算法求二元函数的最小值
Feb 11 Python
Python 列表反转显示的四种方法
Nov 16 Python
[原创]赚疯了!转手立赚800+?大佬的python「抢茅台脚本」使用教程
Jan 12 Python
python3中celery异步框架简单使用+守护进程方式启动
Jan 20 Python
python3+PyQt5+Qt Designer实现界面可视化
Jun 10 Python
python集合能干吗
Jul 19 #Python
python如何建立全零数组
Jul 19 #Python
解决python中0x80072ee2错误的方法
Jul 19 #Python
python给视频添加背景音乐并改变音量的具体方法
Jul 19 #Python
python中加背景音乐如何操作
Jul 19 #Python
python实现最短路径的实例方法
Jul 19 #Python
python等待10秒执行下一命令的方法
Jul 19 #Python
You might like
windows下PHP APACHE MYSQ完整配置
2007/01/02 PHP
PHPExcel读取Excel文件的实现代码
2011/12/06 PHP
学习使用curl采集curl使用方法
2012/01/11 PHP
php中str_pad()函数用法分析
2017/03/28 PHP
php设计模式之观察者模式定义与用法经典示例
2019/09/19 PHP
自己动手制作jquery插件之自动添加删除行的实现
2011/10/13 Javascript
javascript 函数及作用域总结介绍
2013/11/12 Javascript
iframe窗口高度自适应的实现方法
2014/01/08 Javascript
JavaScript sub方法入门实例(把字符串显示为下标)
2014/10/17 Javascript
ajax在兼容模式下失效的快速解决方法
2016/03/22 Javascript
BootStrap智能表单实战系列(十一)级联下拉的支持
2016/06/13 Javascript
简单的js计算器实现
2016/10/26 Javascript
Bootstrap表单控件使用方法详解
2017/01/11 Javascript
使用Promise链式调用解决多个异步回调的问题
2017/01/15 Javascript
ES6学习笔记之字符串、数组、对象、函数新增知识点实例分析
2020/01/22 Javascript
VSCode写vue项目一键生成.vue模版,修改定义其他模板的方法
2020/04/17 Javascript
[02:32]DOTA2英雄基础教程 美杜莎
2014/01/07 DOTA
[08:53]DOTA2每周TOP10 精彩击杀集锦vol.9
2014/06/26 DOTA
[37:02]OG vs INfamous 2019国际邀请赛小组赛 BO2 第二场 8.15
2019/08/17 DOTA
深入理解 Python 中的多线程 新手必看
2016/11/20 Python
Python实现简单文本字符串处理的方法
2018/01/22 Python
利用Pandas 创建空的DataFrame方法
2018/04/08 Python
如何基于python实现年会抽奖工具
2020/10/20 Python
Space NK英国站:英国热门美妆网站
2017/12/11 全球购物
美国全球旅游运营商:Pacific Holidays
2018/06/18 全球购物
英国日常交易网站:Wowcher
2018/09/04 全球购物
系统管理员的职责包括那些?管理的对象是什么?
2013/01/18 面试题
自荐信格式的六要素
2013/09/21 职场文书
写给女朋友的道歉信
2014/01/12 职场文书
大学生咖啡店创业计划书
2014/01/21 职场文书
大学竞选班干部演讲稿
2014/08/21 职场文书
音乐教师求职信范文
2015/03/20 职场文书
2016年“六一儿童节”校园广播稿
2015/12/17 职场文书
Vue3.0 手写放大镜效果
2021/07/25 Vue.js
Mysql排序的特性详情
2021/11/01 MySQL
Window server中安装Redis的超详细教程
2021/11/17 Redis