Pytorch框架实现mnist手写库识别(与tensorflow对比)


Posted in Python onJuly 20, 2020

前言最近在学习过程中需要用到pytorch框架,简单学习了一下,写了一个简单的案例,记录一下pytorch中搭建一个识别网络基础的东西。对应一位博主写的tensorflow的识别mnist数据集,将其改为pytorch框架,也可以详细看到两个框架大体的区别。

Tensorflow版本转载来源(CSDN博主「兔八哥1024」):https://3water.com/article/191157.htm

Pytorch实战mnist手写数字识别

#需要导入的包
import torch
import torch.nn as nn#用于构建网络层
import torch.optim as optim#导入优化器
from torch.utils.data import DataLoader#加载数据集的迭代器
from torchvision import datasets, transforms#用于加载mnsit数据集

#下载数据集

train_set = datasets.MNIST('./data', train=True, download=True,transform = transforms.Compose([
         transforms.ToTensor(),
         transforms.Normalize((0.1037,), (0.3081,))
       ]))
test_set = datasets.MNIST('./data', train=False, download=True,transform = transforms.Compose([
         transforms.ToTensor(),
         transforms.Normalize((0.1037,), (0.3081,))
       ]))

#构建网络(网络结构对应tensorflow的那一篇文章)

class Net(nn.Module):

  def __init__(self, num_classes=10):
    super(Net, self).__init__()
    self.features = nn.Sequential(
      nn.Conv2d(1, 32, kernel_size=5, stride=1, padding=2),
      nn.MaxPool2d(kernel_size=2,stride=2),
      nn.Conv2d(32, 64, kernel_size=5, stride=1, padding=2),
      nn.MaxPool2d(kernel_size=2,stride=2),

    )
    self.classifier = nn.Sequential(
      nn.Linear(3136, 7*7*64),
      nn.Linear(3136, num_classes),

    )

  def forward(self,x):
    x = self.features(x)
    x = torch.flatten(x, 1)
    x = self.classifier(x)

    return x
net=Net()
net.cuda()#用GPU运行

#计算误差,使用adam优化器优化误差
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(net.parameters(), 1e-2)

train_data = DataLoader(train_set, batch_size=128, shuffle=True)
test_data = DataLoader(test_set, batch_size=128, shuffle=False)


#训练过程
for epoch in range(1):
  net.train() ##在进行训练时加上train(),测试时加上eval()
  batch = 0

  for batch_images, batch_labels in train_data:

    average_loss = 0
    train_acc = 0

    ##在pytorch0.4之后将Variable 与tensor进行合并,所以这里不需要进行Variable封装
    if torch.cuda.is_available():
      batch_images, batch_labels = batch_images.cuda(),batch_labels.cuda()

    #前向传播
    out = net(batch_images)
    loss = criterion(out,batch_labels)


    average_loss = loss
    prediction = torch.max(out,1)[1]
    # print(prediction)

    train_correct = (prediction == batch_labels).sum()
    ##这里得到的train_correct是一个longtensor型,需要转换为float

    train_acc = (train_correct.float()) / 128

    optimizer.zero_grad() #清空梯度信息,否则在每次进行反向传播时都会累加
    loss.backward() #loss反向传播
    optimizer.step() ##梯度更新

    batch+=1
    print("Epoch: %d/%d || batch:%d/%d average_loss: %.3f || train_acc: %.2f"
       %(epoch, 20, batch, float(int(50000/128)), average_loss, train_acc))

# 在测试集上检验效果
net.eval() # 将模型改为预测模式
for idx,(im1, label1) in enumerate(test_data):
  if torch.cuda.is_available():
    im, label = im1.cuda(),label1.cuda()
  out = net(im)
  loss = criterion(out, label)

  eval_loss = loss

  pred = torch.max(out,1)[1]
  num_correct = (pred == label).sum()
  acc = (num_correct.float())/ 128
  eval_acc = acc

  print('EVA_Batch:{}, Eval Loss: {:.6f}, Eval Acc: {:.6f}'
   .format(idx,eval_loss , eval_acc))

运行结果:

Pytorch框架实现mnist手写库识别(与tensorflow对比)

到此这篇关于Pytorch框架实现mnist手写库识别(与tensorflow对比)的文章就介绍到这了,更多相关Pytorch框架实现mnist手写库识别(与tensorflow对比)内容请搜索三水点靠木以前的文章或继续浏览下面的相关文章希望大家以后多多支持三水点靠木!

