Pytorch实现的手写数字mnist识别功能完整示例


Posted in Python onDecember 13, 2019

本文实例讲述了Pytorch实现的手写数字mnist识别功能。分享给大家供大家参考,具体如下:

import torch
import torchvision as tv
import torchvision.transforms as transforms
import torch.nn as nn
import torch.optim as optim
import argparse
# 定义是否使用GPU
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# 定义网络结构
class LeNet(nn.Module):
  def __init__(self):
    super(LeNet, self).__init__()
    self.conv1 = nn.Sequential(   #input_size=(1*28*28)
      nn.Conv2d(1, 6, 5, 1, 2), #padding=2保证输入输出尺寸相同
      nn.ReLU(),   #input_size=(6*28*28)
      nn.MaxPool2d(kernel_size=2, stride=2),#output_size=(6*14*14)
    )
    self.conv2 = nn.Sequential(
      nn.Conv2d(6, 16, 5),
      nn.ReLU(),   #input_size=(16*10*10)
      nn.MaxPool2d(2, 2) #output_size=(16*5*5)
    )
    self.fc1 = nn.Sequential(
      nn.Linear(16 * 5 * 5, 120),
      nn.ReLU()
    )
    self.fc2 = nn.Sequential(
      nn.Linear(120, 84),
      nn.ReLU()
    )
    self.fc3 = nn.Linear(84, 10)
  # 定义前向传播过程,输入为x
  def forward(self, x):
    x = self.conv1(x)
    x = self.conv2(x)
    # nn.Linear()的输入输出都是维度为一的值,所以要把多维度的tensor展平成一维
    x = x.view(x.size()[0], -1)
    x = self.fc1(x)
    x = self.fc2(x)
    x = self.fc3(x)
    return x
#使得我们能够手动输入命令行参数,就是让风格变得和Linux命令行差不多
parser = argparse.ArgumentParser()
parser.add_argument('--outf', default='./model/', help='folder to output images and model checkpoints') #模型保存路径
parser.add_argument('--net', default='./model/net.pth', help="path to netG (to continue training)") #模型加载路径
opt = parser.parse_args()
# 超参数设置
EPOCH = 8  #遍历数据集次数
BATCH_SIZE = 64   #批处理尺寸(batch_size)
LR = 0.001    #学习率
# 定义数据预处理方式
transform = transforms.ToTensor()
# 定义训练数据集
trainset = tv.datasets.MNIST(
  root='./data/',
  train=True,
  download=True,
  transform=transform)
# 定义训练批处理数据
trainloader = torch.utils.data.DataLoader(
  trainset,
  batch_size=BATCH_SIZE,
  shuffle=True,
  )
# 定义测试数据集
testset = tv.datasets.MNIST(
  root='./data/',
  train=False,
  download=True,
  transform=transform)
# 定义测试批处理数据
testloader = torch.utils.data.DataLoader(
  testset,
  batch_size=BATCH_SIZE,
  shuffle=False,
  )
