Pytorch实现的手写数字mnist识别功能完整示例


Posted in Python onDecember 13, 2019

本文实例讲述了Pytorch实现的手写数字mnist识别功能。分享给大家供大家参考,具体如下:

import torch
import torchvision as tv
import torchvision.transforms as transforms
import torch.nn as nn
import torch.optim as optim
import argparse
# 定义是否使用GPU
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# 定义网络结构
class LeNet(nn.Module):
  def __init__(self):
    super(LeNet, self).__init__()
    self.conv1 = nn.Sequential(   #input_size=(1*28*28)
      nn.Conv2d(1, 6, 5, 1, 2), #padding=2保证输入输出尺寸相同
      nn.ReLU(),   #input_size=(6*28*28)
      nn.MaxPool2d(kernel_size=2, stride=2),#output_size=(6*14*14)
    )
    self.conv2 = nn.Sequential(
      nn.Conv2d(6, 16, 5),
      nn.ReLU(),   #input_size=(16*10*10)
      nn.MaxPool2d(2, 2) #output_size=(16*5*5)
    )
    self.fc1 = nn.Sequential(
      nn.Linear(16 * 5 * 5, 120),
      nn.ReLU()
    )
    self.fc2 = nn.Sequential(
      nn.Linear(120, 84),
      nn.ReLU()
    )
    self.fc3 = nn.Linear(84, 10)
  # 定义前向传播过程,输入为x
  def forward(self, x):
    x = self.conv1(x)
    x = self.conv2(x)
    # nn.Linear()的输入输出都是维度为一的值,所以要把多维度的tensor展平成一维
    x = x.view(x.size()[0], -1)
    x = self.fc1(x)
    x = self.fc2(x)
    x = self.fc3(x)
    return x
#使得我们能够手动输入命令行参数,就是让风格变得和Linux命令行差不多
parser = argparse.ArgumentParser()
parser.add_argument('--outf', default='./model/', help='folder to output images and model checkpoints') #模型保存路径
parser.add_argument('--net', default='./model/net.pth', help="path to netG (to continue training)") #模型加载路径
opt = parser.parse_args()
# 超参数设置
EPOCH = 8  #遍历数据集次数
BATCH_SIZE = 64   #批处理尺寸(batch_size)
LR = 0.001    #学习率
# 定义数据预处理方式
transform = transforms.ToTensor()
# 定义训练数据集
trainset = tv.datasets.MNIST(
  root='./data/',
  train=True,
  download=True,
  transform=transform)
# 定义训练批处理数据
trainloader = torch.utils.data.DataLoader(
  trainset,
  batch_size=BATCH_SIZE,
  shuffle=True,
  )
# 定义测试数据集
testset = tv.datasets.MNIST(
  root='./data/',
  train=False,
  download=True,
  transform=transform)
# 定义测试批处理数据
testloader = torch.utils.data.DataLoader(
  testset,
  batch_size=BATCH_SIZE,
  shuffle=False,
  )
