PyTorch CNN实战之MNIST手写数字识别示例


Posted in Python onMay 29, 2018

简介

卷积神经网络(Convolutional Neural Network, CNN)是深度学习技术中极具代表的网络结构之一,在图像处理领域取得了很大的成功,在国际标准的ImageNet数据集上,许多成功的模型都是基于CNN的。

卷积神经网络CNN的结构一般包含这几个层:

  1. 输入层:用于数据的输入
  2. 卷积层:使用卷积核进行特征提取和特征映射
  3. 激励层:由于卷积也是一种线性运算,因此需要增加非线性映射
  4. 池化层:进行下采样,对特征图稀疏处理,减少数据运算量。
  5. 全连接层:通常在CNN的尾部进行重新拟合,减少特征信息的损失
  6. 输出层:用于输出结果

PyTorch CNN实战之MNIST手写数字识别示例

PyTorch实战

本文选用上篇的数据集MNIST手写数字识别实践CNN。

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import datasets, transforms
from torch.autograd import Variable

# Training settings
batch_size = 64

# MNIST Dataset
train_dataset = datasets.MNIST(root='./data/',
                train=True,
                transform=transforms.ToTensor(),
                download=True)

test_dataset = datasets.MNIST(root='./data/',
               train=False,
               transform=transforms.ToTensor())

# Data Loader (Input Pipeline)
train_loader = torch.utils.data.DataLoader(dataset=train_dataset,
                      batch_size=batch_size,
                      shuffle=True)

test_loader = torch.utils.data.DataLoader(dataset=test_dataset,
                     batch_size=batch_size,
                     shuffle=False)


class Net(nn.Module):
  def __init__(self):
    super(Net, self).__init__()
    # 输入1通道,输出10通道,kernel 5*5
    self.conv1 = nn.Conv2d(1, 10, kernel_size=5)
    self.conv2 = nn.Conv2d(10, 20, kernel_size=5)
    self.mp = nn.MaxPool2d(2)
    # fully connect
    self.fc = nn.Linear(320, 10)

  def forward(self, x):
    # in_size = 64
    in_size = x.size(0) # one batch
    # x: 64*10*12*12
    x = F.relu(self.mp(self.conv1(x)))
    # x: 64*20*4*4
    x = F.relu(self.mp(self.conv2(x)))
    # x: 64*320
    x = x.view(in_size, -1) # flatten the tensor
    # x: 64*10
    x = self.fc(x)
    return F.log_softmax(x)


model = Net()

optimizer = optim.SGD(model.parameters(), lr=0.01, momentum=0.5)

def train(epoch):
  for batch_idx, (data, target) in enumerate(train_loader):
    data, target = Variable(data), Variable(target)
    optimizer.zero_grad()
    output = model(data)
    loss = F.nll_loss(output, target)
    loss.backward()
    optimizer.step()
    if batch_idx % 200 == 0:
      print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
        epoch, batch_idx * len(data), len(train_loader.dataset),
        100. * batch_idx / len(train_loader), loss.data[0]))


def test():
  test_loss = 0
  correct = 0
  for data, target in test_loader:
    data, target = Variable(data, volatile=True), Variable(target)
    output = model(data)
    # sum up batch loss
    test_loss += F.nll_loss(output, target, size_average=False).data[0]
    # get the index of the max log-probability
    pred = output.data.max(1, keepdim=True)[1]
    correct += pred.eq(target.data.view_as(pred)).cpu().sum()

  test_loss /= len(test_loader.dataset)
  print('\nTest set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n'.format(
    test_loss, correct, len(test_loader.dataset),
    100. * correct / len(test_loader.dataset)))


for epoch in range(1, 10):
  train(epoch)
  test()

输出结果:

Train Epoch: 1 [0/60000 (0%)]   Loss: 2.315724
Train Epoch: 1 [12800/60000 (21%)]  Loss: 1.931551
Train Epoch: 1 [25600/60000 (43%)]  Loss: 0.733935
Train Epoch: 1 [38400/60000 (64%)]  Loss: 0.165043
Train Epoch: 1 [51200/60000 (85%)]  Loss: 0.235188

