Pytorch框架实现mnist手写库识别(与tensorflow对比)


Posted in Python onJuly 20, 2020

前言最近在学习过程中需要用到pytorch框架,简单学习了一下,写了一个简单的案例,记录一下pytorch中搭建一个识别网络基础的东西。对应一位博主写的tensorflow的识别mnist数据集,将其改为pytorch框架,也可以详细看到两个框架大体的区别。

Tensorflow版本转载来源(CSDN博主「兔八哥1024」):https://3water.com/article/191157.htm

Pytorch实战mnist手写数字识别

#需要导入的包
import torch
import torch.nn as nn#用于构建网络层
import torch.optim as optim#导入优化器
from torch.utils.data import DataLoader#加载数据集的迭代器
from torchvision import datasets, transforms#用于加载mnsit数据集

#下载数据集

train_set = datasets.MNIST('./data', train=True, download=True,transform = transforms.Compose([
         transforms.ToTensor(),
         transforms.Normalize((0.1037,), (0.3081,))
       ]))
test_set = datasets.MNIST('./data', train=False, download=True,transform = transforms.Compose([
         transforms.ToTensor(),
         transforms.Normalize((0.1037,), (0.3081,))
       ]))

#构建网络(网络结构对应tensorflow的那一篇文章)

class Net(nn.Module):

  def __init__(self, num_classes=10):
    super(Net, self).__init__()
    self.features = nn.Sequential(
      nn.Conv2d(1, 32, kernel_size=5, stride=1, padding=2),
      nn.MaxPool2d(kernel_size=2,stride=2),
      nn.Conv2d(32, 64, kernel_size=5, stride=1, padding=2),
      nn.MaxPool2d(kernel_size=2,stride=2),

    )
    self.classifier = nn.Sequential(
      nn.Linear(3136, 7*7*64),
      nn.Linear(3136, num_classes),

    )

  def forward(self,x):
    x = self.features(x)
    x = torch.flatten(x, 1)
    x = self.classifier(x)

    return x
net=Net()
net.cuda()#用GPU运行

#计算误差,使用adam优化器优化误差
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(net.parameters(), 1e-2)

train_data = DataLoader(train_set, batch_size=128, shuffle=True)
test_data = DataLoader(test_set, batch_size=128, shuffle=False)


#训练过程
for epoch in range(1):
  net.train() ##在进行训练时加上train(),测试时加上eval()
  batch = 0

  for batch_images, batch_labels in train_data:

    average_loss = 0
    train_acc = 0

    ##在pytorch0.4之后将Variable 与tensor进行合并,所以这里不需要进行Variable封装
    if torch.cuda.is_available():
      batch_images, batch_labels = batch_images.cuda(),batch_labels.cuda()

    #前向传播
    out = net(batch_images)
    loss = criterion(out,batch_labels)


    average_loss = loss
    prediction = torch.max(out,1)[1]
    # print(prediction)

    train_correct = (prediction == batch_labels).sum()
    ##这里得到的train_correct是一个longtensor型,需要转换为float

    train_acc = (train_correct.float()) / 128

    optimizer.zero_grad() #清空梯度信息,否则在每次进行反向传播时都会累加
    loss.backward() #loss反向传播
    optimizer.step() ##梯度更新

    batch+=1
    print("Epoch: %d/%d || batch:%d/%d average_loss: %.3f || train_acc: %.2f"
       %(epoch, 20, batch, float(int(50000/128)), average_loss, train_acc))

# 在测试集上检验效果
net.eval() # 将模型改为预测模式
for idx,(im1, label1) in enumerate(test_data):
  if torch.cuda.is_available():
    im, label = im1.cuda(),label1.cuda()
  out = net(im)
  loss = criterion(out, label)

  eval_loss = loss

  pred = torch.max(out,1)[1]
  num_correct = (pred == label).sum()
  acc = (num_correct.float())/ 128
  eval_acc = acc

  print('EVA_Batch:{}, Eval Loss: {:.6f}, Eval Acc: {:.6f}'
   .format(idx,eval_loss , eval_acc))

运行结果:

Pytorch框架实现mnist手写库识别(与tensorflow对比)