Python 相关文章推荐
PHP网页抓取之抓取百度贴吧邮箱数据代码分享
Apr 13 Python
Python使用pyh生成HTML文档的方法示例
Mar 10 Python
python读取文本中数据并转化为DataFrame的实例
Apr 10 Python
利用Python将数值型特征进行离散化操作的方法
Nov 06 Python
Python3.6.x中内置函数总结及讲解
Feb 22 Python
pandas.cut具体使用总结
Jun 24 Python
docker django无法访问redis容器的解决方法
Aug 21 Python
python求平均数、方差、中位数的例子
Aug 22 Python
解决Django migrate不能发现app.models的表问题
Aug 31 Python
python实现可下载音乐的音乐播放器
Feb 25 Python
python图片验证码识别最新模块muggle_ocr的示例代码
Jul 03 Python
python如何实现读取并显示图片(不需要图形界面)
Jul 08 Python
python集合能干吗
Jul 19 #Python
python如何建立全零数组
Jul 19 #Python
解决python中0x80072ee2错误的方法
Jul 19 #Python
python给视频添加背景音乐并改变音量的具体方法
Jul 19 #Python
python中加背景音乐如何操作
Jul 19 #Python
python实现最短路径的实例方法
Jul 19 #Python
python等待10秒执行下一命令的方法
Jul 19 #Python
You might like
用PHP实现递归循环每一个目录
2010/08/08 PHP
php中将地址生成迅雷快车旋风链接的代码[测试通过]
2011/04/20 PHP
php统计文件大小,以GB、MB、KB、B输出
2011/05/29 PHP
为百度UE编辑器上传图片添加水印功能
2015/04/16 PHP
php获取从百度、谷歌等搜索引擎进入网站关键词的方法
2015/07/08 PHP
php生成唯一数字id的方法汇总
2015/11/18 PHP
ecshop适应在PHP7的修改方法解决报错的实现
2016/11/01 PHP
PHP编程获取各个时间段具体时间的方法
2017/05/26 PHP
php变量与字符串的增删改查操作示例
2020/05/07 PHP
对JavaScript的eval()中使用函数的进一步讨论
2008/07/26 Javascript
JavaScript中使用document.write向页面输出内容实例
2014/10/16 Javascript
JQuery简单实现锚点链接的平滑滚动
2015/05/03 Javascript
prototype.js常用函数详解
2016/06/18 Javascript
JS实现随机颜色的3种方法与颜色格式的转化
2017/01/05 Javascript
浅析bootstrap原理及优缺点
2017/03/19 Javascript
基于JavaScript实现瀑布流布局
2018/08/15 Javascript
解决vuecli3.0热更新失效的问题
2018/09/19 Javascript
vue调试工具vue-devtools安装及使用方法
2018/11/07 Javascript
Vue快速实现通用表单验证的方法
2020/02/24 Javascript
[01:03:38]2014 DOTA2国际邀请赛中国区预选赛5.21 CNB VS CIS
2014/05/22 DOTA
Python实现的微信支付方式总结【三种方式】
2019/04/13 Python
Python实现实时数据采集新型冠状病毒数据实例
2020/02/04 Python
使用CSS3的appearance属性改变任何元素的浏览器默认风格
2012/12/24 HTML / CSS
英国设计师泳装、沙滩装和比基尼在线精品店:Beach Cafe
2019/08/28 全球购物
俄罗斯电动工具和设备购物网站:Vseinstrumenti.ru
2020/11/12 全球购物
安全检查与奖惩制度
2014/01/23 职场文书
小学爱国卫生月活动总结
2014/06/30 职场文书
软件研发工程师岗位职责
2014/09/30 职场文书
2014年教务工作总结
2014/12/03 职场文书
推广普通话的宣传语
2015/07/13 职场文书
公安干警正风肃纪心得体会
2016/01/15 职场文书
Python基础之赋值,浅拷贝,深拷贝的区别
2021/04/30 Python
关于python类SortedList详解
2021/09/04 Python
分享7个 Python 实战项目练习
2022/03/03 Python
使用Redis实现分布式锁的方法
2022/06/16 Redis
使用CSS实现音波加载效果
2023/05/07 HTML / CSS