PyTorch CNN实战之MNIST手写数字识别示例


Posted in Python onMay 29, 2018

简介

卷积神经网络(Convolutional Neural Network, CNN)是深度学习技术中极具代表的网络结构之一,在图像处理领域取得了很大的成功,在国际标准的ImageNet数据集上,许多成功的模型都是基于CNN的。

卷积神经网络CNN的结构一般包含这几个层:

  1. 输入层:用于数据的输入
  2. 卷积层:使用卷积核进行特征提取和特征映射
  3. 激励层:由于卷积也是一种线性运算,因此需要增加非线性映射
  4. 池化层:进行下采样,对特征图稀疏处理,减少数据运算量。
  5. 全连接层:通常在CNN的尾部进行重新拟合,减少特征信息的损失
  6. 输出层:用于输出结果

PyTorch CNN实战之MNIST手写数字识别示例

PyTorch实战

本文选用上篇的数据集MNIST手写数字识别实践CNN。

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import datasets, transforms
from torch.autograd import Variable

# Training settings
batch_size = 64

# MNIST Dataset
train_dataset = datasets.MNIST(root='./data/',
                train=True,
                transform=transforms.ToTensor(),
                download=True)

test_dataset = datasets.MNIST(root='./data/',
               train=False,
               transform=transforms.ToTensor())

# Data Loader (Input Pipeline)
train_loader = torch.utils.data.DataLoader(dataset=train_dataset,
                      batch_size=batch_size,
                      shuffle=True)

test_loader = torch.utils.data.DataLoader(dataset=test_dataset,
                     batch_size=batch_size,
                     shuffle=False)


class Net(nn.Module):
  def __init__(self):
    super(Net, self).__init__()
    # 输入1通道,输出10通道,kernel 5*5
    self.conv1 = nn.Conv2d(1, 10, kernel_size=5)
    self.conv2 = nn.Conv2d(10, 20, kernel_size=5)
    self.mp = nn.MaxPool2d(2)
    # fully connect
    self.fc = nn.Linear(320, 10)

  def forward(self, x):
    # in_size = 64
    in_size = x.size(0) # one batch
    # x: 64*10*12*12
    x = F.relu(self.mp(self.conv1(x)))
    # x: 64*20*4*4
    x = F.relu(self.mp(self.conv2(x)))
    # x: 64*320
    x = x.view(in_size, -1) # flatten the tensor
    # x: 64*10
    x = self.fc(x)
    return F.log_softmax(x)


model = Net()

optimizer = optim.SGD(model.parameters(), lr=0.01, momentum=0.5)

def train(epoch):
  for batch_idx, (data, target) in enumerate(train_loader):
    data, target = Variable(data), Variable(target)
    optimizer.zero_grad()
    output = model(data)
    loss = F.nll_loss(output, target)
    loss.backward()
    optimizer.step()
    if batch_idx % 200 == 0:
      print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
        epoch, batch_idx * len(data), len(train_loader.dataset),
        100. * batch_idx / len(train_loader), loss.data[0]))


def test():
  test_loss = 0
  correct = 0
  for data, target in test_loader:
    data, target = Variable(data, volatile=True), Variable(target)
    output = model(data)
    # sum up batch loss
    test_loss += F.nll_loss(output, target, size_average=False).data[0]
    # get the index of the max log-probability
    pred = output.data.max(1, keepdim=True)[1]
    correct += pred.eq(target.data.view_as(pred)).cpu().sum()

  test_loss /= len(test_loader.dataset)
  print('\nTest set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n'.format(
    test_loss, correct, len(test_loader.dataset),
    100. * correct / len(test_loader.dataset)))


for epoch in range(1, 10):
  train(epoch)
  test()

输出结果:

Train Epoch: 1 [0/60000 (0%)]   Loss: 2.315724
Train Epoch: 1 [12800/60000 (21%)]  Loss: 1.931551
Train Epoch: 1 [25600/60000 (43%)]  Loss: 0.733935
Train Epoch: 1 [38400/60000 (64%)]  Loss: 0.165043
Train Epoch: 1 [51200/60000 (85%)]  Loss: 0.235188

Test set: Average loss: 0.1935, Accuracy: 9421/10000 (94%)

