Pytorch实现的手写数字mnist识别功能完整示例


Posted in Python onDecember 13, 2019

本文实例讲述了Pytorch实现的手写数字mnist识别功能。分享给大家供大家参考,具体如下:

import torch
import torchvision as tv
import torchvision.transforms as transforms
import torch.nn as nn
import torch.optim as optim
import argparse
# 定义是否使用GPU
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# 定义网络结构
class LeNet(nn.Module):
  def __init__(self):
    super(LeNet, self).__init__()
    self.conv1 = nn.Sequential(   #input_size=(1*28*28)
      nn.Conv2d(1, 6, 5, 1, 2), #padding=2保证输入输出尺寸相同
      nn.ReLU(),   #input_size=(6*28*28)
      nn.MaxPool2d(kernel_size=2, stride=2),#output_size=(6*14*14)
    )
    self.conv2 = nn.Sequential(
      nn.Conv2d(6, 16, 5),
      nn.ReLU(),   #input_size=(16*10*10)
      nn.MaxPool2d(2, 2) #output_size=(16*5*5)
    )
    self.fc1 = nn.Sequential(
      nn.Linear(16 * 5 * 5, 120),
      nn.ReLU()
    )
    self.fc2 = nn.Sequential(
      nn.Linear(120, 84),
      nn.ReLU()
    )
    self.fc3 = nn.Linear(84, 10)
  # 定义前向传播过程,输入为x
  def forward(self, x):
    x = self.conv1(x)
    x = self.conv2(x)
    # nn.Linear()的输入输出都是维度为一的值,所以要把多维度的tensor展平成一维
    x = x.view(x.size()[0], -1)
    x = self.fc1(x)
    x = self.fc2(x)
    x = self.fc3(x)
    return x
#使得我们能够手动输入命令行参数,就是让风格变得和Linux命令行差不多
parser = argparse.ArgumentParser()
parser.add_argument('--outf', default='./model/', help='folder to output images and model checkpoints') #模型保存路径
parser.add_argument('--net', default='./model/net.pth', help="path to netG (to continue training)") #模型加载路径
opt = parser.parse_args()
# 超参数设置
EPOCH = 8  #遍历数据集次数
BATCH_SIZE = 64   #批处理尺寸(batch_size)
LR = 0.001    #学习率
# 定义数据预处理方式
transform = transforms.ToTensor()
# 定义训练数据集
trainset = tv.datasets.MNIST(
  root='./data/',
  train=True,
  download=True,
  transform=transform)
# 定义训练批处理数据
trainloader = torch.utils.data.DataLoader(
  trainset,
  batch_size=BATCH_SIZE,
  shuffle=True,
  )
# 定义测试数据集
testset = tv.datasets.MNIST(
  root='./data/',
  train=False,
  download=True,
  transform=transform)
# 定义测试批处理数据
testloader = torch.utils.data.DataLoader(
  testset,
  batch_size=BATCH_SIZE,
  shuffle=False,
  )
