Pytorch实现的手写数字mnist识别功能完整示例


Posted in Python onDecember 13, 2019

本文实例讲述了Pytorch实现的手写数字mnist识别功能。分享给大家供大家参考,具体如下:

import torch
import torchvision as tv
import torchvision.transforms as transforms
import torch.nn as nn
import torch.optim as optim
import argparse
# 定义是否使用GPU
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# 定义网络结构
class LeNet(nn.Module):
  def __init__(self):
    super(LeNet, self).__init__()
    self.conv1 = nn.Sequential(   #input_size=(1*28*28)
      nn.Conv2d(1, 6, 5, 1, 2), #padding=2保证输入输出尺寸相同
      nn.ReLU(),   #input_size=(6*28*28)
      nn.MaxPool2d(kernel_size=2, stride=2),#output_size=(6*14*14)
    )
    self.conv2 = nn.Sequential(
      nn.Conv2d(6, 16, 5),
      nn.ReLU(),   #input_size=(16*10*10)
      nn.MaxPool2d(2, 2) #output_size=(16*5*5)
    )
    self.fc1 = nn.Sequential(
      nn.Linear(16 * 5 * 5, 120),
      nn.ReLU()
    )
    self.fc2 = nn.Sequential(
      nn.Linear(120, 84),
      nn.ReLU()
    )
    self.fc3 = nn.Linear(84, 10)
  # 定义前向传播过程,输入为x
  def forward(self, x):
    x = self.conv1(x)
    x = self.conv2(x)
    # nn.Linear()的输入输出都是维度为一的值,所以要把多维度的tensor展平成一维
    x = x.view(x.size()[0], -1)
    x = self.fc1(x)
    x = self.fc2(x)
    x = self.fc3(x)
    return x
#使得我们能够手动输入命令行参数,就是让风格变得和Linux命令行差不多
parser = argparse.ArgumentParser()
parser.add_argument('--outf', default='./model/', help='folder to output images and model checkpoints') #模型保存路径
parser.add_argument('--net', default='./model/net.pth', help="path to netG (to continue training)") #模型加载路径
opt = parser.parse_args()
# 超参数设置
EPOCH = 8  #遍历数据集次数
BATCH_SIZE = 64   #批处理尺寸(batch_size)
LR = 0.001    #学习率
# 定义数据预处理方式
transform = transforms.ToTensor()
# 定义训练数据集
trainset = tv.datasets.MNIST(
  root='./data/',
  train=True,
  download=True,
  transform=transform)
# 定义训练批处理数据
trainloader = torch.utils.data.DataLoader(
  trainset,
  batch_size=BATCH_SIZE,
  shuffle=True,
  )
# 定义测试数据集
testset = tv.datasets.MNIST(
  root='./data/',
  train=False,
  download=True,
  transform=transform)
# 定义测试批处理数据
testloader = torch.utils.data.DataLoader(
  testset,
  batch_size=BATCH_SIZE,
  shuffle=False,
  )
