Pytorch框架实现mnist手写库识别(与tensorflow对比)


Posted in Python onJuly 20, 2020

前言最近在学习过程中需要用到pytorch框架,简单学习了一下,写了一个简单的案例,记录一下pytorch中搭建一个识别网络基础的东西。对应一位博主写的tensorflow的识别mnist数据集,将其改为pytorch框架,也可以详细看到两个框架大体的区别。

Tensorflow版本转载来源(CSDN博主「兔八哥1024」):https://3water.com/article/191157.htm

Pytorch实战mnist手写数字识别

#需要导入的包
import torch
import torch.nn as nn#用于构建网络层
import torch.optim as optim#导入优化器
from torch.utils.data import DataLoader#加载数据集的迭代器
from torchvision import datasets, transforms#用于加载mnsit数据集

#下载数据集

train_set = datasets.MNIST('./data', train=True, download=True,transform = transforms.Compose([
         transforms.ToTensor(),
         transforms.Normalize((0.1037,), (0.3081,))
       ]))
test_set = datasets.MNIST('./data', train=False, download=True,transform = transforms.Compose([
         transforms.ToTensor(),
         transforms.Normalize((0.1037,), (0.3081,))
       ]))

#构建网络(网络结构对应tensorflow的那一篇文章)

class Net(nn.Module):

  def __init__(self, num_classes=10):
    super(Net, self).__init__()
    self.features = nn.Sequential(
      nn.Conv2d(1, 32, kernel_size=5, stride=1, padding=2),
      nn.MaxPool2d(kernel_size=2,stride=2),
      nn.Conv2d(32, 64, kernel_size=5, stride=1, padding=2),
      nn.MaxPool2d(kernel_size=2,stride=2),

    )
    self.classifier = nn.Sequential(
      nn.Linear(3136, 7*7*64),
      nn.Linear(3136, num_classes),

    )

  def forward(self,x):
    x = self.features(x)
    x = torch.flatten(x, 1)
    x = self.classifier(x)

    return x
net=Net()
net.cuda()#用GPU运行

#计算误差,使用adam优化器优化误差
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(net.parameters(), 1e-2)

train_data = DataLoader(train_set, batch_size=128, shuffle=True)
test_data = DataLoader(test_set, batch_size=128, shuffle=False)


#训练过程
for epoch in range(1):
  net.train() ##在进行训练时加上train(),测试时加上eval()
  batch = 0

  for batch_images, batch_labels in train_data:

    average_loss = 0
    train_acc = 0

    ##在pytorch0.4之后将Variable 与tensor进行合并,所以这里不需要进行Variable封装
    if torch.cuda.is_available():
      batch_images, batch_labels = batch_images.cuda(),batch_labels.cuda()

    #前向传播
    out = net(batch_images)
    loss = criterion(out,batch_labels)


    average_loss = loss
    prediction = torch.max(out,1)[1]
    # print(prediction)

    train_correct = (prediction == batch_labels).sum()
    ##这里得到的train_correct是一个longtensor型,需要转换为float

    train_acc = (train_correct.float()) / 128

    optimizer.zero_grad() #清空梯度信息,否则在每次进行反向传播时都会累加
    loss.backward() #loss反向传播
    optimizer.step() ##梯度更新

    batch+=1
    print("Epoch: %d/%d || batch:%d/%d average_loss: %.3f || train_acc: %.2f"
       %(epoch, 20, batch, float(int(50000/128)), average_loss, train_acc))

# 在测试集上检验效果
net.eval() # 将模型改为预测模式
for idx,(im1, label1) in enumerate(test_data):
  if torch.cuda.is_available():
    im, label = im1.cuda(),label1.cuda()
  out = net(im)
  loss = criterion(out, label)

  eval_loss = loss

  pred = torch.max(out,1)[1]
  num_correct = (pred == label).sum()
  acc = (num_correct.float())/ 128
  eval_acc = acc

  print('EVA_Batch:{}, Eval Loss: {:.6f}, Eval Acc: {:.6f}'
   .format(idx,eval_loss , eval_acc))

运行结果:

Pytorch框架实现mnist手写库识别(与tensorflow对比)