# 定义损失函数loss function 和优化方式(采用SGD)
net = LeNet().to(device)
criterion = nn.CrossEntropyLoss() # 交叉熵损失函数,通常用于多分类问题上
optimizer = optim.SGD(net.parameters(), lr=LR, momentum=0.9)
# 训练
if __name__ == "__main__":
  for epoch in range(EPOCH):
    sum_loss = 0.0
    # 数据读取
    for i, data in enumerate(trainloader):
      inputs, labels = data
      inputs, labels = inputs.to(device), labels.to(device)
      # 梯度清零
      optimizer.zero_grad()
      # forward + backward
      outputs = net(inputs)
      loss = criterion(outputs, labels)
      loss.backward()
      optimizer.step()
      # 每训练100个batch打印一次平均loss
      sum_loss += loss.item()
      if i % 100 == 99:
        print('[%d, %d] loss: %.03f'
           % (epoch + 1, i + 1, sum_loss / 100))
        sum_loss = 0.0
    # 每跑完一次epoch测试一下准确率
    with torch.no_grad():
      correct = 0
      total = 0
      for data in testloader:
        images, labels = data
        images, labels = images.to(device), labels.to(device)
        outputs = net(images)
        # 取得分最高的那个类
        _, predicted = torch.max(outputs.data, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum()
      print('第%d个epoch的识别准确率为:%d%%' % (epoch + 1, (100 * correct / total)))
  #torch.save(net.state_dict(), '%s/net_%03d.pth' % (opt.outf, epoch + 1))

更多关于Python相关内容可查看本站专题:《Python数学运算技巧总结》、《Python图片操作技巧总结》、《Python数据结构与算法教程》、《Python函数使用技巧总结》、《Python字符串操作技巧汇总》及《Python入门与进阶经典教程》

希望本文所述对大家Python程序设计有所帮助。

Python 相关文章推荐
Python 返回汉字的汉语拼音
Feb 27 Python
Python使用urllib2获取网络资源实例讲解
Dec 02 Python
pycharm 使用心得(八)如何调用另一文件中的函数
Jun 06 Python
深入讨论Python函数的参数的默认值所引发的问题的原因
Mar 30 Python
python实现在windows服务中新建进程的方法
Jun 30 Python
利用 Monkey 命令操作屏幕快速滑动
Dec 07 Python
放弃 Python 转向 Go语言有人给出了 9 大理由
Oct 20 Python
Python 字典一个键对应多个值的方法
Sep 29 Python
使用Python获取爱奇艺电视剧弹幕数据的示例代码
Jan 12 Python
详解python的xlwings库读写excel操作总结
Feb 26 Python
Python超简单容易上手的画图工具库推荐
May 10 Python
利用Pycharm连接服务器的全过程记录
Jul 01 Python
使用matplotlib绘制图例标签中带有公式的图
Dec 13 #Python
Python实现将蓝底照片转化为白底照片功能完整实例
Dec 13 #Python
python多进程重复加载的解决方式
Dec 13 #Python
使用pyqt5 tablewidget 单元格设置正则表达式
Dec 13 #Python
Python代码块及缓存机制原理详解
Dec 13 #Python
Python3和pyqt5实现控件数据动态显示方式
Dec 13 #Python
python实现简单日志记录库glog的使用
Dec 13 #Python
You might like
PHP小技巧之JS和CSS优化工具Minify的使用方法
2014/05/19 PHP
PHP往XML中添加节点的方法
2015/03/12 PHP
php以fastCGI的方式运行时文件系统权限问题及解决方法
2015/05/11 PHP
php实现将base64格式图片保存在指定目录的方法
2016/10/13 PHP
浅谈php://filter的妙用
2019/03/05 PHP
PHP针对redis常用操作实例详解
2019/08/17 PHP
laravel实现图片上传预览,及编辑时可更换图片,并实时变化的例子
2019/11/14 PHP
jquery实现简单的拖拽效果实例兼容所有主流浏览器(优化篇)
2013/06/28 Javascript
JS随机生成不重复数据的实例方法
2013/07/17 Javascript
JS测试显示屏分辨率以及屏幕尺寸的方法
2013/11/22 Javascript
jQuery动态修改超链接地址的方法
2015/02/13 Javascript
javascript模块化简单解析
2016/04/07 Javascript
拥Bootstrap入怀——导航栏篇
2016/05/30 Javascript
JS采用绝对定位实现回到顶部效果完整实例
2016/06/20 Javascript
AngularJS入门教程之静态模板详解
2016/08/18 Javascript
JavaScript编码风格指南(中文版)
2016/08/26 Javascript
浅谈jQuery操作类数组的工具方法
2016/12/23 Javascript
Vue + element 实现多选框组并保存已选id集合的示例代码
2020/06/03 Javascript
[54:43]DOTA2-DPC中国联赛 正赛 CDEC vs Dynasty BO3 第一场 2月22日
2021/03/11 DOTA
跟老齐学Python之深入变量和引用对象
2014/09/24 Python
用Python编写生成树状结构的文件目录的脚本的教程
2015/05/04 Python
对Python中range()函数和list的比较
2018/04/19 Python
Python单向链表和双向链表原理与用法实例详解
2018/08/31 Python
python实现双色球随机选号
2020/01/01 Python
JupyterNotebook 输出窗口的显示效果调整实现
2020/09/22 Python
matplotlib阶梯图的实现(step())
2021/03/02 Python
HTML5 CSS3打造相册效果附源码下载
2014/06/16 HTML / CSS
Kathmandu英国网站:新西兰户外运动品牌
2017/03/27 全球购物
大学活动策划书范文
2014/01/10 职场文书
请假条标准格式规范
2014/04/10 职场文书
《少年王冕》教学反思
2014/04/11 职场文书
成绩单家长评语大全
2014/04/16 职场文书
反腐倡廉剖析材料
2014/09/30 职场文书
检察院起诉书
2015/05/20 职场文书
Python中如何处理常见报错
2022/01/18 Python
pandas中关于apply+lambda的应用
2022/02/28 Python