Test set: Average loss: 0.1935, Accuracy: 9421/10000 (94%)

Train Epoch: 2 [0/60000 (0%)]   Loss: 0.333513
Train Epoch: 2 [12800/60000 (21%)]  Loss: 0.163156
Train Epoch: 2 [25600/60000 (43%)]  Loss: 0.213840
Train Epoch: 2 [38400/60000 (64%)]  Loss: 0.141114
Train Epoch: 2 [51200/60000 (85%)]  Loss: 0.128191

Test set: Average loss: 0.1180, Accuracy: 9645/10000 (96%)

Train Epoch: 3 [0/60000 (0%)]   Loss: 0.206469
Train Epoch: 3 [12800/60000 (21%)]  Loss: 0.234443
Train Epoch: 3 [25600/60000 (43%)]  Loss: 0.061048
Train Epoch: 3 [38400/60000 (64%)]  Loss: 0.192217
Train Epoch: 3 [51200/60000 (85%)]  Loss: 0.089190

Test set: Average loss: 0.0938, Accuracy: 9723/10000 (97%)

Train Epoch: 4 [0/60000 (0%)]   Loss: 0.086325
Train Epoch: 4 [12800/60000 (21%)]  Loss: 0.117741
Train Epoch: 4 [25600/60000 (43%)]  Loss: 0.188178
Train Epoch: 4 [38400/60000 (64%)]  Loss: 0.049807
Train Epoch: 4 [51200/60000 (85%)]  Loss: 0.174097

Test set: Average loss: 0.0743, Accuracy: 9767/10000 (98%)

Train Epoch: 5 [0/60000 (0%)]   Loss: 0.063171
Train Epoch: 5 [12800/60000 (21%)]  Loss: 0.061265
Train Epoch: 5 [25600/60000 (43%)]  Loss: 0.103549
Train Epoch: 5 [38400/60000 (64%)]  Loss: 0.019137
Train Epoch: 5 [51200/60000 (85%)]  Loss: 0.067103

Test set: Average loss: 0.0720, Accuracy: 9781/10000 (98%)

Train Epoch: 6 [0/60000 (0%)]   Loss: 0.069251
Train Epoch: 6 [12800/60000 (21%)]  Loss: 0.075502
Train Epoch: 6 [25600/60000 (43%)]  Loss: 0.052337
Train Epoch: 6 [38400/60000 (64%)]  Loss: 0.015375
Train Epoch: 6 [51200/60000 (85%)]  Loss: 0.028996

Test set: Average loss: 0.0694, Accuracy: 9783/10000 (98%)

Train Epoch: 7 [0/60000 (0%)]   Loss: 0.171613
Train Epoch: 7 [12800/60000 (21%)]  Loss: 0.078520
Train Epoch: 7 [25600/60000 (43%)]  Loss: 0.149186
Train Epoch: 7 [38400/60000 (64%)]  Loss: 0.026692
Train Epoch: 7 [51200/60000 (85%)]  Loss: 0.108824

Test set: Average loss: 0.0672, Accuracy: 9793/10000 (98%)

Train Epoch: 8 [0/60000 (0%)]   Loss: 0.029188
Train Epoch: 8 [12800/60000 (21%)]  Loss: 0.031202
Train Epoch: 8 [25600/60000 (43%)]  Loss: 0.194858
Train Epoch: 8 [38400/60000 (64%)]  Loss: 0.051497
Train Epoch: 8 [51200/60000 (85%)]  Loss: 0.024832

Test set: Average loss: 0.0535, Accuracy: 9837/10000 (98%)

Train Epoch: 9 [0/60000 (0%)]   Loss: 0.026706
Train Epoch: 9 [12800/60000 (21%)]  Loss: 0.057807
Train Epoch: 9 [25600/60000 (43%)]  Loss: 0.065225
Train Epoch: 9 [38400/60000 (64%)]  Loss: 0.037004
Train Epoch: 9 [51200/60000 (85%)]  Loss: 0.057822

Test set: Average loss: 0.0538, Accuracy: 9829/10000 (98%)

Process finished with exit code 0

参考:https://github.com/hunkim/PyTorchZeroToAll