# 定义损失函数loss function 和优化方式(采用SGD)
net = LeNet().to(device)
criterion = nn.CrossEntropyLoss() # 交叉熵损失函数,通常用于多分类问题上
optimizer = optim.SGD(net.parameters(), lr=LR, momentum=0.9)
# 训练
if __name__ == "__main__":
  for epoch in range(EPOCH):
    sum_loss = 0.0
    # 数据读取
    for i, data in enumerate(trainloader):
      inputs, labels = data
      inputs, labels = inputs.to(device), labels.to(device)
      # 梯度清零
      optimizer.zero_grad()
      # forward + backward
      outputs = net(inputs)
      loss = criterion(outputs, labels)
      loss.backward()
      optimizer.step()
      # 每训练100个batch打印一次平均loss
      sum_loss += loss.item()
      if i % 100 == 99:
        print('[%d, %d] loss: %.03f'
           % (epoch + 1, i + 1, sum_loss / 100))
        sum_loss = 0.0
    # 每跑完一次epoch测试一下准确率
    with torch.no_grad():
      correct = 0
      total = 0
      for data in testloader:
        images, labels = data
        images, labels = images.to(device), labels.to(device)
        outputs = net(images)
        # 取得分最高的那个类
        _, predicted = torch.max(outputs.data, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum()
      print('第%d个epoch的识别准确率为:%d%%' % (epoch + 1, (100 * correct / total)))
  #torch.save(net.state_dict(), '%s/net_%03d.pth' % (opt.outf, epoch + 1))

更多关于Python相关内容可查看本站专题:《Python数学运算技巧总结》、《Python图片操作技巧总结》、《Python数据结构与算法教程》、《Python函数使用技巧总结》、《Python字符串操作技巧汇总》及《Python入门与进阶经典教程》

希望本文所述对大家Python程序设计有所帮助。

Python 相关文章推荐
python获取文件后缀名及批量更新目录下文件后缀名的方法
Nov 11 Python
Python缩进和冒号详解
Jun 01 Python
在Python运行时动态查看进程内部信息的方法
Feb 22 Python
Python 网络编程之TCP客户端/服务端功能示例【基于socket套接字】
Oct 12 Python
python实现通过队列完成进程间的多任务功能示例
Oct 28 Python
使用python实现多维数据降维操作
Feb 24 Python
使用Python操作MySQL的小技巧
Sep 10 Python
python实现图像随机裁剪的示例代码
Dec 10 Python
详解python的变量缓存机制
Jan 24 Python
Python实现随机爬山算法
Jan 29 Python
用python画城市轮播地图
May 28 Python
python用tkinter开发的扫雷游戏
Jun 01 Python
使用matplotlib绘制图例标签中带有公式的图
Dec 13 #Python
Python实现将蓝底照片转化为白底照片功能完整实例
Dec 13 #Python
python多进程重复加载的解决方式
Dec 13 #Python
使用pyqt5 tablewidget 单元格设置正则表达式
Dec 13 #Python
Python代码块及缓存机制原理详解
Dec 13 #Python
Python3和pyqt5实现控件数据动态显示方式
Dec 13 #Python
python实现简单日志记录库glog的使用
Dec 13 #Python
You might like
解析php中heredoc的使用方法
2013/06/17 PHP
PHP全局使用Laravel辅助函数dd
2019/12/26 PHP
把JS与CSS写在同一个文件里的书写方法
2007/06/02 Javascript
JavaScript 滚轮事件使用说明
2010/03/07 Javascript
JavaScript:Div层拖动效果实例代码
2013/08/06 Javascript
JS实现点击链接取消跳转效果的方法
2014/01/24 Javascript
轻松创建nodejs服务器(1):一个简单nodejs服务器例子
2014/12/18 NodeJs
利用JavaScript的AngularJS库制作电子名片的方法
2015/06/18 Javascript
JavaScript面向对象分层思维全面解析
2016/11/22 Javascript
javascript+html5+css3自定义提示窗口
2017/06/21 Javascript
详解Node中导入模块require和import的区别
2017/08/11 Javascript
基于element-ui的rules中正则表达式
2018/09/04 Javascript
百度小程序自定义通用toast组件
2019/07/17 Javascript
layer的prompt弹出框,点击回车,触发确定事件的方法
2019/09/06 Javascript
关于ckeditor在bootstrap中modal中弹框无法输入的解决方法
2019/09/11 Javascript
在博客园博文中添加自定义右键菜单的方法详解
2020/02/05 Javascript
实例分析javascript中的异步
2020/06/02 Javascript
python中关于时间和日期函数的常用计算总结(time和datatime)
2013/03/08 Python
Python中pygame安装方法图文详解
2015/11/11 Python
总结网络IO模型与select模型的Python实例讲解
2016/06/27 Python
详解python基础之while循环及if判断
2017/08/24 Python
Python subprocess模块功能与常见用法实例详解
2018/06/28 Python
Python获取命令实时输出-原样彩色输出并返回输出结果的示例
2019/07/11 Python
在python image 中安装中文字体的实现方法
2019/08/22 Python
python定间隔取点(np.linspace)的实现
2019/11/27 Python
Python之Class&Object用法详解
2019/12/25 Python
wxPython修改文本框颜色过程解析
2020/02/14 Python
opencv python 图片读取与显示图片窗口未响应问题的解决
2020/04/24 Python
浅谈HTML5 Web Worker的使用
2018/01/05 HTML / CSS
Otticanet美国:最顶尖的世界名牌眼镜, 能得到打折季的价格
2019/03/10 全球购物
双立人美国官方商店:ZWILLING集团餐具和炊具
2020/05/07 全球购物
结婚喜宴家长答谢词
2014/01/15 职场文书
学习十八大坚定理想信念心得体会
2014/03/11 职场文书
水污染治理工程专业求职信
2014/06/14 职场文书
中国古代史学名著《战国策》概述
2019/08/09 职场文书
一文搞懂Redis中String数据类型
2022/04/03 Redis