到此这篇关于Pytorch框架实现mnist手写库识别(与tensorflow对比)的文章就介绍到这了,更多相关Pytorch框架实现mnist手写库识别(与tensorflow对比)内容请搜索三水点靠木以前的文章或继续浏览下面的相关文章希望大家以后多多支持三水点靠木!

Python 相关文章推荐
用Python输出一个杨辉三角的例子
Jun 13 Python
Python 高级专用类方法的实例详解
Sep 11 Python
python机器人行走步数问题的解决
Jan 29 Python
Python利用openpyxl库遍历Sheet的实例
May 03 Python
python opencv 图像拼接的实现方法
Jun 27 Python
python输出电脑上所有的串口名的方法
Jul 02 Python
Python实现打印实心和空心菱形
Nov 23 Python
Python使用psutil获取进程信息的例子
Dec 17 Python
Python底层封装实现方法详解
Jan 22 Python
python中with用法讲解
Feb 07 Python
python with语句的原理与用法详解
Mar 30 Python
python 模块导入问题汇总
Feb 01 Python
python集合能干吗
Jul 19 #Python
python如何建立全零数组
Jul 19 #Python
解决python中0x80072ee2错误的方法
Jul 19 #Python
python给视频添加背景音乐并改变音量的具体方法
Jul 19 #Python
python中加背景音乐如何操作
Jul 19 #Python
python实现最短路径的实例方法
Jul 19 #Python
python等待10秒执行下一命令的方法
Jul 19 #Python
You might like
重置版战役片段
2020/04/09 魔兽争霸
虹吸式咖啡壶操作
2021/03/03 冲泡冲煮
解析用PHP读写音频文件信息的详解(支持WMA和MP3)
2013/05/10 PHP
php cookie使用方法学习笔记分享
2013/11/07 PHP
php导出csv数据在浏览器中输出提供下载或保存到文件的示例
2014/04/24 PHP
php单态设计模式(单例模式)实例
2014/11/18 PHP
php获取QQ头像并显示的方法
2014/12/23 PHP
yii2框架中使用下拉菜单的自动搜索yii-widget-select2实例分析
2016/01/09 PHP
PHP+JQUERY操作JSON实例
2017/03/23 PHP
Laravel Validator自定义错误返回提示消息并在前端展示
2019/05/09 PHP
键盘控制事件应用教程大全
2006/11/24 Javascript
JQuery 返回布尔值Is()条件判断方法代码
2012/05/14 Javascript
JQuery处理json与ajax返回JSON实例代码
2014/01/03 Javascript
AngularJS 中的事件详解
2016/07/28 Javascript
Vue.js每天必学之表单控件绑定
2016/09/05 Javascript
javascript的document中的动态添加标签实现方法
2016/10/24 Javascript
jQuery扩展+xml实现表单验证功能的方法
2016/12/25 Javascript
从零开始学习Node.js系列教程之基于connect和express框架的多页面实现数学运算示例
2017/04/13 Javascript
前端构建工具之gulp的语法教程
2017/06/12 Javascript
详解webpack模块化管理和打包工具
2018/04/21 Javascript
了解javascript中let和var及const关键字的区别
2019/05/24 Javascript
[02:04]2014DOTA2国际邀请赛 DK一个时代的落幕
2014/07/21 DOTA
Python实现的简单发送邮件脚本分享
2014/11/07 Python
将Django框架和遗留的Web应用集成的方法
2015/07/24 Python
pytorch 可视化feature map的示例代码
2019/08/20 Python
python 获取当前目录下的文件目录和文件名实例代码详解
2020/03/10 Python
Python调用系统命令os.system()和os.popen()的实现
2020/12/31 Python
Python自动化测试基础必备知识点总结
2021/02/07 Python
CSS3 渐变(Gradients)之CSS3 径向渐变
2016/07/08 HTML / CSS
Levi’s美国官网:美国著名的牛仔裤品牌
2016/08/19 全球购物
领导视察欢迎词
2014/01/15 职场文书
孝老爱亲模范事迹
2014/01/24 职场文书
书法比赛获奖感言
2014/02/10 职场文书
2014年教师政治学习材料
2014/06/02 职场文书
学困生帮扶工作总结
2015/08/13 职场文书
建议书的格式及范文
2015/09/14 职场文书