到此这篇关于Pytorch框架实现mnist手写库识别(与tensorflow对比)的文章就介绍到这了,更多相关Pytorch框架实现mnist手写库识别(与tensorflow对比)内容请搜索三水点靠木以前的文章或继续浏览下面的相关文章希望大家以后多多支持三水点靠木!

Python 相关文章推荐
Python操作Mysql实例代码教程在线版(查询手册)
Feb 18 Python
详解Python中映射类型(字典)操作符的概念和使用
Aug 19 Python
CentOS 7下安装Python 3.5并与Python2.7兼容并存详解
Jul 07 Python
matlab中实现矩阵删除一行或一列的方法
Apr 04 Python
对python csv模块配置分隔符和引用符详解
Dec 12 Python
pandas删除指定行详解
Apr 04 Python
Python2与Python3的区别实例总结
Apr 17 Python
Python 解决OPEN读文件报错 ,路径以及r的问题
Dec 19 Python
python基于celery实现异步任务周期任务定时任务
Dec 30 Python
python3.9实现pyinstaller打包python文件成exe
Dec 13 Python
python+opencv3.4.0 实现HOG+SVM行人检测的示例代码
Jan 28 Python
python使用pymysql模块操作MySQL
Jun 16 Python
python集合能干吗
Jul 19 #Python
python如何建立全零数组
Jul 19 #Python
解决python中0x80072ee2错误的方法
Jul 19 #Python
python给视频添加背景音乐并改变音量的具体方法
Jul 19 #Python
python中加背景音乐如何操作
Jul 19 #Python
python实现最短路径的实例方法
Jul 19 #Python
python等待10秒执行下一命令的方法
Jul 19 #Python
You might like
如何使用php输出时间格式
2013/08/31 PHP
php查询whois信息的方法
2015/06/08 PHP
php文件上传类的分享
2017/07/06 PHP
[原创]php token使用与验证示例【测试可用】
2017/08/30 PHP
xml 与javascript结合的问题解决方法
2007/03/24 Javascript
JavaScript在IE和Firefox浏览器下的7个差异兼容写法小结
2010/06/18 Javascript
jQuery操作select下拉框的text值和value值的方法
2014/05/31 Javascript
详解JavaScript中的forEach()方法的使用
2015/06/08 Javascript
JS实现点击按钮控制Div变宽、增高及调整背景色的方法
2015/08/05 Javascript
Javascript验证方法大全
2015/09/21 Javascript
浅谈JavaScript函数的四种存在形态
2016/06/08 Javascript
Vue.js开发环境搭建
2016/11/10 Javascript
vue中添加mp3音频文件的方法
2018/03/02 Javascript
vue中关闭eslint的方法分析
2018/08/04 Javascript
Vue.js 事件修饰符的使用教程
2018/11/01 Javascript
微信小程序bindtap事件与冒泡阻止详解
2019/08/08 Javascript
如何基于layui的laytpl实现数据绑定的示例代码
2020/04/10 Javascript
详解ES6 扩展运算符的使用与注意事项
2020/11/12 Javascript
iview实现动态表单和自定义验证时间段重叠
2021/01/10 Javascript
[01:02:30]Mineski vs Secret 2019国际邀请赛淘汰赛 败者组 BO3 第三场 8.22
2019/09/05 DOTA
如何解决django配置settings时遇到Could not import settings 'conf.local'
2014/11/18 Python
Python 遍历列表里面序号和值的方法(三种)
2017/02/17 Python
python实现批量图片格式转换
2020/06/16 Python
python+opencv打开摄像头,保存视频、拍照功能的实现方法
2019/01/08 Python
python实现烟花小程序
2019/01/30 Python
python实现随机漫步方法和原理
2019/06/10 Python
使用Python和OpenCV检测图像中的物体并将物体裁剪下来
2019/10/30 Python
Flask框架 CSRF 保护实现方法详解
2019/10/30 Python
Python @property原理解析和用法实例
2020/02/11 Python
python判断元素是否存在的实例方法
2020/09/24 Python
导师评语大全
2014/04/26 职场文书
公司授权委托书格式范文
2014/10/02 职场文书
2015年教师党员承诺书
2015/04/27 职场文书
2015年度女工工作总结
2015/10/22 职场文书
人民币使用说明书
2019/04/17 职场文书
2019学校运动会开幕词
2019/05/13 职场文书