Pytorch实现的手写数字mnist识别功能完整示例


Posted in Python onDecember 13, 2019

本文实例讲述了Pytorch实现的手写数字mnist识别功能。分享给大家供大家参考,具体如下:

import torch
import torchvision as tv
import torchvision.transforms as transforms
import torch.nn as nn
import torch.optim as optim
import argparse
# 定义是否使用GPU
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# 定义网络结构
class LeNet(nn.Module):
  def __init__(self):
    super(LeNet, self).__init__()
    self.conv1 = nn.Sequential(   #input_size=(1*28*28)
      nn.Conv2d(1, 6, 5, 1, 2), #padding=2保证输入输出尺寸相同
      nn.ReLU(),   #input_size=(6*28*28)
      nn.MaxPool2d(kernel_size=2, stride=2),#output_size=(6*14*14)
    )
    self.conv2 = nn.Sequential(
      nn.Conv2d(6, 16, 5),
      nn.ReLU(),   #input_size=(16*10*10)
      nn.MaxPool2d(2, 2) #output_size=(16*5*5)
    )
    self.fc1 = nn.Sequential(
      nn.Linear(16 * 5 * 5, 120),
      nn.ReLU()
    )
    self.fc2 = nn.Sequential(
      nn.Linear(120, 84),
      nn.ReLU()
    )
    self.fc3 = nn.Linear(84, 10)
  # 定义前向传播过程,输入为x
  def forward(self, x):
    x = self.conv1(x)
    x = self.conv2(x)
    # nn.Linear()的输入输出都是维度为一的值,所以要把多维度的tensor展平成一维
    x = x.view(x.size()[0], -1)
    x = self.fc1(x)
    x = self.fc2(x)
    x = self.fc3(x)
    return x
#使得我们能够手动输入命令行参数,就是让风格变得和Linux命令行差不多
parser = argparse.ArgumentParser()
parser.add_argument('--outf', default='./model/', help='folder to output images and model checkpoints') #模型保存路径
parser.add_argument('--net', default='./model/net.pth', help="path to netG (to continue training)") #模型加载路径
opt = parser.parse_args()
# 超参数设置
EPOCH = 8  #遍历数据集次数
BATCH_SIZE = 64   #批处理尺寸(batch_size)
LR = 0.001    #学习率
# 定义数据预处理方式
transform = transforms.ToTensor()
# 定义训练数据集
trainset = tv.datasets.MNIST(
  root='./data/',
  train=True,
  download=True,
  transform=transform)
# 定义训练批处理数据
trainloader = torch.utils.data.DataLoader(
  trainset,
  batch_size=BATCH_SIZE,
  shuffle=True,
  )
# 定义测试数据集
testset = tv.datasets.MNIST(
  root='./data/',
  train=False,
  download=True,
  transform=transform)
# 定义测试批处理数据
testloader = torch.utils.data.DataLoader(
  testset,
  batch_size=BATCH_SIZE,
  shuffle=False,
  )
