Pytorch实现的手写数字mnist识别功能完整示例


Posted in Python onDecember 13, 2019

本文实例讲述了Pytorch实现的手写数字mnist识别功能。分享给大家供大家参考,具体如下:

import torch
import torchvision as tv
import torchvision.transforms as transforms
import torch.nn as nn
import torch.optim as optim
import argparse
# 定义是否使用GPU
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# 定义网络结构
class LeNet(nn.Module):
  def __init__(self):
    super(LeNet, self).__init__()
    self.conv1 = nn.Sequential(   #input_size=(1*28*28)
      nn.Conv2d(1, 6, 5, 1, 2), #padding=2保证输入输出尺寸相同
      nn.ReLU(),   #input_size=(6*28*28)
      nn.MaxPool2d(kernel_size=2, stride=2),#output_size=(6*14*14)
    )
    self.conv2 = nn.Sequential(
      nn.Conv2d(6, 16, 5),
      nn.ReLU(),   #input_size=(16*10*10)
      nn.MaxPool2d(2, 2) #output_size=(16*5*5)
    )
    self.fc1 = nn.Sequential(
      nn.Linear(16 * 5 * 5, 120),
      nn.ReLU()
    )
    self.fc2 = nn.Sequential(
      nn.Linear(120, 84),
      nn.ReLU()
    )
    self.fc3 = nn.Linear(84, 10)
  # 定义前向传播过程,输入为x
  def forward(self, x):
    x = self.conv1(x)
    x = self.conv2(x)
    # nn.Linear()的输入输出都是维度为一的值,所以要把多维度的tensor展平成一维
    x = x.view(x.size()[0], -1)
    x = self.fc1(x)
    x = self.fc2(x)
    x = self.fc3(x)
    return x
#使得我们能够手动输入命令行参数,就是让风格变得和Linux命令行差不多
parser = argparse.ArgumentParser()
parser.add_argument('--outf', default='./model/', help='folder to output images and model checkpoints') #模型保存路径
parser.add_argument('--net', default='./model/net.pth', help="path to netG (to continue training)") #模型加载路径
opt = parser.parse_args()
# 超参数设置
EPOCH = 8  #遍历数据集次数
BATCH_SIZE = 64   #批处理尺寸(batch_size)
LR = 0.001    #学习率
# 定义数据预处理方式
transform = transforms.ToTensor()
# 定义训练数据集
trainset = tv.datasets.MNIST(
  root='./data/',
  train=True,
  download=True,
  transform=transform)
# 定义训练批处理数据
trainloader = torch.utils.data.DataLoader(
  trainset,
  batch_size=BATCH_SIZE,
  shuffle=True,
  )
# 定义测试数据集
testset = tv.datasets.MNIST(
  root='./data/',
  train=False,
  download=True,
  transform=transform)
# 定义测试批处理数据
testloader = torch.utils.data.DataLoader(
  testset,
  batch_size=BATCH_SIZE,
  shuffle=False,
  )
