Pytorch实现的手写数字mnist识别功能完整示例


Posted in Python onDecember 13, 2019

本文实例讲述了Pytorch实现的手写数字mnist识别功能。分享给大家供大家参考,具体如下:

import torch
import torchvision as tv
import torchvision.transforms as transforms
import torch.nn as nn
import torch.optim as optim
import argparse
# 定义是否使用GPU
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# 定义网络结构
class LeNet(nn.Module):
  def __init__(self):
    super(LeNet, self).__init__()
    self.conv1 = nn.Sequential(   #input_size=(1*28*28)
      nn.Conv2d(1, 6, 5, 1, 2), #padding=2保证输入输出尺寸相同
      nn.ReLU(),   #input_size=(6*28*28)
      nn.MaxPool2d(kernel_size=2, stride=2),#output_size=(6*14*14)
    )
    self.conv2 = nn.Sequential(
      nn.Conv2d(6, 16, 5),
      nn.ReLU(),   #input_size=(16*10*10)
      nn.MaxPool2d(2, 2) #output_size=(16*5*5)
    )
    self.fc1 = nn.Sequential(
      nn.Linear(16 * 5 * 5, 120),
      nn.ReLU()
    )
    self.fc2 = nn.Sequential(
      nn.Linear(120, 84),
      nn.ReLU()
    )
    self.fc3 = nn.Linear(84, 10)
  # 定义前向传播过程,输入为x
  def forward(self, x):
    x = self.conv1(x)
    x = self.conv2(x)
    # nn.Linear()的输入输出都是维度为一的值,所以要把多维度的tensor展平成一维
    x = x.view(x.size()[0], -1)
    x = self.fc1(x)
    x = self.fc2(x)
    x = self.fc3(x)
    return x
#使得我们能够手动输入命令行参数,就是让风格变得和Linux命令行差不多
parser = argparse.ArgumentParser()
parser.add_argument('--outf', default='./model/', help='folder to output images and model checkpoints') #模型保存路径
parser.add_argument('--net', default='./model/net.pth', help="path to netG (to continue training)") #模型加载路径
opt = parser.parse_args()
# 超参数设置
EPOCH = 8  #遍历数据集次数
BATCH_SIZE = 64   #批处理尺寸(batch_size)
LR = 0.001    #学习率
# 定义数据预处理方式
transform = transforms.ToTensor()
# 定义训练数据集
trainset = tv.datasets.MNIST(
  root='./data/',
  train=True,
  download=True,
  transform=transform)
# 定义训练批处理数据
trainloader = torch.utils.data.DataLoader(
  trainset,
  batch_size=BATCH_SIZE,
  shuffle=True,
  )
# 定义测试数据集
testset = tv.datasets.MNIST(
  root='./data/',
  train=False,
  download=True,
  transform=transform)
# 定义测试批处理数据
testloader = torch.utils.data.DataLoader(
  testset,
  batch_size=BATCH_SIZE,
  shuffle=False,
  )