# 定义损失函数loss function 和优化方式(采用SGD)
net = LeNet().to(device)
criterion = nn.CrossEntropyLoss() # 交叉熵损失函数,通常用于多分类问题上
optimizer = optim.SGD(net.parameters(), lr=LR, momentum=0.9)
# 训练
if __name__ == "__main__":
  for epoch in range(EPOCH):
    sum_loss = 0.0
    # 数据读取
    for i, data in enumerate(trainloader):
      inputs, labels = data
      inputs, labels = inputs.to(device), labels.to(device)
      # 梯度清零
      optimizer.zero_grad()
      # forward + backward
      outputs = net(inputs)
      loss = criterion(outputs, labels)
      loss.backward()
      optimizer.step()
      # 每训练100个batch打印一次平均loss
      sum_loss += loss.item()
      if i % 100 == 99:
        print('[%d, %d] loss: %.03f'
           % (epoch + 1, i + 1, sum_loss / 100))
        sum_loss = 0.0
    # 每跑完一次epoch测试一下准确率
    with torch.no_grad():
      correct = 0
      total = 0
      for data in testloader:
        images, labels = data
        images, labels = images.to(device), labels.to(device)
        outputs = net(images)
        # 取得分最高的那个类
        _, predicted = torch.max(outputs.data, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum()
      print('第%d个epoch的识别准确率为:%d%%' % (epoch + 1, (100 * correct / total)))
  #torch.save(net.state_dict(), '%s/net_%03d.pth' % (opt.outf, epoch + 1))

更多关于Python相关内容可查看本站专题:《Python数学运算技巧总结》、《Python图片操作技巧总结》、《Python数据结构与算法教程》、《Python函数使用技巧总结》、《Python字符串操作技巧汇总》及《Python入门与进阶经典教程》

希望本文所述对大家Python程序设计有所帮助。

Python 相关文章推荐
python基础教程之面向对象的一些概念
Aug 29 Python
CentOS安装pillow报错的解决方法
Jan 27 Python
python进阶_浅谈面向对象进阶
Aug 17 Python
对python3 urllib包与http包的使用详解
May 10 Python
Python实现动态添加属性和方法操作示例
Jul 25 Python
python 实现敏感词过滤的方法
Jan 21 Python
Python生成MD5值的两种方法实例分析
Apr 26 Python
Django框架表单操作实例分析
Nov 04 Python
pytorch torchvision.ImageFolder的用法介绍
Feb 20 Python
opencv+pyQt5实现图片阈值编辑器/寻色块阈值利器
Nov 13 Python
使用Python爬虫爬取小红书完完整整的全过程
Jan 19 Python
python 提取html文本的方法
May 20 Python
使用matplotlib绘制图例标签中带有公式的图
Dec 13 #Python
Python实现将蓝底照片转化为白底照片功能完整实例
Dec 13 #Python
python多进程重复加载的解决方式
Dec 13 #Python
使用pyqt5 tablewidget 单元格设置正则表达式
Dec 13 #Python
Python代码块及缓存机制原理详解
Dec 13 #Python
Python3和pyqt5实现控件数据动态显示方式
Dec 13 #Python
python实现简单日志记录库glog的使用
Dec 13 #Python
You might like
phpmyadmin提示The mbstring extension is missing的解决方法
2014/12/17 PHP
Smarty foreach控制循环次数的一些方法
2015/07/01 PHP
php实现的错误处理封装类实例
2017/06/20 PHP
基于JQuery的一句话搞定手风琴菜单
2012/09/14 Javascript
禁止iframe脚本弹出的窗口覆盖了父窗口的方法
2014/09/06 Javascript
js仿黑客帝国字母掉落效果代码分享
2020/11/08 Javascript
解决js图片加载时出现404的问题
2020/11/30 Javascript
javascript:void(0)点击登录没反应怎么解决
2015/11/13 Javascript
JS中闭包的经典用法小结(2则示例)
2016/12/28 Javascript
ajax图片上传,图片异步上传,更新实例
2016/12/30 Javascript
js实现百度登录框鼠标拖拽效果
2017/03/07 Javascript
js编写选项卡效果
2017/05/23 Javascript
Vue实战之vue登录验证的实现代码
2017/10/31 Javascript
vue打包后显示空白正确处理方法
2017/11/01 Javascript
bootstrap treeview 扩展addNode方法动态添加子节点的方法
2017/11/21 Javascript
JS运动改变单物体透明度的方法分析
2018/01/23 Javascript
jQuery实现的简单对话框拖动功能示例
2018/06/05 jQuery
iconfont的三种使用方式详解
2018/08/05 Javascript
详解React 条件渲染
2020/07/08 Javascript
python解析模块(ConfigParser)使用方法
2013/12/10 Python
python求解水仙花数的方法
2015/05/11 Python
python实现识别相似图片小结
2016/02/22 Python
python微信跳一跳游戏辅助代码解析
2018/01/29 Python
python2 与 python3 实现共存的方法
2018/07/12 Python
Python最小二乘法矩阵
2019/01/02 Python
python实现多层感知器MLP(基于双月数据集)
2019/01/18 Python
python实现列表中最大最小值输出的示例
2019/07/09 Python
django多对多表的创建,级联删除及手动创建第三张表
2019/07/25 Python
tensorflow 重置/清除计算图的实现
2020/01/19 Python
CSS3实现大小不一的粒子旋转加载动画
2016/04/21 HTML / CSS
html5嵌入内容_动力节点Java学院整理
2017/07/07 HTML / CSS
HTML5在线预览PDF的示例代码
2017/09/14 HTML / CSS
世界上最大的在线旅行社新加坡网站:Expedia新加坡
2016/08/25 全球购物
全球性的在线购物网站:Zapals
2017/03/22 全球购物
工作睡觉检讨书
2014/02/25 职场文书
Mysql 如何实现多张无关联表查询数据并分页
2021/06/05 MySQL