Pytorch框架实现mnist手写库识别(与tensorflow对比)


Posted in Python onJuly 20, 2020

前言最近在学习过程中需要用到pytorch框架,简单学习了一下,写了一个简单的案例,记录一下pytorch中搭建一个识别网络基础的东西。对应一位博主写的tensorflow的识别mnist数据集,将其改为pytorch框架,也可以详细看到两个框架大体的区别。

Tensorflow版本转载来源(CSDN博主「兔八哥1024」):https://3water.com/article/191157.htm

Pytorch实战mnist手写数字识别

#需要导入的包
import torch
import torch.nn as nn#用于构建网络层
import torch.optim as optim#导入优化器
from torch.utils.data import DataLoader#加载数据集的迭代器
from torchvision import datasets, transforms#用于加载mnsit数据集

#下载数据集

train_set = datasets.MNIST('./data', train=True, download=True,transform = transforms.Compose([
         transforms.ToTensor(),
         transforms.Normalize((0.1037,), (0.3081,))
       ]))
test_set = datasets.MNIST('./data', train=False, download=True,transform = transforms.Compose([
         transforms.ToTensor(),
         transforms.Normalize((0.1037,), (0.3081,))
       ]))

#构建网络(网络结构对应tensorflow的那一篇文章)

class Net(nn.Module):

  def __init__(self, num_classes=10):
    super(Net, self).__init__()
    self.features = nn.Sequential(
      nn.Conv2d(1, 32, kernel_size=5, stride=1, padding=2),
      nn.MaxPool2d(kernel_size=2,stride=2),
      nn.Conv2d(32, 64, kernel_size=5, stride=1, padding=2),
      nn.MaxPool2d(kernel_size=2,stride=2),

    )
    self.classifier = nn.Sequential(
      nn.Linear(3136, 7*7*64),
      nn.Linear(3136, num_classes),

    )

  def forward(self,x):
    x = self.features(x)
    x = torch.flatten(x, 1)
    x = self.classifier(x)

    return x
net=Net()
net.cuda()#用GPU运行

#计算误差,使用adam优化器优化误差
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(net.parameters(), 1e-2)

train_data = DataLoader(train_set, batch_size=128, shuffle=True)
test_data = DataLoader(test_set, batch_size=128, shuffle=False)


#训练过程
for epoch in range(1):
  net.train() ##在进行训练时加上train(),测试时加上eval()
  batch = 0

  for batch_images, batch_labels in train_data:

    average_loss = 0
    train_acc = 0

    ##在pytorch0.4之后将Variable 与tensor进行合并,所以这里不需要进行Variable封装
    if torch.cuda.is_available():
      batch_images, batch_labels = batch_images.cuda(),batch_labels.cuda()

    #前向传播
    out = net(batch_images)
    loss = criterion(out,batch_labels)


    average_loss = loss
    prediction = torch.max(out,1)[1]
    # print(prediction)

    train_correct = (prediction == batch_labels).sum()
    ##这里得到的train_correct是一个longtensor型,需要转换为float

    train_acc = (train_correct.float()) / 128

    optimizer.zero_grad() #清空梯度信息,否则在每次进行反向传播时都会累加
    loss.backward() #loss反向传播
    optimizer.step() ##梯度更新

    batch+=1
    print("Epoch: %d/%d || batch:%d/%d average_loss: %.3f || train_acc: %.2f"
       %(epoch, 20, batch, float(int(50000/128)), average_loss, train_acc))

# 在测试集上检验效果
net.eval() # 将模型改为预测模式
for idx,(im1, label1) in enumerate(test_data):
  if torch.cuda.is_available():
    im, label = im1.cuda(),label1.cuda()
  out = net(im)
  loss = criterion(out, label)

  eval_loss = loss

  pred = torch.max(out,1)[1]
  num_correct = (pred == label).sum()
  acc = (num_correct.float())/ 128
  eval_acc = acc

  print('EVA_Batch:{}, Eval Loss: {:.6f}, Eval Acc: {:.6f}'
   .format(idx,eval_loss , eval_acc))

运行结果:

Pytorch框架实现mnist手写库识别(与tensorflow对比)

到此这篇关于Pytorch框架实现mnist手写库识别(与tensorflow对比)的文章就介绍到这了,更多相关Pytorch框架实现mnist手写库识别(与tensorflow对比)内容请搜索三水点靠木以前的文章或继续浏览下面的相关文章希望大家以后多多支持三水点靠木!