# 定义损失函数loss function 和优化方式(采用SGD)
net = LeNet().to(device)
criterion = nn.CrossEntropyLoss() # 交叉熵损失函数,通常用于多分类问题上
optimizer = optim.SGD(net.parameters(), lr=LR, momentum=0.9)
# 训练
if __name__ == "__main__":
  for epoch in range(EPOCH):
    sum_loss = 0.0
    # 数据读取
    for i, data in enumerate(trainloader):
      inputs, labels = data
      inputs, labels = inputs.to(device), labels.to(device)
      # 梯度清零
      optimizer.zero_grad()
      # forward + backward
      outputs = net(inputs)
      loss = criterion(outputs, labels)
      loss.backward()
      optimizer.step()
      # 每训练100个batch打印一次平均loss
      sum_loss += loss.item()
      if i % 100 == 99:
        print('[%d, %d] loss: %.03f'
           % (epoch + 1, i + 1, sum_loss / 100))
        sum_loss = 0.0
    # 每跑完一次epoch测试一下准确率
    with torch.no_grad():
      correct = 0
      total = 0
      for data in testloader:
        images, labels = data
        images, labels = images.to(device), labels.to(device)
        outputs = net(images)
        # 取得分最高的那个类
        _, predicted = torch.max(outputs.data, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum()
      print('第%d个epoch的识别准确率为:%d%%' % (epoch + 1, (100 * correct / total)))
  #torch.save(net.state_dict(), '%s/net_%03d.pth' % (opt.outf, epoch + 1))

更多关于Python相关内容可查看本站专题:《Python数学运算技巧总结》、《Python图片操作技巧总结》、《Python数据结构与算法教程》、《Python函数使用技巧总结》、《Python字符串操作技巧汇总》及《Python入门与进阶经典教程》

希望本文所述对大家Python程序设计有所帮助。

Python 相关文章推荐
pycharm 使用心得(九)解决No Python interpreter selected的问题
Jun 06 Python
用Python设计一个经典小游戏
May 15 Python
Python内置函数reversed()用法分析
Mar 20 Python
Python基于pycrypto实现的AES加密和解密算法示例
Apr 10 Python
python 批量修改/替换数据的实例
Jul 25 Python
python Pandas如何对数据集随机抽样
Jul 29 Python
django框架cookie和session用法实例详解
Dec 10 Python
Python hashlib加密模块常用方法解析
Dec 18 Python
python实现文字版扫雷
Apr 24 Python
Python爬虫之Spider类用法简单介绍
Aug 04 Python
Python 找出英文单词列表(list)中最长单词链
Dec 14 Python
Python 第三方库 openpyxl 的安装过程
Dec 24 Python
使用matplotlib绘制图例标签中带有公式的图
Dec 13 #Python
Python实现将蓝底照片转化为白底照片功能完整实例
Dec 13 #Python
python多进程重复加载的解决方式
Dec 13 #Python
使用pyqt5 tablewidget 单元格设置正则表达式
Dec 13 #Python
Python代码块及缓存机制原理详解
Dec 13 #Python
Python3和pyqt5实现控件数据动态显示方式
Dec 13 #Python
python实现简单日志记录库glog的使用
Dec 13 #Python
You might like
PHP文件去掉PHP注释空格的函数分析(PHP代码压缩)
2013/07/02 PHP
php教程之phpize使用方法
2014/02/12 PHP
使用PHP接受文件并获得其后缀名的方法
2015/08/05 PHP
利用PHP判断是否是连乘数字串的方法示例
2017/07/03 PHP
ThinkPHP5 的简单搭建和使用详解
2018/11/15 PHP
JavaScript中void(0)的具体含义解释
2007/02/27 Javascript
js原生态函数中使用jQuery中的 $(this)无效的解决方法
2011/05/25 Javascript
一个小例子解释如何来阻止Jquery事件冒泡
2014/07/17 Javascript
基于javascript实现判断移动终端浏览器版本信息
2014/12/09 Javascript
在Javascript操作JSON对象,增加 删除 修改的简单实现
2016/06/02 Javascript
ros::spin() 和 ros::spinOnce()函数的区别及详解
2016/10/01 Javascript
简单实现js点击展开二级菜单功能
2017/05/16 Javascript
es6新特性之 class 基本用法解析
2018/05/05 Javascript
详解.vue文件解析的实现
2018/06/11 Javascript
微信小程序实时聊天WebSocket
2018/07/05 Javascript
ES6 Promise对象的应用实例分析
2019/06/27 Javascript
vue进入页面时不在顶部,检测滚动返回顶部按钮问题及解决方法
2019/10/30 Javascript
Vue实现渲染数据后控制滚动条位置(推荐)
2019/12/09 Javascript
JavaScript实现多个物体同时运动
2020/03/12 Javascript
JavaScript实现简单计算器
2020/03/19 Javascript
sharp.js安装过程中遇到的问题总结
2020/04/02 Javascript
使用beaker让Facebook的Bottle框架支持session功能
2015/04/23 Python
Python 中的 else详解
2016/04/23 Python
pyenv与virtualenv安装实现python多版本多项目管理
2019/08/17 Python
详解Python 字符串相似性的几种度量方法
2019/08/29 Python
解决django model修改添加字段报错的问题
2019/11/18 Python
如何基于Python实现word文档重新排版
2020/09/29 Python
网易微博Web App用HTML5开发的过程介绍
2012/06/13 HTML / CSS
美国百货齐全的精品网站,提供美式风格的产品:Overstock.com
2016/07/22 全球购物
马来西亚排名第一的宠物用品店:Pets Wonderland
2020/04/16 全球购物
远程学习的教学用品和家庭学习资源:Really Good Stuff
2020/04/27 全球购物
农行实习自我鉴定
2013/09/22 职场文书
加强干部作风建设整改方案
2014/10/24 职场文书
2015元旦晚会主持词(开场白+结束语)
2014/12/14 职场文书
2015年财政所工作总结
2015/04/25 职场文书
2019年聘任书的写作格式及范文!
2019/07/03 职场文书