Pytorch框架实现mnist手写库识别(与tensorflow对比)


Posted in Python onJuly 20, 2020

前言最近在学习过程中需要用到pytorch框架,简单学习了一下,写了一个简单的案例,记录一下pytorch中搭建一个识别网络基础的东西。对应一位博主写的tensorflow的识别mnist数据集,将其改为pytorch框架,也可以详细看到两个框架大体的区别。

Tensorflow版本转载来源(CSDN博主「兔八哥1024」):https://3water.com/article/191157.htm

Pytorch实战mnist手写数字识别

#需要导入的包
import torch
import torch.nn as nn#用于构建网络层
import torch.optim as optim#导入优化器
from torch.utils.data import DataLoader#加载数据集的迭代器
from torchvision import datasets, transforms#用于加载mnsit数据集

#下载数据集

train_set = datasets.MNIST('./data', train=True, download=True,transform = transforms.Compose([
         transforms.ToTensor(),
         transforms.Normalize((0.1037,), (0.3081,))
       ]))
test_set = datasets.MNIST('./data', train=False, download=True,transform = transforms.Compose([
         transforms.ToTensor(),
         transforms.Normalize((0.1037,), (0.3081,))
       ]))

#构建网络(网络结构对应tensorflow的那一篇文章)

class Net(nn.Module):

  def __init__(self, num_classes=10):
    super(Net, self).__init__()
    self.features = nn.Sequential(
      nn.Conv2d(1, 32, kernel_size=5, stride=1, padding=2),
      nn.MaxPool2d(kernel_size=2,stride=2),
      nn.Conv2d(32, 64, kernel_size=5, stride=1, padding=2),
      nn.MaxPool2d(kernel_size=2,stride=2),

    )
    self.classifier = nn.Sequential(
      nn.Linear(3136, 7*7*64),
      nn.Linear(3136, num_classes),

    )

  def forward(self,x):
    x = self.features(x)
    x = torch.flatten(x, 1)
    x = self.classifier(x)

    return x
net=Net()
net.cuda()#用GPU运行

#计算误差,使用adam优化器优化误差
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(net.parameters(), 1e-2)

train_data = DataLoader(train_set, batch_size=128, shuffle=True)
test_data = DataLoader(test_set, batch_size=128, shuffle=False)


#训练过程
for epoch in range(1):
  net.train() ##在进行训练时加上train(),测试时加上eval()
  batch = 0

  for batch_images, batch_labels in train_data:

    average_loss = 0
    train_acc = 0

    ##在pytorch0.4之后将Variable 与tensor进行合并,所以这里不需要进行Variable封装
    if torch.cuda.is_available():
      batch_images, batch_labels = batch_images.cuda(),batch_labels.cuda()

    #前向传播
    out = net(batch_images)
    loss = criterion(out,batch_labels)


    average_loss = loss
    prediction = torch.max(out,1)[1]
    # print(prediction)

    train_correct = (prediction == batch_labels).sum()
    ##这里得到的train_correct是一个longtensor型,需要转换为float

    train_acc = (train_correct.float()) / 128

    optimizer.zero_grad() #清空梯度信息,否则在每次进行反向传播时都会累加
    loss.backward() #loss反向传播
    optimizer.step() ##梯度更新

    batch+=1
    print("Epoch: %d/%d || batch:%d/%d average_loss: %.3f || train_acc: %.2f"
       %(epoch, 20, batch, float(int(50000/128)), average_loss, train_acc))

# 在测试集上检验效果
net.eval() # 将模型改为预测模式
for idx,(im1, label1) in enumerate(test_data):
  if torch.cuda.is_available():
    im, label = im1.cuda(),label1.cuda()
  out = net(im)
  loss = criterion(out, label)

  eval_loss = loss

  pred = torch.max(out,1)[1]
  num_correct = (pred == label).sum()
  acc = (num_correct.float())/ 128
  eval_acc = acc

  print('EVA_Batch:{}, Eval Loss: {:.6f}, Eval Acc: {:.6f}'
   .format(idx,eval_loss , eval_acc))

运行结果:

Pytorch框架实现mnist手写库识别(与tensorflow对比)

