Pytorch实现的手写数字mnist识别功能完整示例


Posted in Python onDecember 13, 2019

本文实例讲述了Pytorch实现的手写数字mnist识别功能。分享给大家供大家参考,具体如下:

import torch
import torchvision as tv
import torchvision.transforms as transforms
import torch.nn as nn
import torch.optim as optim
import argparse
# 定义是否使用GPU
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# 定义网络结构
class LeNet(nn.Module):
  def __init__(self):
    super(LeNet, self).__init__()
    self.conv1 = nn.Sequential(   #input_size=(1*28*28)
      nn.Conv2d(1, 6, 5, 1, 2), #padding=2保证输入输出尺寸相同
      nn.ReLU(),   #input_size=(6*28*28)
      nn.MaxPool2d(kernel_size=2, stride=2),#output_size=(6*14*14)
    )
    self.conv2 = nn.Sequential(
      nn.Conv2d(6, 16, 5),
      nn.ReLU(),   #input_size=(16*10*10)
      nn.MaxPool2d(2, 2) #output_size=(16*5*5)
    )
    self.fc1 = nn.Sequential(
      nn.Linear(16 * 5 * 5, 120),
      nn.ReLU()
    )
    self.fc2 = nn.Sequential(
      nn.Linear(120, 84),
      nn.ReLU()
    )
    self.fc3 = nn.Linear(84, 10)
  # 定义前向传播过程,输入为x
  def forward(self, x):
    x = self.conv1(x)
    x = self.conv2(x)
    # nn.Linear()的输入输出都是维度为一的值,所以要把多维度的tensor展平成一维
    x = x.view(x.size()[0], -1)
    x = self.fc1(x)
    x = self.fc2(x)
    x = self.fc3(x)
    return x
#使得我们能够手动输入命令行参数,就是让风格变得和Linux命令行差不多
parser = argparse.ArgumentParser()
parser.add_argument('--outf', default='./model/', help='folder to output images and model checkpoints') #模型保存路径
parser.add_argument('--net', default='./model/net.pth', help="path to netG (to continue training)") #模型加载路径
opt = parser.parse_args()
# 超参数设置
EPOCH = 8  #遍历数据集次数
BATCH_SIZE = 64   #批处理尺寸(batch_size)
LR = 0.001    #学习率
# 定义数据预处理方式
transform = transforms.ToTensor()
# 定义训练数据集
trainset = tv.datasets.MNIST(
  root='./data/',
  train=True,
  download=True,
  transform=transform)
# 定义训练批处理数据
trainloader = torch.utils.data.DataLoader(
  trainset,
  batch_size=BATCH_SIZE,
  shuffle=True,
  )
# 定义测试数据集
testset = tv.datasets.MNIST(
  root='./data/',
  train=False,
  download=True,
  transform=transform)
# 定义测试批处理数据
testloader = torch.utils.data.DataLoader(
  testset,
  batch_size=BATCH_SIZE,
  shuffle=False,
  )