# 定义损失函数loss function 和优化方式(采用SGD)
net = LeNet().to(device)
criterion = nn.CrossEntropyLoss() # 交叉熵损失函数,通常用于多分类问题上
optimizer = optim.SGD(net.parameters(), lr=LR, momentum=0.9)
# 训练
if __name__ == "__main__":
  for epoch in range(EPOCH):
    sum_loss = 0.0
    # 数据读取
    for i, data in enumerate(trainloader):
      inputs, labels = data
      inputs, labels = inputs.to(device), labels.to(device)
      # 梯度清零
      optimizer.zero_grad()
      # forward + backward
      outputs = net(inputs)
      loss = criterion(outputs, labels)
      loss.backward()
      optimizer.step()
      # 每训练100个batch打印一次平均loss
      sum_loss += loss.item()
      if i % 100 == 99:
        print('[%d, %d] loss: %.03f'
           % (epoch + 1, i + 1, sum_loss / 100))
        sum_loss = 0.0
    # 每跑完一次epoch测试一下准确率
    with torch.no_grad():
      correct = 0
      total = 0
      for data in testloader:
        images, labels = data
        images, labels = images.to(device), labels.to(device)
        outputs = net(images)
        # 取得分最高的那个类
        _, predicted = torch.max(outputs.data, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum()
      print('第%d个epoch的识别准确率为:%d%%' % (epoch + 1, (100 * correct / total)))
  #torch.save(net.state_dict(), '%s/net_%03d.pth' % (opt.outf, epoch + 1))

更多关于Python相关内容可查看本站专题:《Python数学运算技巧总结》、《Python图片操作技巧总结》、《Python数据结构与算法教程》、《Python函数使用技巧总结》、《Python字符串操作技巧汇总》及《Python入门与进阶经典教程》

希望本文所述对大家Python程序设计有所帮助。

Python 相关文章推荐
对Python定时任务的启动和停止方法详解
Feb 19 Python
Python3.5装饰器原理及应用实例详解
Apr 30 Python
Python实现Linux监控的方法
May 16 Python
Python操作SQLite数据库过程解析
Sep 02 Python
Django之使用celery和NGINX生成静态页面实现性能优化
Oct 08 Python
Python request操作步骤及代码实例
Apr 13 Python
VScode连接远程服务器上的jupyter notebook的实现
Apr 23 Python
使用已经得到的keras模型识别自己手写的数字方式
Jun 29 Python
利于python脚本编写可视化nmap和masscan的方法
Dec 29 Python
python 视频下载神器(you-get)的具体使用
Jan 06 Python
pytorch损失反向传播后梯度为none的问题
May 12 Python
Pytorch 实现变量类型转换
May 17 Python
使用matplotlib绘制图例标签中带有公式的图
Dec 13 #Python
Python实现将蓝底照片转化为白底照片功能完整实例
Dec 13 #Python
python多进程重复加载的解决方式
Dec 13 #Python
使用pyqt5 tablewidget 单元格设置正则表达式
Dec 13 #Python
Python代码块及缓存机制原理详解
Dec 13 #Python
Python3和pyqt5实现控件数据动态显示方式
Dec 13 #Python
python实现简单日志记录库glog的使用
Dec 13 #Python
You might like
判断PHP数组是否为空的代码
2011/09/08 PHP
基于php使用memcache存储session的详解
2013/06/25 PHP
一个简单的PHP验证码实现代码
2014/05/10 PHP
PHP处理CSV表格文件的常用操作方法总结
2016/07/01 PHP
PHP7创建销毁session的实例方法
2020/02/03 PHP
JavaScript如何从listbox里同时删除多个项目
2013/10/12 Javascript
jquery中map函数与each函数的区别实例介绍
2014/06/23 Javascript
jQuery CSS()方法改变现有的CSS样式
2014/08/20 Javascript
php,js,css字符串截取的办法集锦
2014/09/26 Javascript
浅谈window对象的scrollBy()方法
2015/07/15 Javascript
jQuery横向擦除焦点图特效代码分享
2015/09/06 Javascript
微信小程序 后台https域名绑定和免费的https证书申请详解
2016/11/10 Javascript
ajax的分页查询示例(不刷新页面)
2017/01/11 Javascript
javascript完美实现给定日期返回上月日期的方法
2017/06/15 Javascript
基于JavaScript实现多级菜单效果
2017/07/25 Javascript
Angular4开发解决跨域问题详解
2017/08/28 Javascript
node.js微信小程序配置消息推送的实现
2019/02/13 Javascript
jQuery AJAX与jQuery事件的分析讲解
2019/02/18 jQuery
微信小程序 wx.getUserInfo引导用户授权问题实例分析
2020/03/09 Javascript
微信小程序input抖动问题的修复方法
2021/03/03 Javascript
python计算两个矩形框重合百分比的实例
2018/11/07 Python
python利用插值法对折线进行平滑曲线处理
2018/12/25 Python
Pyqt5自适应布局实例
2019/12/13 Python
Django Serializer HiddenField隐藏字段实例
2020/03/31 Python
社区版pycharm创建django项目的方法(pycharm的newproject左侧没有项目选项)
2020/09/23 Python
canvas像素画板的实现代码
2018/11/21 HTML / CSS
计算机本科生自荐信
2013/10/15 职场文书
婚礼答谢宴主持词
2014/03/14 职场文书
租车协议书范本2014
2014/11/17 职场文书
解除同居协议书
2015/01/29 职场文书
技术员岗位职责
2015/02/04 职场文书
创业计划书之香辣虾火锅
2019/09/23 职场文书
如何判断微信付款码和支付宝付款码
2021/04/01 PHP
go 原生http web 服务跨域restful api的写法介绍
2021/04/27 Golang
Python基本数据类型之字符串str
2021/07/21 Python
Golang 1.18 多模块Multi-Module工作区模式的新特性
2022/04/11 Golang