到此这篇关于Pytorch框架实现mnist手写库识别(与tensorflow对比)的文章就介绍到这了,更多相关Pytorch框架实现mnist手写库识别(与tensorflow对比)内容请搜索三水点靠木以前的文章或继续浏览下面的相关文章希望大家以后多多支持三水点靠木!

Python 相关文章推荐
python实现ipsec开权限实例
Nov 11 Python
Python实现文件按照日期命名的方法
Jul 09 Python
使用PyV8在Python爬虫中执行js代码
Feb 16 Python
win8下python3.4安装和环境配置图文教程
Jul 31 Python
python实现随机漫步方法和原理
Jun 10 Python
Django生成PDF文档显示在网页上以及解决PDF中文显示乱码的问题
Jul 04 Python
Django框架HttpRequest对象用法实例分析
Nov 01 Python
基于TensorFlow常量、序列以及随机值生成实例
Jan 04 Python
jupyter notebook 实现matplotlib图动态刷新
Apr 22 Python
Python 实现3种回归模型(Linear Regression,Lasso,Ridge)的示例
Oct 15 Python
Python QT组件库qtwidgets的使用
Nov 02 Python
仅用几行Python代码就能复制她的U盘文件?
Jun 26 Python
python集合能干吗
Jul 19 #Python
python如何建立全零数组
Jul 19 #Python
解决python中0x80072ee2错误的方法
Jul 19 #Python
python给视频添加背景音乐并改变音量的具体方法
Jul 19 #Python
python中加背景音乐如何操作
Jul 19 #Python
python实现最短路径的实例方法
Jul 19 #Python
python等待10秒执行下一命令的方法
Jul 19 #Python
You might like
PHP的引用详解
2015/02/22 PHP
php按单词截取字符串的方法
2015/04/07 PHP
Laravel 实现数据软删除功能
2019/08/21 PHP
javascript encodeURI和encodeURIComponent的比较
2010/04/03 Javascript
多引号嵌套的变量命名的问题
2014/05/09 Javascript
Javascript中call与apply的学习笔记
2014/09/22 Javascript
JavaScript中点击事件的写法
2016/06/28 Javascript
Three.js学习之正交投影照相机
2016/08/01 Javascript
node.js实现博客小爬虫的实例代码
2016/10/08 Javascript
JavaScript实现窗口抖动效果
2016/10/19 Javascript
jQuery序列化后的表单值转换成Json
2017/06/16 jQuery
angular4中关于表单的校验示例
2017/10/16 Javascript
Moment.js实现多个同时倒计时
2019/08/26 Javascript
javascript中的相等操作符(==与===区别)
2019/12/21 Javascript
JS实现滑动导航效果
2020/01/14 Javascript
JS操作Fckeditor的一些常用方法(获取、插入等)
2020/02/19 Javascript
如何在 Vue 中使用 JSX
2021/02/14 Vue.js
[53:23]Secret vs Liquid 2018国际邀请赛淘汰赛BO3 第二场 8.25
2018/08/29 DOTA
Windows下的Python 3.6.1的下载与安装图文详解(适合32位和64位)
2018/02/21 Python
解决python中使用PYQT时中文乱码问题
2019/06/17 Python
pytorch对梯度进行可视化进行梯度检查教程
2020/02/04 Python
在Keras中利用np.random.shuffle()打乱数据集实例
2020/06/15 Python
python中not、and和or的优先级与详细用法介绍
2020/11/03 Python
selenium+超级鹰实现模拟登录12306
2021/01/24 Python
实例教程 一款纯css3实现的数字统计游戏
2014/11/10 HTML / CSS
HTML5新特性之语义化标签
2017/10/31 HTML / CSS
Finishline官网:美国一家领先的运动品牌鞋类、服装零售商
2016/07/20 全球购物
伦敦时尚生活的缩影:LN-CC
2017/01/24 全球购物
String和StringBuffer的区别
2015/08/13 面试题
工作中的自我评价如何写好
2013/10/28 职场文书
决心书范文
2014/03/11 职场文书
暑期政治学习心得体会
2014/09/02 职场文书
领导干部作风建设自查报告
2014/10/23 职场文书
2015年办公室人员工作总结
2015/05/15 职场文书
教师节校长致辞
2015/07/31 职场文书
TV动画《间谍过家家》公开PV
2022/03/20 日漫