# 定义损失函数loss function 和优化方式(采用SGD)
net = LeNet().to(device)
criterion = nn.CrossEntropyLoss() # 交叉熵损失函数,通常用于多分类问题上
optimizer = optim.SGD(net.parameters(), lr=LR, momentum=0.9)
# 训练
if __name__ == "__main__":
  for epoch in range(EPOCH):
    sum_loss = 0.0
    # 数据读取
    for i, data in enumerate(trainloader):
      inputs, labels = data
      inputs, labels = inputs.to(device), labels.to(device)
      # 梯度清零
      optimizer.zero_grad()
      # forward + backward
      outputs = net(inputs)
      loss = criterion(outputs, labels)
      loss.backward()
      optimizer.step()
      # 每训练100个batch打印一次平均loss
      sum_loss += loss.item()
      if i % 100 == 99:
        print('[%d, %d] loss: %.03f'
           % (epoch + 1, i + 1, sum_loss / 100))
        sum_loss = 0.0
    # 每跑完一次epoch测试一下准确率
    with torch.no_grad():
      correct = 0
      total = 0
      for data in testloader:
        images, labels = data
        images, labels = images.to(device), labels.to(device)
        outputs = net(images)
        # 取得分最高的那个类
        _, predicted = torch.max(outputs.data, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum()
      print('第%d个epoch的识别准确率为:%d%%' % (epoch + 1, (100 * correct / total)))
  #torch.save(net.state_dict(), '%s/net_%03d.pth' % (opt.outf, epoch + 1))

更多关于Python相关内容可查看本站专题:《Python数学运算技巧总结》、《Python图片操作技巧总结》、《Python数据结构与算法教程》、《Python函数使用技巧总结》、《Python字符串操作技巧汇总》及《Python入门与进阶经典教程》

希望本文所述对大家Python程序设计有所帮助。

Python 相关文章推荐
浅谈Python中数据解析
May 05 Python
python计算对角线有理函数插值的方法
May 07 Python
浅谈python中scipy.misc.logsumexp函数的运用场景
Jun 23 Python
Python使用smtp和pop简单收发邮件完整实例
Jan 09 Python
用xpath获取指定标签下的所有text的实例
Jan 02 Python
python如何爬取网站数据并进行数据可视化
Jul 08 Python
python被修饰的函数消失问题解决(基于wraps函数)
Nov 04 Python
Python多重继承之菱形继承的实例详解
Feb 12 Python
Python关键字及可变参数*args,**kw原理解析
Apr 04 Python
python实现简单遗传算法
Sep 18 Python
python tkinter Entry控件的焦点移动操作
May 22 Python
Anaconda安装pytorch及配置PyCharm 2021环境
Jun 04 Python
使用matplotlib绘制图例标签中带有公式的图
Dec 13 #Python
Python实现将蓝底照片转化为白底照片功能完整实例
Dec 13 #Python
python多进程重复加载的解决方式
Dec 13 #Python
使用pyqt5 tablewidget 单元格设置正则表达式
Dec 13 #Python
Python代码块及缓存机制原理详解
Dec 13 #Python
Python3和pyqt5实现控件数据动态显示方式
Dec 13 #Python
python实现简单日志记录库glog的使用
Dec 13 #Python
You might like
Wordpress php 分页代码
2009/10/21 PHP
php数组函数序列之prev() - 移动数组内部指针到上一个元素的位置,并返回该元素值
2011/10/31 PHP
PHP文件注释标记及规范小结
2012/04/01 PHP
使用Sphinx对索引进行搜索
2013/06/25 PHP
php统计文章排行示例
2014/03/04 PHP
PHP session 会话处理函数
2016/06/06 PHP
设置下载不需要倒计时cookie(倒计时代码)
2008/11/19 Javascript
javascript hashtable实现代码
2009/10/13 Javascript
基于Jquery实现的一个图片滚动切换
2012/06/21 Javascript
javascript获取url上某个参数的方法
2013/11/08 Javascript
Javascript变量作用域详解
2013/12/06 Javascript
node.js require() 源码解读
2015/12/13 Javascript
详解微信小程序开发之城市选择器 城市切换
2017/01/17 Javascript
Vue2.x中的Render函数详解
2017/05/30 Javascript
关于webpack2和模块打包的新手指南(小结)
2017/08/07 Javascript
JS浅拷贝和深拷贝原理与实现方法分析
2019/02/28 Javascript
Vue 中可以定义组件模版的几种方式
2019/08/06 Javascript
24个ES6方法解决JS实际开发问题(小结)
2020/05/31 Javascript
vue在App.vue文件中监听路由变化刷新页面操作
2020/08/14 Javascript
[07:09]DOTA2-DPC中国联赛 正赛 Ehome vs Elephant 选手采访
2021/03/11 DOTA
Python2.7简单连接与操作MySQL的方法
2016/04/27 Python
python+django加载静态网页模板解析
2017/12/12 Python
Python实现KNN邻近算法
2021/01/28 Python
python最长回文串算法
2018/06/04 Python
Django 日志配置按日期滚动的方法
2019/01/31 Python
Python读取xlsx文件的实现方法
2019/07/04 Python
浅谈Python_Openpyxl使用(最全总结)
2019/09/05 Python
让Django的BooleanField支持字符串形式的输入方式
2020/05/20 Python
查找适用于matplotlib的中文字体名称与实际文件名对应关系的方法
2021/01/05 Python
UNIX命令速查表
2012/03/10 面试题
乡镇干部先进事迹材料
2014/02/03 职场文书
民警个人对照检查剖析材料
2014/09/17 职场文书
务虚会发言材料
2014/12/25 职场文书
2015年电话销售工作总结范文
2015/04/20 职场文书
Golang中异常处理机制详解
2021/06/08 Golang
mysql性能优化以及配置连接参数设置
2022/05/06 MySQL