Train Epoch: 2 [0/60000 (0%)]   Loss: 0.333513
Train Epoch: 2 [12800/60000 (21%)]  Loss: 0.163156
Train Epoch: 2 [25600/60000 (43%)]  Loss: 0.213840
Train Epoch: 2 [38400/60000 (64%)]  Loss: 0.141114
Train Epoch: 2 [51200/60000 (85%)]  Loss: 0.128191

Test set: Average loss: 0.1180, Accuracy: 9645/10000 (96%)

Train Epoch: 3 [0/60000 (0%)]   Loss: 0.206469
Train Epoch: 3 [12800/60000 (21%)]  Loss: 0.234443
Train Epoch: 3 [25600/60000 (43%)]  Loss: 0.061048
Train Epoch: 3 [38400/60000 (64%)]  Loss: 0.192217
Train Epoch: 3 [51200/60000 (85%)]  Loss: 0.089190

Test set: Average loss: 0.0938, Accuracy: 9723/10000 (97%)

Train Epoch: 4 [0/60000 (0%)]   Loss: 0.086325
Train Epoch: 4 [12800/60000 (21%)]  Loss: 0.117741
Train Epoch: 4 [25600/60000 (43%)]  Loss: 0.188178
Train Epoch: 4 [38400/60000 (64%)]  Loss: 0.049807
Train Epoch: 4 [51200/60000 (85%)]  Loss: 0.174097

Test set: Average loss: 0.0743, Accuracy: 9767/10000 (98%)

Train Epoch: 5 [0/60000 (0%)]   Loss: 0.063171
Train Epoch: 5 [12800/60000 (21%)]  Loss: 0.061265
Train Epoch: 5 [25600/60000 (43%)]  Loss: 0.103549
Train Epoch: 5 [38400/60000 (64%)]  Loss: 0.019137
Train Epoch: 5 [51200/60000 (85%)]  Loss: 0.067103

Test set: Average loss: 0.0720, Accuracy: 9781/10000 (98%)

Train Epoch: 6 [0/60000 (0%)]   Loss: 0.069251
Train Epoch: 6 [12800/60000 (21%)]  Loss: 0.075502
Train Epoch: 6 [25600/60000 (43%)]  Loss: 0.052337
Train Epoch: 6 [38400/60000 (64%)]  Loss: 0.015375
Train Epoch: 6 [51200/60000 (85%)]  Loss: 0.028996

Test set: Average loss: 0.0694, Accuracy: 9783/10000 (98%)

Train Epoch: 7 [0/60000 (0%)]   Loss: 0.171613
Train Epoch: 7 [12800/60000 (21%)]  Loss: 0.078520
Train Epoch: 7 [25600/60000 (43%)]  Loss: 0.149186
Train Epoch: 7 [38400/60000 (64%)]  Loss: 0.026692
Train Epoch: 7 [51200/60000 (85%)]  Loss: 0.108824

Test set: Average loss: 0.0672, Accuracy: 9793/10000 (98%)

Train Epoch: 8 [0/60000 (0%)]   Loss: 0.029188
Train Epoch: 8 [12800/60000 (21%)]  Loss: 0.031202
Train Epoch: 8 [25600/60000 (43%)]  Loss: 0.194858
Train Epoch: 8 [38400/60000 (64%)]  Loss: 0.051497
Train Epoch: 8 [51200/60000 (85%)]  Loss: 0.024832

Test set: Average loss: 0.0535, Accuracy: 9837/10000 (98%)

Train Epoch: 9 [0/60000 (0%)]   Loss: 0.026706
Train Epoch: 9 [12800/60000 (21%)]  Loss: 0.057807
Train Epoch: 9 [25600/60000 (43%)]  Loss: 0.065225
Train Epoch: 9 [38400/60000 (64%)]  Loss: 0.037004
Train Epoch: 9 [51200/60000 (85%)]  Loss: 0.057822

Test set: Average loss: 0.0538, Accuracy: 9829/10000 (98%)

Process finished with exit code 0

参考:https://github.com/hunkim/PyTorchZeroToAll

以上就是本文的全部内容,希望对大家的学习有所帮助,也希望大家多多支持三水点靠木。

