Pytorch实现的手写数字mnist识别功能完整示例


Posted in Python onDecember 13, 2019

本文实例讲述了Pytorch实现的手写数字mnist识别功能。分享给大家供大家参考,具体如下:

import torch
import torchvision as tv
import torchvision.transforms as transforms
import torch.nn as nn
import torch.optim as optim
import argparse
# 定义是否使用GPU
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# 定义网络结构
class LeNet(nn.Module):
  def __init__(self):
    super(LeNet, self).__init__()
    self.conv1 = nn.Sequential(   #input_size=(1*28*28)
      nn.Conv2d(1, 6, 5, 1, 2), #padding=2保证输入输出尺寸相同
      nn.ReLU(),   #input_size=(6*28*28)
      nn.MaxPool2d(kernel_size=2, stride=2),#output_size=(6*14*14)
    )
    self.conv2 = nn.Sequential(
      nn.Conv2d(6, 16, 5),
      nn.ReLU(),   #input_size=(16*10*10)
      nn.MaxPool2d(2, 2) #output_size=(16*5*5)
    )
    self.fc1 = nn.Sequential(
      nn.Linear(16 * 5 * 5, 120),
      nn.ReLU()
    )
    self.fc2 = nn.Sequential(
      nn.Linear(120, 84),
      nn.ReLU()
    )
    self.fc3 = nn.Linear(84, 10)
  # 定义前向传播过程,输入为x
  def forward(self, x):
    x = self.conv1(x)
    x = self.conv2(x)
    # nn.Linear()的输入输出都是维度为一的值,所以要把多维度的tensor展平成一维
    x = x.view(x.size()[0], -1)
    x = self.fc1(x)
    x = self.fc2(x)
    x = self.fc3(x)
    return x
#使得我们能够手动输入命令行参数,就是让风格变得和Linux命令行差不多
parser = argparse.ArgumentParser()
parser.add_argument('--outf', default='./model/', help='folder to output images and model checkpoints') #模型保存路径
parser.add_argument('--net', default='./model/net.pth', help="path to netG (to continue training)") #模型加载路径
opt = parser.parse_args()
# 超参数设置
EPOCH = 8  #遍历数据集次数
BATCH_SIZE = 64   #批处理尺寸(batch_size)
LR = 0.001    #学习率
# 定义数据预处理方式
transform = transforms.ToTensor()
# 定义训练数据集
trainset = tv.datasets.MNIST(
  root='./data/',
  train=True,
  download=True,
  transform=transform)
# 定义训练批处理数据
trainloader = torch.utils.data.DataLoader(
  trainset,
  batch_size=BATCH_SIZE,
  shuffle=True,
  )
# 定义测试数据集
testset = tv.datasets.MNIST(
  root='./data/',
  train=False,
  download=True,
  transform=transform)
# 定义测试批处理数据
testloader = torch.utils.data.DataLoader(
  testset,
  batch_size=BATCH_SIZE,
  shuffle=False,
  )