# 定义损失函数loss function 和优化方式(采用SGD)
net = LeNet().to(device)
criterion = nn.CrossEntropyLoss() # 交叉熵损失函数,通常用于多分类问题上
optimizer = optim.SGD(net.parameters(), lr=LR, momentum=0.9)
# 训练
if __name__ == "__main__":
  for epoch in range(EPOCH):
    sum_loss = 0.0
    # 数据读取
    for i, data in enumerate(trainloader):
      inputs, labels = data
      inputs, labels = inputs.to(device), labels.to(device)
      # 梯度清零
      optimizer.zero_grad()
      # forward + backward
      outputs = net(inputs)
      loss = criterion(outputs, labels)
      loss.backward()
      optimizer.step()
      # 每训练100个batch打印一次平均loss
      sum_loss += loss.item()
      if i % 100 == 99:
        print('[%d, %d] loss: %.03f'
           % (epoch + 1, i + 1, sum_loss / 100))
        sum_loss = 0.0
    # 每跑完一次epoch测试一下准确率
    with torch.no_grad():
      correct = 0
      total = 0
      for data in testloader:
        images, labels = data
        images, labels = images.to(device), labels.to(device)
        outputs = net(images)
        # 取得分最高的那个类
        _, predicted = torch.max(outputs.data, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum()
      print('第%d个epoch的识别准确率为:%d%%' % (epoch + 1, (100 * correct / total)))
  #torch.save(net.state_dict(), '%s/net_%03d.pth' % (opt.outf, epoch + 1))

更多关于Python相关内容可查看本站专题:《Python数学运算技巧总结》、《Python图片操作技巧总结》、《Python数据结构与算法教程》、《Python函数使用技巧总结》、《Python字符串操作技巧汇总》及《Python入门与进阶经典教程》

希望本文所述对大家Python程序设计有所帮助。

Python 相关文章推荐
Python内置函数bin() oct()等实现进制转换
Dec 30 Python
python通过urllib2爬网页上种子下载示例
Feb 24 Python
Python的__builtin__模块中的一些要点知识
May 02 Python
Windows下Python使用Pandas模块操作Excel文件的教程
May 31 Python
python算法表示概念扫盲教程
Apr 13 Python
Python2.7下安装Scrapy框架步骤教程
Dec 22 Python
Django中的forms组件实例详解
Nov 08 Python
带你认识Django
Jan 15 Python
python解析命令行参数的三种方法详解
Nov 29 Python
Pyspark读取parquet数据过程解析
Mar 27 Python
Anaconda安装pytorch及配置PyCharm 2021环境
Jun 04 Python
详解Python+OpenCV进行基础的图像操作
Feb 15 Python
使用matplotlib绘制图例标签中带有公式的图
Dec 13 #Python
Python实现将蓝底照片转化为白底照片功能完整实例
Dec 13 #Python
python多进程重复加载的解决方式
Dec 13 #Python
使用pyqt5 tablewidget 单元格设置正则表达式
Dec 13 #Python
Python代码块及缓存机制原理详解
Dec 13 #Python
Python3和pyqt5实现控件数据动态显示方式
Dec 13 #Python
python实现简单日志记录库glog的使用
Dec 13 #Python
You might like
PHP字符串 ==比较运算符的副作用
2009/10/21 PHP
用PHP查询搜索引擎排名位置的代码
2010/01/05 PHP
php中处理mysql_fetch_assoc返回来的数组 不用foreach----echo
2011/05/04 PHP
php学习笔记 面向对象的构造与析构方法
2011/06/13 PHP
PHP中$_SERVER的详细参数与说明介绍
2013/10/26 PHP
微信公众号点击菜单即可打开并登录微站的实现方法
2014/11/14 PHP
PHP中把对象数组转换成普通数组的方法
2015/07/10 PHP
百度留言本js 大家可以参考下
2009/10/13 Javascript
jquery实现的让超出显示范围外的导航自动固定屏幕最顶上
2011/09/22 Javascript
Prototype源码浅析 String部分(二)
2012/01/16 Javascript
通过JavaScript使Div居中并随网页大小改变而改变
2013/06/24 Javascript
jQuery自定义事件的简单实现代码
2014/01/27 Javascript
Javascript中prototype属性实现给内置对象添加新的方法
2015/05/14 Javascript
阿里巴巴技术文章分享 Javascript继承机制的实现
2016/01/14 Javascript
浅谈JS的基础类型与引用类型
2016/09/13 Javascript
jQuery基于ajax方式实现用户名存在性检查功能示例
2017/02/10 Javascript
js获取元素下的第一级子元素的方法(推荐)
2017/03/05 Javascript
Easyui和zTree两种方式分别实现树形下拉框
2017/08/04 Javascript
详解Node中导入模块require和import的区别
2017/08/11 Javascript
基于jsbarcode 生成条形码并将生成的条码保存至本地+源码
2020/04/27 Javascript
[01:21]2018DOTA2亚洲邀请赛4.5采访 打DOTA2也能有女朋友?
2018/04/06 DOTA
python使用PIL缩放网络图片并保存的方法
2015/04/24 Python
Python实现SSH远程登陆,并执行命令的方法(分享)
2017/05/08 Python
python简单线程和协程学习心得(分享)
2017/06/14 Python
python爬虫中多线程的使用详解
2019/09/23 Python
python飞机大战pygame游戏背景设计详解
2019/12/17 Python
详解python itertools功能
2020/02/07 Python
python GUI库图形界面开发之pyinstaller打包python程序为exe安装文件
2020/02/26 Python
在SQL Server中创建数据库主要有那种方式
2013/09/10 面试题
岗位廉洁从业承诺书
2014/03/28 职场文书
市级文明单位申报材料
2014/05/07 职场文书
质量负责人任命书
2014/06/06 职场文书
乡镇领导干部个人对照检查材料思想汇报
2014/09/23 职场文书
2014年图书馆个人工作总结
2014/12/18 职场文书
实践论读书笔记
2015/06/29 职场文书
解决xampp安装后Apache无法启动
2022/03/21 Servers