Pytorch实现的手写数字mnist识别功能完整示例


Posted in Python onDecember 13, 2019

本文实例讲述了Pytorch实现的手写数字mnist识别功能。分享给大家供大家参考,具体如下:

import torch
import torchvision as tv
import torchvision.transforms as transforms
import torch.nn as nn
import torch.optim as optim
import argparse
# 定义是否使用GPU
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# 定义网络结构
class LeNet(nn.Module):
  def __init__(self):
    super(LeNet, self).__init__()
    self.conv1 = nn.Sequential(   #input_size=(1*28*28)
      nn.Conv2d(1, 6, 5, 1, 2), #padding=2保证输入输出尺寸相同
      nn.ReLU(),   #input_size=(6*28*28)
      nn.MaxPool2d(kernel_size=2, stride=2),#output_size=(6*14*14)
    )
    self.conv2 = nn.Sequential(
      nn.Conv2d(6, 16, 5),
      nn.ReLU(),   #input_size=(16*10*10)
      nn.MaxPool2d(2, 2) #output_size=(16*5*5)
    )
    self.fc1 = nn.Sequential(
      nn.Linear(16 * 5 * 5, 120),
      nn.ReLU()
    )
    self.fc2 = nn.Sequential(
      nn.Linear(120, 84),
      nn.ReLU()
    )
    self.fc3 = nn.Linear(84, 10)
  # 定义前向传播过程,输入为x
  def forward(self, x):
    x = self.conv1(x)
    x = self.conv2(x)
    # nn.Linear()的输入输出都是维度为一的值,所以要把多维度的tensor展平成一维
    x = x.view(x.size()[0], -1)
    x = self.fc1(x)
    x = self.fc2(x)
    x = self.fc3(x)
    return x
#使得我们能够手动输入命令行参数,就是让风格变得和Linux命令行差不多
parser = argparse.ArgumentParser()
parser.add_argument('--outf', default='./model/', help='folder to output images and model checkpoints') #模型保存路径
parser.add_argument('--net', default='./model/net.pth', help="path to netG (to continue training)") #模型加载路径
opt = parser.parse_args()
# 超参数设置
EPOCH = 8  #遍历数据集次数
BATCH_SIZE = 64   #批处理尺寸(batch_size)
LR = 0.001    #学习率
# 定义数据预处理方式
transform = transforms.ToTensor()
# 定义训练数据集
trainset = tv.datasets.MNIST(
  root='./data/',
  train=True,
  download=True,
  transform=transform)
# 定义训练批处理数据
trainloader = torch.utils.data.DataLoader(
  trainset,
  batch_size=BATCH_SIZE,
  shuffle=True,
  )
# 定义测试数据集
testset = tv.datasets.MNIST(
  root='./data/',
  train=False,
  download=True,
  transform=transform)
# 定义测试批处理数据
testloader = torch.utils.data.DataLoader(
  testset,
  batch_size=BATCH_SIZE,
  shuffle=False,
  )
