Pytorch框架实现mnist手写库识别(与tensorflow对比)


Posted in Python onJuly 20, 2020

前言最近在学习过程中需要用到pytorch框架,简单学习了一下,写了一个简单的案例,记录一下pytorch中搭建一个识别网络基础的东西。对应一位博主写的tensorflow的识别mnist数据集,将其改为pytorch框架,也可以详细看到两个框架大体的区别。

Tensorflow版本转载来源(CSDN博主「兔八哥1024」):https://3water.com/article/191157.htm

Pytorch实战mnist手写数字识别

#需要导入的包
import torch
import torch.nn as nn#用于构建网络层
import torch.optim as optim#导入优化器
from torch.utils.data import DataLoader#加载数据集的迭代器
from torchvision import datasets, transforms#用于加载mnsit数据集

#下载数据集

train_set = datasets.MNIST('./data', train=True, download=True,transform = transforms.Compose([
         transforms.ToTensor(),
         transforms.Normalize((0.1037,), (0.3081,))
       ]))
test_set = datasets.MNIST('./data', train=False, download=True,transform = transforms.Compose([
         transforms.ToTensor(),
         transforms.Normalize((0.1037,), (0.3081,))
       ]))

#构建网络(网络结构对应tensorflow的那一篇文章)

class Net(nn.Module):

  def __init__(self, num_classes=10):
    super(Net, self).__init__()
    self.features = nn.Sequential(
      nn.Conv2d(1, 32, kernel_size=5, stride=1, padding=2),
      nn.MaxPool2d(kernel_size=2,stride=2),
      nn.Conv2d(32, 64, kernel_size=5, stride=1, padding=2),
      nn.MaxPool2d(kernel_size=2,stride=2),

    )
    self.classifier = nn.Sequential(
      nn.Linear(3136, 7*7*64),
      nn.Linear(3136, num_classes),

    )

  def forward(self,x):
    x = self.features(x)
    x = torch.flatten(x, 1)
    x = self.classifier(x)

    return x
net=Net()
net.cuda()#用GPU运行

#计算误差,使用adam优化器优化误差
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(net.parameters(), 1e-2)

train_data = DataLoader(train_set, batch_size=128, shuffle=True)
test_data = DataLoader(test_set, batch_size=128, shuffle=False)


#训练过程
for epoch in range(1):
  net.train() ##在进行训练时加上train(),测试时加上eval()
  batch = 0

  for batch_images, batch_labels in train_data:

    average_loss = 0
    train_acc = 0

    ##在pytorch0.4之后将Variable 与tensor进行合并,所以这里不需要进行Variable封装
    if torch.cuda.is_available():
      batch_images, batch_labels = batch_images.cuda(),batch_labels.cuda()

    #前向传播
    out = net(batch_images)
    loss = criterion(out,batch_labels)


    average_loss = loss
    prediction = torch.max(out,1)[1]
    # print(prediction)

    train_correct = (prediction == batch_labels).sum()
    ##这里得到的train_correct是一个longtensor型,需要转换为float

    train_acc = (train_correct.float()) / 128

    optimizer.zero_grad() #清空梯度信息,否则在每次进行反向传播时都会累加
    loss.backward() #loss反向传播
    optimizer.step() ##梯度更新

    batch+=1
    print("Epoch: %d/%d || batch:%d/%d average_loss: %.3f || train_acc: %.2f"
       %(epoch, 20, batch, float(int(50000/128)), average_loss, train_acc))

# 在测试集上检验效果
net.eval() # 将模型改为预测模式
for idx,(im1, label1) in enumerate(test_data):
  if torch.cuda.is_available():
    im, label = im1.cuda(),label1.cuda()
  out = net(im)
  loss = criterion(out, label)

  eval_loss = loss

  pred = torch.max(out,1)[1]
  num_correct = (pred == label).sum()
  acc = (num_correct.float())/ 128
  eval_acc = acc

  print('EVA_Batch:{}, Eval Loss: {:.6f}, Eval Acc: {:.6f}'
   .format(idx,eval_loss , eval_acc))

运行结果:

Pytorch框架实现mnist手写库识别(与tensorflow对比)