Python 相关文章推荐
整理Python最基本的操作字典的方法
Apr 24 Python
Python实现字典去除重复的方法示例
Jul 31 Python
利用Tkinter(python3.6)实现一个简单计算器
Dec 21 Python
Anaconda 离线安装 python 包的操作方法
Jun 11 Python
Python 给定的经纬度标注在地图上的实现方法
Jul 05 Python
对Python _取log的几种方式小结
Jul 25 Python
wxpython多线程防假死与线程间传递消息实例详解
Dec 13 Python
Pytorch GPU显存充足却显示out of memory的解决方式
Jan 13 Python
升级keras解决load_weights()中的未定义skip_mismatch关键字问题
Jun 12 Python
python 服务器运行代码报错ModuleNotFoundError的解决办法
Sep 16 Python
scrapy结合selenium解析动态页面的实现
Sep 28 Python
python 基于opencv操作摄像头
Dec 24 Python
python集合能干吗
Jul 19 #Python
python如何建立全零数组
Jul 19 #Python
解决python中0x80072ee2错误的方法
Jul 19 #Python
python给视频添加背景音乐并改变音量的具体方法
Jul 19 #Python
python中加背景音乐如何操作
Jul 19 #Python
python实现最短路径的实例方法
Jul 19 #Python
python等待10秒执行下一命令的方法
Jul 19 #Python
You might like
getimagesize获取图片尺寸实例
2014/11/15 PHP
Zend Framework动作助手Url用法详解
2016/03/05 PHP
PHP简单获取及判断提交来源的方法
2016/04/22 PHP
PHP的PDO错误与错误处理
2019/01/27 PHP
php时间戳转换代码详解
2019/08/04 PHP
Prototype使用指南之selector.js
2007/01/10 Javascript
用jQuery中的ajax分页实现代码
2011/09/20 Javascript
JavaScript初学者应注意的七个细节小结
2012/01/30 Javascript
jquery实现效果比较好的table选中行颜色
2014/03/25 Javascript
《JavaScript DOM 编程艺术》读书笔记之JavaScript 语法
2015/01/09 Javascript
用js编写的简单的计算器代码程序
2015/08/04 Javascript
Javascript实现鼠标框选操作  不是点击选取
2016/04/14 Javascript
JavaScript随机打乱数组顺序之随机洗牌算法
2016/08/02 Javascript
使用bootstrap实现多窗口和拖动效果
2016/09/22 Javascript
jquery自定义插件结合baiduTemplate.js实现异步刷新(附源码)
2016/12/22 Javascript
Vue关于数据绑定出错解决办法
2017/05/15 Javascript
JS 实现分页打印功能
2018/05/16 Javascript
详解vue中的computed的this指向问题
2018/12/05 Javascript
this.$toast() 了解一下?
2019/04/18 Javascript
vue+springboot图片上传和显示的示例代码
2020/02/14 Javascript
微信小程序吸底区域适配iPhoneX的实现
2020/04/09 Javascript
js String.prototype.trim字符去前后空格的扩展
2020/08/23 Javascript
Vue 防止短时间内连续点击后多次触发请求的操作
2020/11/11 Javascript
python基础教程之udp端口扫描
2014/02/10 Python
Python实现带百分比的进度条
2016/06/28 Python
Linux 修改Python命令的方法示例
2018/12/03 Python
利用Python实现Shp格式向GeoJSON的转换方法
2019/07/09 Python
pytorch中如何使用DataLoader对数据集进行批处理的方法
2019/08/06 Python
解决python和pycharm安装gmpy2 出现ERROR的问题
2020/08/28 Python
解决PyCharm IDE环境下,执行unittest不生成测试报告的问题
2020/09/03 Python
美国诺德斯特龙百货官网:Nordstrom
2016/08/23 全球购物
个性与发展自我评价
2014/02/11 职场文书
物业管理工作方案
2014/05/10 职场文书
2014年乡镇党建工作总结
2014/11/11 职场文书
2014年卫生保健工作总结
2014/12/08 职场文书
最新动漫情报:2022年7月新番定档超过30部, OVERLORD骨王第四季也在其中噢
2022/05/04 日漫