# 定义损失函数loss function 和优化方式(采用SGD)
net = LeNet().to(device)
criterion = nn.CrossEntropyLoss() # 交叉熵损失函数,通常用于多分类问题上
optimizer = optim.SGD(net.parameters(), lr=LR, momentum=0.9)
# 训练
if __name__ == "__main__":
  for epoch in range(EPOCH):
    sum_loss = 0.0
    # 数据读取
    for i, data in enumerate(trainloader):
      inputs, labels = data
      inputs, labels = inputs.to(device), labels.to(device)
      # 梯度清零
      optimizer.zero_grad()
      # forward + backward
      outputs = net(inputs)
      loss = criterion(outputs, labels)
      loss.backward()
      optimizer.step()
      # 每训练100个batch打印一次平均loss
      sum_loss += loss.item()
      if i % 100 == 99:
        print('[%d, %d] loss: %.03f'
           % (epoch + 1, i + 1, sum_loss / 100))
        sum_loss = 0.0
    # 每跑完一次epoch测试一下准确率
    with torch.no_grad():
      correct = 0
      total = 0
      for data in testloader:
        images, labels = data
        images, labels = images.to(device), labels.to(device)
        outputs = net(images)
        # 取得分最高的那个类
        _, predicted = torch.max(outputs.data, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum()
      print('第%d个epoch的识别准确率为:%d%%' % (epoch + 1, (100 * correct / total)))
  #torch.save(net.state_dict(), '%s/net_%03d.pth' % (opt.outf, epoch + 1))

更多关于Python相关内容可查看本站专题:《Python数学运算技巧总结》、《Python图片操作技巧总结》、《Python数据结构与算法教程》、《Python函数使用技巧总结》、《Python字符串操作技巧汇总》及《Python入门与进阶经典教程》

希望本文所述对大家Python程序设计有所帮助。

Python 相关文章推荐
linux系统使用python获取cpu信息脚本分享
Jan 15 Python
python机器学习理论与实战(二)决策树
Jan 19 Python
python3库numpy数组属性的查看方法
Apr 17 Python
PyCharm代码整体缩进,反向缩进的方法
Jun 25 Python
浅谈Python中的可迭代对象、迭代器、For循环工作机制、生成器
Mar 11 Python
pyinstaller打包多个py文件和去除cmd黑框的方法
Jun 21 Python
python安装pil库方法及代码
Jun 25 Python
python读取图片的方式,以及将图片以三维数组的形式输出方法
Jul 03 Python
Python Web版语音合成实例详解
Jul 16 Python
PyCharm汉化安装及永久激活详细教程(靠谱)
Jan 16 Python
python列表的逆序遍历实现
Apr 20 Python
Python多线程实用方法以及共享变量资源竞争问题
Apr 12 Python
使用matplotlib绘制图例标签中带有公式的图
Dec 13 #Python
Python实现将蓝底照片转化为白底照片功能完整实例
Dec 13 #Python
python多进程重复加载的解决方式
Dec 13 #Python
使用pyqt5 tablewidget 单元格设置正则表达式
Dec 13 #Python
Python代码块及缓存机制原理详解
Dec 13 #Python
Python3和pyqt5实现控件数据动态显示方式
Dec 13 #Python
python实现简单日志记录库glog的使用
Dec 13 #Python
You might like
使用php来实现网络服务
2009/09/15 PHP
修改php.ini以达到屏蔽错误信息并记录日志
2013/06/16 PHP
php对数组排序的简单实例
2013/12/25 PHP
php解决抢购秒杀抽奖等大流量并发入库导致的库存负数的问题
2014/06/19 PHP
PHP发送短信代码分享
2015/08/11 PHP
Codeigniter控制器controller继承问题实例分析
2016/01/19 PHP
通用JS事件写法实现代码
2009/01/07 Javascript
扩展Jquery插件处理mouseover时内部有子元素时发生样式闪烁
2011/12/08 Javascript
jQuery图片切换插件jquery.cycle.js使用示例
2014/06/16 Javascript
js匿名函数的调用示例(形式多种多样)
2014/08/20 Javascript
js获取UserControl内容为拼html时提供方便
2014/11/02 Javascript
javascript实现点击后变换按钮显示文字的方法
2015/05/13 Javascript
JavaScript中日期的相关操作方法总结
2015/10/24 Javascript
javascript十六进制数字和ASCII字符之间的转换方法
2016/12/27 Javascript
javascript实现文字无缝滚动
2016/12/27 Javascript
React实现点击删除列表中对应项
2017/01/10 Javascript
创建一般js对象的几种方式
2017/01/19 Javascript
Vue2 配置 Axios api 接口调用文件的方法
2017/11/13 Javascript
vue采用EventBus实现跨组件通信及注意事项小结
2018/06/14 Javascript
Node.js API详解之 assert模块用法实例分析
2020/05/26 Javascript
[48:05]2018DOTA2亚洲邀请赛 3.31 小组赛 B组 VGJ.T vs VP
2018/03/31 DOTA
numpy判断数值类型、过滤出数值型数据的方法
2018/06/09 Python
Python利用递归实现文件的复制方法
2018/10/27 Python
python使用KNN算法识别手写数字
2019/04/25 Python
Python实现Linux监控的方法
2019/05/16 Python
Python使用正则实现计算字符串算式
2019/12/29 Python
Python 随机生成测试数据的模块:faker基本使用方法详解
2020/04/09 Python
Python SMTP配置参数并发送邮件
2020/06/16 Python
澳大利亚优惠网站:Deals.com.au
2019/07/02 全球购物
Java工程师面试集锦之Spring框架
2013/06/16 面试题
银行类自荐信
2014/02/04 职场文书
团员年度个人总结
2015/02/26 职场文书
汽车4S店前台接待岗位职责
2015/04/03 职场文书
离婚起诉状范本
2015/05/19 职场文书
送给小学生的暑假礼物!小学生必背99首古诗
2019/07/02 职场文书
使用python将HTML转换为PDF pdfkit包(wkhtmltopdf) 的使用方法
2022/04/21 Python