Pytorch框架实现mnist手写库识别(与tensorflow对比)


Posted in Python onJuly 20, 2020

前言最近在学习过程中需要用到pytorch框架,简单学习了一下,写了一个简单的案例,记录一下pytorch中搭建一个识别网络基础的东西。对应一位博主写的tensorflow的识别mnist数据集,将其改为pytorch框架,也可以详细看到两个框架大体的区别。

Tensorflow版本转载来源(CSDN博主「兔八哥1024」):https://3water.com/article/191157.htm

Pytorch实战mnist手写数字识别

#需要导入的包
import torch
import torch.nn as nn#用于构建网络层
import torch.optim as optim#导入优化器
from torch.utils.data import DataLoader#加载数据集的迭代器
from torchvision import datasets, transforms#用于加载mnsit数据集

#下载数据集

train_set = datasets.MNIST('./data', train=True, download=True,transform = transforms.Compose([
         transforms.ToTensor(),
         transforms.Normalize((0.1037,), (0.3081,))
       ]))
test_set = datasets.MNIST('./data', train=False, download=True,transform = transforms.Compose([
         transforms.ToTensor(),
         transforms.Normalize((0.1037,), (0.3081,))
       ]))

#构建网络(网络结构对应tensorflow的那一篇文章)

class Net(nn.Module):

  def __init__(self, num_classes=10):
    super(Net, self).__init__()
    self.features = nn.Sequential(
      nn.Conv2d(1, 32, kernel_size=5, stride=1, padding=2),
      nn.MaxPool2d(kernel_size=2,stride=2),
      nn.Conv2d(32, 64, kernel_size=5, stride=1, padding=2),
      nn.MaxPool2d(kernel_size=2,stride=2),

    )
    self.classifier = nn.Sequential(
      nn.Linear(3136, 7*7*64),
      nn.Linear(3136, num_classes),

    )

  def forward(self,x):
    x = self.features(x)
    x = torch.flatten(x, 1)
    x = self.classifier(x)

    return x
net=Net()
net.cuda()#用GPU运行

#计算误差,使用adam优化器优化误差
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(net.parameters(), 1e-2)

train_data = DataLoader(train_set, batch_size=128, shuffle=True)
test_data = DataLoader(test_set, batch_size=128, shuffle=False)


#训练过程
for epoch in range(1):
  net.train() ##在进行训练时加上train(),测试时加上eval()
  batch = 0

  for batch_images, batch_labels in train_data:

    average_loss = 0
    train_acc = 0

    ##在pytorch0.4之后将Variable 与tensor进行合并,所以这里不需要进行Variable封装
    if torch.cuda.is_available():
      batch_images, batch_labels = batch_images.cuda(),batch_labels.cuda()

    #前向传播
    out = net(batch_images)
    loss = criterion(out,batch_labels)


    average_loss = loss
    prediction = torch.max(out,1)[1]
    # print(prediction)

    train_correct = (prediction == batch_labels).sum()
    ##这里得到的train_correct是一个longtensor型,需要转换为float

    train_acc = (train_correct.float()) / 128

    optimizer.zero_grad() #清空梯度信息,否则在每次进行反向传播时都会累加
    loss.backward() #loss反向传播
    optimizer.step() ##梯度更新

    batch+=1
    print("Epoch: %d/%d || batch:%d/%d average_loss: %.3f || train_acc: %.2f"
       %(epoch, 20, batch, float(int(50000/128)), average_loss, train_acc))

# 在测试集上检验效果
net.eval() # 将模型改为预测模式
for idx,(im1, label1) in enumerate(test_data):
  if torch.cuda.is_available():
    im, label = im1.cuda(),label1.cuda()
  out = net(im)
  loss = criterion(out, label)

  eval_loss = loss

  pred = torch.max(out,1)[1]
  num_correct = (pred == label).sum()
  acc = (num_correct.float())/ 128
  eval_acc = acc

  print('EVA_Batch:{}, Eval Loss: {:.6f}, Eval Acc: {:.6f}'
   .format(idx,eval_loss , eval_acc))

运行结果:

Pytorch框架实现mnist手写库识别(与tensorflow对比)

到此这篇关于Pytorch框架实现mnist手写库识别(与tensorflow对比)的文章就介绍到这了,更多相关Pytorch框架实现mnist手写库识别(与tensorflow对比)内容请搜索三水点靠木以前的文章或继续浏览下面的相关文章希望大家以后多多支持三水点靠木!