# 定义损失函数loss function 和优化方式(采用SGD)
net = LeNet().to(device)
criterion = nn.CrossEntropyLoss() # 交叉熵损失函数,通常用于多分类问题上
optimizer = optim.SGD(net.parameters(), lr=LR, momentum=0.9)
# 训练
if __name__ == "__main__":
  for epoch in range(EPOCH):
    sum_loss = 0.0
    # 数据读取
    for i, data in enumerate(trainloader):
      inputs, labels = data
      inputs, labels = inputs.to(device), labels.to(device)
      # 梯度清零
      optimizer.zero_grad()
      # forward + backward
      outputs = net(inputs)
      loss = criterion(outputs, labels)
      loss.backward()
      optimizer.step()
      # 每训练100个batch打印一次平均loss
      sum_loss += loss.item()
      if i % 100 == 99:
        print('[%d, %d] loss: %.03f'
           % (epoch + 1, i + 1, sum_loss / 100))
        sum_loss = 0.0
    # 每跑完一次epoch测试一下准确率
    with torch.no_grad():
      correct = 0
      total = 0
      for data in testloader:
        images, labels = data
        images, labels = images.to(device), labels.to(device)
        outputs = net(images)
        # 取得分最高的那个类
        _, predicted = torch.max(outputs.data, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum()
      print('第%d个epoch的识别准确率为:%d%%' % (epoch + 1, (100 * correct / total)))
  #torch.save(net.state_dict(), '%s/net_%03d.pth' % (opt.outf, epoch + 1))

更多关于Python相关内容可查看本站专题:《Python数学运算技巧总结》、《Python图片操作技巧总结》、《Python数据结构与算法教程》、《Python函数使用技巧总结》、《Python字符串操作技巧汇总》及《Python入门与进阶经典教程》

希望本文所述对大家Python程序设计有所帮助。

Python 相关文章推荐
centos下更新Python版本的步骤
Feb 12 Python
Python实现扫描指定目录下的子目录及文件的方法
Jul 16 Python
Python中Collections模块的Counter容器类使用教程
May 31 Python
Python使用回溯法子集树模板获取最长公共子序列(LCS)的方法
Sep 08 Python
[原创]教女朋友学Python(一)运行环境搭建
Nov 29 Python
安装python时MySQLdb报错的问题描述及解决方法
Mar 20 Python
django 开发忘记密码通过邮箱找回功能示例
Apr 17 Python
python实现将读入的多维list转为一维list的方法
Jun 28 Python
对python中两种列表元素去重函数性能的比较方法
Jun 29 Python
python使用opencv驱动摄像头的方法
Aug 03 Python
python中栈的原理及实现方法示例
Nov 27 Python
在tensorflow以及keras安装目录查询操作(windows下)
Jun 19 Python
使用matplotlib绘制图例标签中带有公式的图
Dec 13 #Python
Python实现将蓝底照片转化为白底照片功能完整实例
Dec 13 #Python
python多进程重复加载的解决方式
Dec 13 #Python
使用pyqt5 tablewidget 单元格设置正则表达式
Dec 13 #Python
Python代码块及缓存机制原理详解
Dec 13 #Python
Python3和pyqt5实现控件数据动态显示方式
Dec 13 #Python
python实现简单日志记录库glog的使用
Dec 13 #Python
You might like
通过PHP的内置函数,通过DES算法对数据加密和解密
2012/06/21 PHP
PHP类的反射用法实例
2014/11/03 PHP
一个js拖拽的效果类和dom-drag.js浅析
2010/07/17 Javascript
Javascript动态绑定事件的简单实现代码
2010/12/25 Javascript
基于jquery的button默认enter事件(回车事件)。
2011/05/18 Javascript
javascript实现鼠标拖动改变层大小的方法
2015/04/30 Javascript
JavaScript取得键盘按下方向键是哪个的方法
2015/08/04 Javascript
学习JavaScript设计模式之中介者模式
2016/01/14 Javascript
教你如何终止JQUERY的$.AJAX请求
2016/02/23 Javascript
jQuery控制div实现随滚动条滚动效果
2016/06/07 Javascript
jQuery实现可拖拽3D万花筒旋转特效
2017/01/03 Javascript
localStorage的黑科技-js和css缓存机制
2017/02/06 Javascript
详解Angular-cli生成组件修改css成less或sass的实例
2017/07/27 Javascript
webpack开发跨域问题解决办法
2017/08/03 Javascript
swiper插件自定义切换箭头按钮
2017/12/28 Javascript
JS 音频可视化插件Wavesurfer.js的使用教程
2018/10/31 Javascript
简单实现vue中的依赖收集与响应的方法
2019/02/18 Javascript
json数据格式常见操作示例
2019/06/13 Javascript
js实现鼠标拖拽div左右滑动
2020/01/15 Javascript
vue开发简单上传图片功能
2020/06/30 Javascript
[06:49]2018DOTA2国际邀请赛寻真——VirtusPro傲视群雄
2018/08/12 DOTA
python之import机制详解
2014/07/03 Python
Python中生成器和迭代器的区别详解
2018/02/10 Python
Python常用库大全及简要说明
2020/01/17 Python
浅谈Tensorflow 动态双向RNN的输出问题
2020/01/20 Python
Python字符串格式化常用手段及注意事项
2020/06/17 Python
python 实现表情识别
2020/11/21 Python
俄罗斯GamePark游戏商店网站:购买游戏、游戏机和配件
2020/03/13 全球购物
接口中的方法可以是abstract的吗
2015/07/23 面试题
办公室文秘自我鉴定
2013/09/21 职场文书
计算机毕业大学生推荐信
2013/12/01 职场文书
领导干部学习“三严三实”思想汇报
2014/09/15 职场文书
2014年维修工作总结
2014/11/22 职场文书
男方婚礼答谢词
2015/01/20 职场文书
Spring Boot项目传参校验的最佳实践指南
2022/04/05 Java/Android
MySQL的表级锁,行级锁,排它锁和共享锁
2022/07/15 MySQL