# 定义损失函数loss function 和优化方式(采用SGD)
net = LeNet().to(device)
criterion = nn.CrossEntropyLoss() # 交叉熵损失函数,通常用于多分类问题上
optimizer = optim.SGD(net.parameters(), lr=LR, momentum=0.9)
# 训练
if __name__ == "__main__":
  for epoch in range(EPOCH):
    sum_loss = 0.0
    # 数据读取
    for i, data in enumerate(trainloader):
      inputs, labels = data
      inputs, labels = inputs.to(device), labels.to(device)
      # 梯度清零
      optimizer.zero_grad()
      # forward + backward
      outputs = net(inputs)
      loss = criterion(outputs, labels)
      loss.backward()
      optimizer.step()
      # 每训练100个batch打印一次平均loss
      sum_loss += loss.item()
      if i % 100 == 99:
        print('[%d, %d] loss: %.03f'
           % (epoch + 1, i + 1, sum_loss / 100))
        sum_loss = 0.0
    # 每跑完一次epoch测试一下准确率
    with torch.no_grad():
      correct = 0
      total = 0
      for data in testloader:
        images, labels = data
        images, labels = images.to(device), labels.to(device)
        outputs = net(images)
        # 取得分最高的那个类
        _, predicted = torch.max(outputs.data, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum()
      print('第%d个epoch的识别准确率为:%d%%' % (epoch + 1, (100 * correct / total)))
  #torch.save(net.state_dict(), '%s/net_%03d.pth' % (opt.outf, epoch + 1))

更多关于Python相关内容可查看本站专题:《Python数学运算技巧总结》、《Python图片操作技巧总结》、《Python数据结构与算法教程》、《Python函数使用技巧总结》、《Python字符串操作技巧汇总》及《Python入门与进阶经典教程》

希望本文所述对大家Python程序设计有所帮助。

Python 相关文章推荐
Python字符串和文件操作常用函数分析
Apr 08 Python
python2.7的编码问题与解决方法
Oct 04 Python
python 按照固定长度分割字符串的方法小结
Apr 30 Python
在Python中获取两数相除的商和余数方法
Nov 10 Python
详解配置Django的Celery异步之路踩坑
Nov 25 Python
Python sklearn KFold 生成交叉验证数据集的方法
Dec 11 Python
Django框架创建mysql连接与使用示例
Jul 29 Python
Tensorflow实现在训练好的模型上进行测试
Jan 20 Python
django 多数据库及分库实现方式
Apr 01 Python
python3+selenium获取页面加载的所有静态资源文件链接操作
May 04 Python
什么是Python包的循环导入
Sep 08 Python
python实现控制台输出颜色
Mar 02 Python
使用matplotlib绘制图例标签中带有公式的图
Dec 13 #Python
Python实现将蓝底照片转化为白底照片功能完整实例
Dec 13 #Python
python多进程重复加载的解决方式
Dec 13 #Python
使用pyqt5 tablewidget 单元格设置正则表达式
Dec 13 #Python
Python代码块及缓存机制原理详解
Dec 13 #Python
Python3和pyqt5实现控件数据动态显示方式
Dec 13 #Python
python实现简单日志记录库glog的使用
Dec 13 #Python
You might like
php模板中出现空行解决方法
2011/03/08 PHP
PHP加密函数 Javascript/Js 解密函数
2013/09/23 PHP
JQuery循环滚动图片代码
2011/12/08 Javascript
js控制frameSet示例
2013/09/10 Javascript
JavaScript给input的value赋值引发的关于基本类型值和引用类型值问题
2015/12/07 Javascript
原生JavaScript制作微博发布面板效果
2016/03/11 Javascript
jQuery 3.0中存在问题及解决办法
2016/07/15 Javascript
详解给Vue2路由导航钩子和axios拦截器做个封装
2018/04/10 Javascript
30分钟搭建Python的Flask框架并在上面编写第一个应用
2015/03/30 Python
Python切片知识解析
2016/03/06 Python
详解使用python crontab设置linux定时任务
2016/12/08 Python
Python操作使用MySQL数据库的实例代码
2017/05/25 Python
Python3 socket同步通信简单示例
2017/06/07 Python
python 重定向获取真实url的方法
2018/05/11 Python
Python(TensorFlow框架)实现手写数字识别系统的方法
2018/05/29 Python
python实现列表的排序方法分享
2019/07/01 Python
python二进制读写及特殊码同步实现详解
2019/10/11 Python
通过python扫描二维码/条形码并打印数据
2019/11/14 Python
Jupyter notebook设置背景主题,字体大小及自动补全代码的操作
2020/04/13 Python
Java ExcutorService优雅关闭方式解析
2020/05/30 Python
python numpy实现rolling滚动案例
2020/06/08 Python
HTML5中使用json对象的实例代码
2018/09/10 HTML / CSS
一家专门经营包包的英国网站:MyBag
2019/09/08 全球购物
我的求职计划书
2014/01/10 职场文书
带薪年假请假条
2014/02/04 职场文书
20年同学聚会邀请函
2014/02/04 职场文书
暑假社会实践心得体会
2014/09/02 职场文书
收款委托书范本
2014/09/11 职场文书
机电专业毕业生自我鉴定2014
2014/10/04 职场文书
2014年安全保卫工作总结
2014/11/13 职场文书
2015年七一建党节活动总结
2015/03/20 职场文书
nginx里的rewrite跳转的实现
2021/03/31 Servers
CSS 还能这样玩?奇思妙想渐变的艺术
2021/04/27 HTML / CSS
Springboot/Springcloud项目集成redis进行存取的过程解析
2021/12/04 Redis
《吸血鬼:避世 血猎》官宣4.27发售 系列首款大逃杀
2022/04/03 其他游戏
使用CSS定位HTML元素的实现方法
2022/07/07 HTML / CSS