Pytorch实现的手写数字mnist识别功能完整示例


Posted in Python onDecember 13, 2019

本文实例讲述了Pytorch实现的手写数字mnist识别功能。分享给大家供大家参考,具体如下:

import torch
import torchvision as tv
import torchvision.transforms as transforms
import torch.nn as nn
import torch.optim as optim
import argparse
# 定义是否使用GPU
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# 定义网络结构
class LeNet(nn.Module):
  def __init__(self):
    super(LeNet, self).__init__()
    self.conv1 = nn.Sequential(   #input_size=(1*28*28)
      nn.Conv2d(1, 6, 5, 1, 2), #padding=2保证输入输出尺寸相同
      nn.ReLU(),   #input_size=(6*28*28)
      nn.MaxPool2d(kernel_size=2, stride=2),#output_size=(6*14*14)
    )
    self.conv2 = nn.Sequential(
      nn.Conv2d(6, 16, 5),
      nn.ReLU(),   #input_size=(16*10*10)
      nn.MaxPool2d(2, 2) #output_size=(16*5*5)
    )
    self.fc1 = nn.Sequential(
      nn.Linear(16 * 5 * 5, 120),
      nn.ReLU()
    )
    self.fc2 = nn.Sequential(
      nn.Linear(120, 84),
      nn.ReLU()
    )
    self.fc3 = nn.Linear(84, 10)
  # 定义前向传播过程,输入为x
  def forward(self, x):
    x = self.conv1(x)
    x = self.conv2(x)
    # nn.Linear()的输入输出都是维度为一的值,所以要把多维度的tensor展平成一维
    x = x.view(x.size()[0], -1)
    x = self.fc1(x)
    x = self.fc2(x)
    x = self.fc3(x)
    return x
#使得我们能够手动输入命令行参数,就是让风格变得和Linux命令行差不多
parser = argparse.ArgumentParser()
parser.add_argument('--outf', default='./model/', help='folder to output images and model checkpoints') #模型保存路径
parser.add_argument('--net', default='./model/net.pth', help="path to netG (to continue training)") #模型加载路径
opt = parser.parse_args()
# 超参数设置
EPOCH = 8  #遍历数据集次数
BATCH_SIZE = 64   #批处理尺寸(batch_size)
LR = 0.001    #学习率
# 定义数据预处理方式
transform = transforms.ToTensor()
# 定义训练数据集
trainset = tv.datasets.MNIST(
  root='./data/',
  train=True,
  download=True,
  transform=transform)
# 定义训练批处理数据
trainloader = torch.utils.data.DataLoader(
  trainset,
  batch_size=BATCH_SIZE,
  shuffle=True,
  )
# 定义测试数据集
testset = tv.datasets.MNIST(
  root='./data/',
  train=False,
  download=True,
  transform=transform)
# 定义测试批处理数据
testloader = torch.utils.data.DataLoader(
  testset,
  batch_size=BATCH_SIZE,
  shuffle=False,
  )