Python 相关文章推荐
Python随机生成手机号、数字的方法详解
Jul 21 Python
解读python logging模块的使用方法
Apr 17 Python
Python使用sort和class实现的多级排序功能示例
Aug 15 Python
Mac下Anaconda的安装和使用教程
Nov 29 Python
完美解决Python matplotlib绘图时汉字显示不正常的问题
Jan 29 Python
pycharm中显示CSS提示的知识点总结
Jul 29 Python
python实现身份证实名认证的方法实例
Nov 08 Python
Pytorch使用MNIST数据集实现CGAN和生成指定的数字方式
Jan 10 Python
pytorch构建多模型实例
Jan 15 Python
Python爬取网页信息的示例
Sep 24 Python
pymysql模块使用简介与示例
Nov 17 Python
Django使用django-simple-captcha做验证码的实现示例
Jan 07 Python
Python根据指定日期计算后n天,前n天是哪一天的方法
May 29 #Python
python 将md5转为16字节的方法
May 29 #Python
python 利用栈和队列模拟递归的过程
May 29 #Python
查看django执行的sql语句及消耗时间的两种方法
May 29 #Python
让Django支持Sql Server作后端数据库的方法
May 29 #Python
Django 浅谈根据配置生成SQL语句的问题
May 29 #Python
django表单实现下拉框的示例讲解
May 29 #Python
You might like
YII Framework框架教程之缓存用法详解
2016/03/14 PHP
PHP基于DOM创建xml文档的方法示例
2017/02/08 PHP
Jquery带搜索框的下拉菜单
2013/05/06 Javascript
javascript 按键事件(兼容各浏览器)
2013/12/20 Javascript
jQuery对Select的操作大集合(收藏)
2013/12/28 Javascript
一个jquery实现的不错的多行文字图片滚动效果
2014/09/28 Javascript
js实现照片墙功能实例
2015/02/05 Javascript
深入理解JavaScript函数参数(推荐)
2016/07/26 Javascript
JavaScript 链式结构序列化详解
2016/09/30 Javascript
谈谈JS中常遇到的浏览器兼容问题和解决方法
2016/12/17 Javascript
jquery实现轮播图效果
2017/02/13 Javascript
详解Node.js读写中文内容文件操作
2018/10/10 Javascript
微信小程序自定义多列选择器使用详解
2019/06/21 Javascript
中高级前端必须了解的JS中的内存管理(推荐)
2019/07/04 Javascript
Vue 解决多级动态面包屑导航的问题
2019/11/04 Javascript
javascript History对象原理解析
2020/02/17 Javascript
vue+element-ui表格封装tag标签使用插槽
2020/06/18 Javascript
[19:24]DOTA2客户端使用指南 一分钟快速设置轻松超神
2013/09/24 DOTA
[43:24]VG vs Serenity 2018国际邀请赛小组赛BO2 第二场 8.17
2018/08/20 DOTA
对于Python的Django框架使用的一些实用建议
2015/04/03 Python
python中日志logging模块的性能及多进程详解
2017/07/18 Python
Python网络编程基于多线程实现多用户全双工聊天功能示例
2018/04/10 Python
django rest framework 数据的查找、过滤、排序的示例
2018/06/25 Python
使用pandas read_table读取csv文件的方法
2018/07/04 Python
python Pexpect 实现输密码 scp 拷贝的方法
2019/01/03 Python
Django-Model数据库操作(增删改查、连表结构)详解
2019/07/17 Python
如何配置关联Python 解释器 Anaconda的教程(图解)
2020/04/30 Python
加拿大休闲和工业服装和鞋类零售商:L’Équipeur
2018/01/12 全球购物
葡萄牙航空官方网站:TAP Air Portugal
2019/10/31 全球购物
幼儿教师个人求职信范文
2013/09/21 职场文书
我的梦中国梦演讲稿
2014/04/23 职场文书
2014年留守儿童工作总结
2014/12/10 职场文书
2016年五四青年节校园广播稿
2015/12/17 职场文书
个人的事迹材料怎么写
2019/04/24 职场文书
nginx常用配置conf的示例代码详解
2022/03/21 Servers
Python识别花卉种类鉴定网络热门植物并自动整理分类
2022/04/08 Python