以上就是本文的全部内容,希望对大家的学习有所帮助,也希望大家多多支持三水点靠木。

Python 相关文章推荐
举例讲解Python中的算数运算符的用法
May 13 Python
浅谈Python接口对json串的处理方法
Dec 19 Python
详解Python3序列赋值、序列解包
May 14 Python
计算机二级python学习教程(3) python语言基本数据类型
May 16 Python
微信小程序python用户认证的实现
Jul 29 Python
Python学习笔记之Break和Continue用法分析
Aug 14 Python
python 3.7.4 安装 opencv的教程
Oct 10 Python
django框架forms组件用法实例详解
Dec 10 Python
Python实现仿射密码的思路详解
Apr 23 Python
如何搭建pytorch环境的方法步骤
May 06 Python
实例讲解Python 迭代器与生成器
Jul 08 Python
Appium+Python实现简单的自动化登录测试的实现
Jan 26 Python
Python根据指定日期计算后n天,前n天是哪一天的方法
May 29 #Python
python 将md5转为16字节的方法
May 29 #Python
python 利用栈和队列模拟递归的过程
May 29 #Python
查看django执行的sql语句及消耗时间的两种方法
May 29 #Python
让Django支持Sql Server作后端数据库的方法
May 29 #Python
Django 浅谈根据配置生成SQL语句的问题
May 29 #Python
django表单实现下拉框的示例讲解
May 29 #Python
You might like
php控制linux服务器常用功能 关机 重启 开新站点等
2012/09/05 PHP
ecshop后台编辑器替换成ueditor编辑器
2015/03/03 PHP
Symfony2联合查询实现方法
2016/03/18 PHP
PHP最常用的正则表达式
2017/02/13 PHP
脚本合并提升javascript性能示例
2014/02/24 Javascript
javascript中的Base64、UTF8编码与解码详解
2015/03/18 Javascript
JavaScript显示表单内元素数量的方法
2015/04/02 Javascript
每天一篇javascript学习小结(面向对象编程)
2015/11/20 Javascript
AngularJS使用angular-formly进行表单验证
2015/12/27 Javascript
JS 滚动事件window.onscroll与position:fixed写兼容IE6的回到顶部组件
2016/10/10 Javascript
express文件上传中间件Multer详解
2016/10/24 Javascript
bootstrap-table实现表头固定以及列固定的方法示例
2019/03/07 Javascript
微信小程序基于movable-view实现滑动删除效果
2020/01/08 Javascript
[01:07:02]DOTA2-DPC中国联赛 正赛 iG vs PSG.LGD BO3 第三场 2月26日
2021/03/11 DOTA
使用Python求解最大公约数的实现方法
2015/08/20 Python
Tensorflow实现卷积神经网络的详细代码
2018/05/24 Python
Python Django框架单元测试之文件上传测试示例
2019/05/17 Python
python 判断字符串中是否含有汉字或非汉字的实例
2019/07/15 Python
在自动化中用python实现键盘操作的方法详解
2019/07/19 Python
python代码 FTP备份交换机配置脚本实例解析
2019/08/01 Python
python3爬虫中异步协程的用法
2020/07/10 Python
Sneaker Studio匈牙利:购买运动鞋
2018/03/26 全球购物
SISLEY希思黎官方旗舰店:享誉全球的奢华植物美容品牌
2018/04/25 全球购物
英国旅行箱包和行李箱购物网站:Travel Luggage & Cabin Bags
2019/08/26 全球购物
Ruby如何创建一个线程
2013/03/10 面试题
顺丰快递Java软件工程师面试题
2015/07/31 面试题
项目考察欢迎辞
2014/01/17 职场文书
财务部总监岗位职责
2014/03/12 职场文书
七匹狼男装广告词
2014/03/21 职场文书
加油口号大全
2014/06/13 职场文书
青年文明号口号
2014/06/17 职场文书
2015年大学学生会工作总结
2015/05/13 职场文书
英语演讲开场白
2015/05/29 职场文书
十二生肖观后感
2015/06/12 职场文书
该怎么书写道歉信?
2019/07/03 职场文书
Redis遍历所有key的两个命令(KEYS 和 SCAN)
2021/04/12 Redis