Python 相关文章推荐
Python编程语言的35个与众不同之处(语言特征和使用技巧)
Jul 07 Python
Python 使用SMTP发送邮件的代码小结
Sep 21 Python
用Pygal绘制直方图代码示例
Dec 07 Python
使用django-crontab实现定时任务的示例
Feb 26 Python
python简易实现任意位数的水仙花实例
Nov 13 Python
python打包exe开机自动启动的实例(windows)
Jun 28 Python
Mac在python3环境下安装virtualwrapper遇到的问题及解决方法
Jul 09 Python
Django中提示消息messages的设置方式
Nov 15 Python
Matplotlib scatter绘制散点图的方法实现
Jan 02 Python
python编写softmax函数、交叉熵函数实例
Jun 11 Python
python 实现控制鼠标键盘
Nov 27 Python
Selenium关闭INFO:CONSOLE提示的解决
Dec 07 Python
python集合能干吗
Jul 19 #Python
python如何建立全零数组
Jul 19 #Python
解决python中0x80072ee2错误的方法
Jul 19 #Python
python给视频添加背景音乐并改变音量的具体方法
Jul 19 #Python
python中加背景音乐如何操作
Jul 19 #Python
python实现最短路径的实例方法
Jul 19 #Python
python等待10秒执行下一命令的方法
Jul 19 #Python
You might like
评分9.0以上的动画电影,剧情除了经典还很燃
2020/03/04 日漫
phpMyAdmin 链接表的附加功能尚未激活问题的解决方法(已测)
2012/03/27 PHP
基于magic_quotes_gpc与magic_quotes_runtime的区别与使用介绍
2013/04/22 PHP
php数组删除元素示例
2014/03/21 PHP
php中error与exception的区别及应用
2014/07/28 PHP
PHP实现的只保留字符串首尾字符功能示例【隐藏部分字符串】
2019/03/11 PHP
jQuery 页面载入进度条实现代码
2009/02/08 Javascript
js中document.getElementByid、document.all和document.layers区分介绍
2011/12/08 Javascript
JavaScript 数组some()和filter()的用法及区别
2016/05/20 Javascript
JavaScript 自定义事件之我见
2017/09/25 Javascript
swiper在vue项目中loop循环轮播失效的解决方法
2018/09/15 Javascript
[48:48]VGJ.T vs Liquid 2018国际邀请赛小组赛BO2 第二场 8.19
2018/08/21 DOTA
Python常用的内置序列结构(列表、元组、字典)学习笔记
2016/07/08 Python
Python多线程经典问题之乘客做公交车算法实例
2017/03/22 Python
TensorFlow实现MLP多层感知机模型
2018/03/09 Python
PyTorch上实现卷积神经网络CNN的方法
2018/04/28 Python
Python 可变类型和不可变类型及引用过程解析
2019/09/27 Python
Python如何实现强制数据类型转换
2019/11/22 Python
pycharm如何使用anaconda中的各种包(操作步骤)
2020/07/31 Python
python 爬虫基本使用——统计杭电oj题目正确率并排序
2020/10/26 Python
python+selenium爬取微博热搜存入Mysql的实现方法
2021/01/27 Python
Pytorch 图像变换函数集合小结
2021/02/01 Python
canvas环形倒计时组件的示例代码
2018/06/14 HTML / CSS
canvas实现烟花的示例代码
2020/01/16 HTML / CSS
Elemis美国官网:英国的第一豪华护肤品牌
2018/03/15 全球购物
美国女士内衣在线折扣商店:One Hanes Place
2019/03/24 全球购物
大学生毕业自我评价范文分享
2013/11/07 职场文书
酒店总经理助理岗位职责
2014/02/01 职场文书
大学生个人求职口试自我评价
2014/02/16 职场文书
停车位租赁协议书
2014/09/24 职场文书
社区法制宣传月活动总结
2015/05/07 职场文书
小学中队长竞选稿
2015/11/20 职场文书
护士岗前培训心得体会
2016/01/08 职场文书
pytorch锁死在dataloader(训练时卡死)
2021/05/28 Python
vue实现Toast组件轻提示
2022/04/10 Vue.js
CSS使用Flex和Grid布局实现3D骰子
2022/08/05 HTML / CSS