到此这篇关于Pytorch框架实现mnist手写库识别(与tensorflow对比)的文章就介绍到这了,更多相关Pytorch框架实现mnist手写库识别(与tensorflow对比)内容请搜索三水点靠木以前的文章或继续浏览下面的相关文章希望大家以后多多支持三水点靠木!

Python 相关文章推荐
Python实现Tab自动补全和历史命令管理的方法
Mar 12 Python
利用Opencv中Houghline方法实现直线检测
Feb 11 Python
python中类的属性和方法介绍
Nov 27 Python
Python3爬虫学习之爬虫利器Beautiful Soup用法分析
Dec 12 Python
python requests post多层字典的方法
Dec 27 Python
Python装饰器用法实例分析
Jan 14 Python
python面试题小结附答案实例代码
Apr 11 Python
如何在VSCode上轻松舒适的配置Python的方法步骤
Oct 28 Python
python数值基础知识浅析
Nov 19 Python
Pytest框架之fixture的详细使用教程
Apr 07 Python
python 如何设置守护进程
Oct 29 Python
Python获取指定网段正在使用的IP
Dec 14 Python
python集合能干吗
Jul 19 #Python
python如何建立全零数组
Jul 19 #Python
解决python中0x80072ee2错误的方法
Jul 19 #Python
python给视频添加背景音乐并改变音量的具体方法
Jul 19 #Python
python中加背景音乐如何操作
Jul 19 #Python
python实现最短路径的实例方法
Jul 19 #Python
python等待10秒执行下一命令的方法
Jul 19 #Python
You might like
《PHP边学边教》(04.编写简易的通讯录――视频教程1)
2006/12/13 PHP
PHP保存session到memcache服务器的方法
2016/01/19 PHP
用PHP的反射实现委托模式的讲解
2019/03/22 PHP
php 命名空间(namespace)原理与用法实例小结
2019/11/13 PHP
JS是否可以跨文件同时控制多个iframe页面的应用技巧
2007/12/16 Javascript
javascript 浏览器检测代码精简版
2010/03/04 Javascript
jquery插件制作 图片走廊 gallery
2012/08/17 Javascript
JavaScript高级程序设计(第3版)学习笔记13 ECMAScript5新特性
2012/10/11 Javascript
详解Javascript继承的实现
2016/03/25 Javascript
js手动播放图片实现图片轮播效果
2016/09/17 Javascript
浅析JavaScript中var that=this
2017/02/17 Javascript
Web开发中客户端的跳转与服务器端的跳转的区别
2017/03/05 Javascript
electron实现qq快捷登录的方法示例
2018/10/22 Javascript
vuex 实现getter值赋值给vue组件里的data示例
2019/11/05 Javascript
详解vite2.0配置学习(typescript版本)
2021/02/25 Javascript
Vue SPA 首屏优化方案
2021/02/26 Vue.js
[01:57]DOTA2上海特锦赛小组赛解说单车采访花絮
2016/02/27 DOTA
Python的Flask框架中实现分页功能的教程
2015/04/20 Python
Python实现截屏的函数
2015/07/25 Python
python语言中with as的用法使用详解
2018/02/23 Python
Python cookbook(数据结构与算法)从字典中提取子集的方法示例
2018/03/22 Python
django 使用 request 获取浏览器发送的参数示例代码
2018/06/11 Python
python实现KNN分类算法
2019/10/16 Python
pandas 空数据处理方法详解
2019/11/02 Python
Python hashlib常见摘要算法详解
2020/01/13 Python
PyQt5+Pycharm安装和配置图文教程详解
2020/03/24 Python
HTML5中的新元素介绍
2008/10/17 HTML / CSS
使用HTML5加载音频和视频的实现代码
2020/11/30 HTML / CSS
大学新生军训个人的自我评价
2013/10/03 职场文书
数控专业应届生求职信
2013/11/27 职场文书
工会换届选举方案
2014/05/21 职场文书
人事主管岗位职责说明书
2014/07/30 职场文书
学校组织向国旗敬礼活动方案(中小学适用)
2014/09/27 职场文书
2015年城市管理工作总结
2015/05/23 职场文书
小学教师教育随笔
2015/08/14 职场文书
新手必备Python开发环境搭建教程
2021/05/28 Python