# 定义损失函数loss function 和优化方式(采用SGD)
net = LeNet().to(device)
criterion = nn.CrossEntropyLoss() # 交叉熵损失函数,通常用于多分类问题上
optimizer = optim.SGD(net.parameters(), lr=LR, momentum=0.9)
# 训练
if __name__ == "__main__":
  for epoch in range(EPOCH):
    sum_loss = 0.0
    # 数据读取
    for i, data in enumerate(trainloader):
      inputs, labels = data
      inputs, labels = inputs.to(device), labels.to(device)
      # 梯度清零
      optimizer.zero_grad()
      # forward + backward
      outputs = net(inputs)
      loss = criterion(outputs, labels)
      loss.backward()
      optimizer.step()
      # 每训练100个batch打印一次平均loss
      sum_loss += loss.item()
      if i % 100 == 99:
        print('[%d, %d] loss: %.03f'
           % (epoch + 1, i + 1, sum_loss / 100))
        sum_loss = 0.0
    # 每跑完一次epoch测试一下准确率
    with torch.no_grad():
      correct = 0
      total = 0
      for data in testloader:
        images, labels = data
        images, labels = images.to(device), labels.to(device)
        outputs = net(images)
        # 取得分最高的那个类
        _, predicted = torch.max(outputs.data, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum()
      print('第%d个epoch的识别准确率为:%d%%' % (epoch + 1, (100 * correct / total)))
  #torch.save(net.state_dict(), '%s/net_%03d.pth' % (opt.outf, epoch + 1))

更多关于Python相关内容可查看本站专题:《Python数学运算技巧总结》、《Python图片操作技巧总结》、《Python数据结构与算法教程》、《Python函数使用技巧总结》、《Python字符串操作技巧汇总》及《Python入门与进阶经典教程》

希望本文所述对大家Python程序设计有所帮助。

Python 相关文章推荐
巧用Python装饰器 免去调用父类构造函数的麻烦
May 18 Python
一个小示例告诉你Python语言的优雅之处
Jul 04 Python
用Python编写一个基于终端的实现翻译的脚本
Apr 24 Python
python简单获取本机计算机名和IP地址的方法
Jun 03 Python
python嵌套字典比较值与取值的实现示例
Nov 03 Python
python中的内置函数max()和min()及mas()函数的高级用法
Mar 29 Python
解决Python2.7读写文件中的中文乱码问题
Apr 12 Python
Python使用win32 COM实现Excel的写入与保存功能示例
May 03 Python
使用python绘制二维图形示例
Nov 22 Python
Python批量启动多线程代码实例
Feb 18 Python
如何表示python中的相对路径
Jul 08 Python
Python如何批量生成和调用变量
Nov 21 Python
使用matplotlib绘制图例标签中带有公式的图
Dec 13 #Python
Python实现将蓝底照片转化为白底照片功能完整实例
Dec 13 #Python
python多进程重复加载的解决方式
Dec 13 #Python
使用pyqt5 tablewidget 单元格设置正则表达式
Dec 13 #Python
Python代码块及缓存机制原理详解
Dec 13 #Python
Python3和pyqt5实现控件数据动态显示方式
Dec 13 #Python
python实现简单日志记录库glog的使用
Dec 13 #Python
You might like
PHP之十六个魔术方法详细介绍
2016/11/01 PHP
JavaScript 检测浏览器和操作系统的脚本
2008/12/26 Javascript
ComboBox 和 DateField 在IE下消失的解决方法
2013/08/30 Javascript
jQuery+canvas实现简单的球体斜抛及颜色动态变换效果
2016/01/28 Javascript
javascript弹出窗口中增加确定取消按钮
2016/06/24 Javascript
AngularJS基础 ng-href 指令用法
2016/08/01 Javascript
JavaScript中Array的实用操作技巧分享
2016/09/11 Javascript
javascript兼容性(实例讲解)
2017/08/15 Javascript
React Native模块之Permissions权限申请的实例相机
2017/09/28 Javascript
vue2.0 循环遍历加载不同图片的方法
2018/03/06 Javascript
jQuery pjax 应用简单示例
2018/09/20 jQuery
JS 音频可视化插件Wavesurfer.js的使用教程
2018/10/31 Javascript
JS 自执行函数原理及用法
2019/08/05 Javascript
详谈vue中router-link和传统a链接的区别
2020/07/22 Javascript
vant 中van-list的用法说明
2020/11/11 Javascript
[01:09]DOTA2次级职业联赛 - 99战队宣传片
2014/12/01 DOTA
python使用Flask框架获取用户IP地址的方法
2015/03/21 Python
Python实现二分查找算法实例
2015/05/26 Python
Python实现TCP协议下的端口映射功能的脚本程序示例
2016/06/14 Python
Python3中使用urllib的方法详解(header,代理,超时,认证,异常处理)
2016/09/21 Python
python实现各进制转换的总结大全
2017/06/18 Python
Python3多线程爬虫实例讲解代码
2018/01/05 Python
Python正则表达式指南 推荐
2018/10/09 Python
python3 实现一行输入,空格隔开的示例
2018/11/14 Python
python3使用flask编写注册post接口的方法
2018/12/28 Python
对numpy下的轴交换transpose和swapaxes的示例解读
2019/06/26 Python
python之openpyxl模块的安装和基本用法(excel管理)
2021/02/03 Python
小学生期末评语
2014/04/21 职场文书
镇副书记专题民主生活会对照检查材料思想汇报
2014/10/02 职场文书
2014年学校法制宣传日活动总结
2014/11/01 职场文书
师范生小学见习总结
2015/06/23 职场文书
分家协议书范本
2016/03/22 职场文书
2019邀请函格式及范文
2019/05/20 职场文书
python3.9之你应该知道的新特性详解
2021/04/29 Python
对Golang中的FORM相关字段理解
2021/05/02 Golang
springboot入门 之profile设置方式
2022/04/04 Java/Android