PyTorch CNN实战之MNIST手写数字识别示例


Posted in Python onMay 29, 2018

简介

卷积神经网络(Convolutional Neural Network, CNN)是深度学习技术中极具代表的网络结构之一,在图像处理领域取得了很大的成功,在国际标准的ImageNet数据集上,许多成功的模型都是基于CNN的。

卷积神经网络CNN的结构一般包含这几个层:

  1. 输入层:用于数据的输入
  2. 卷积层:使用卷积核进行特征提取和特征映射
  3. 激励层:由于卷积也是一种线性运算,因此需要增加非线性映射
  4. 池化层:进行下采样,对特征图稀疏处理,减少数据运算量。
  5. 全连接层:通常在CNN的尾部进行重新拟合,减少特征信息的损失
  6. 输出层:用于输出结果

PyTorch CNN实战之MNIST手写数字识别示例

PyTorch实战

本文选用上篇的数据集MNIST手写数字识别实践CNN。

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import datasets, transforms
from torch.autograd import Variable

# Training settings
batch_size = 64

# MNIST Dataset
train_dataset = datasets.MNIST(root='./data/',
                train=True,
                transform=transforms.ToTensor(),
                download=True)

test_dataset = datasets.MNIST(root='./data/',
               train=False,
               transform=transforms.ToTensor())

# Data Loader (Input Pipeline)
train_loader = torch.utils.data.DataLoader(dataset=train_dataset,
                      batch_size=batch_size,
                      shuffle=True)

test_loader = torch.utils.data.DataLoader(dataset=test_dataset,
                     batch_size=batch_size,
                     shuffle=False)


class Net(nn.Module):
  def __init__(self):
    super(Net, self).__init__()
    # 输入1通道,输出10通道,kernel 5*5
    self.conv1 = nn.Conv2d(1, 10, kernel_size=5)
    self.conv2 = nn.Conv2d(10, 20, kernel_size=5)
    self.mp = nn.MaxPool2d(2)
    # fully connect
    self.fc = nn.Linear(320, 10)

  def forward(self, x):
    # in_size = 64
    in_size = x.size(0) # one batch
    # x: 64*10*12*12
    x = F.relu(self.mp(self.conv1(x)))
    # x: 64*20*4*4
    x = F.relu(self.mp(self.conv2(x)))
    # x: 64*320
    x = x.view(in_size, -1) # flatten the tensor
    # x: 64*10
    x = self.fc(x)
    return F.log_softmax(x)


model = Net()

optimizer = optim.SGD(model.parameters(), lr=0.01, momentum=0.5)

def train(epoch):
  for batch_idx, (data, target) in enumerate(train_loader):
    data, target = Variable(data), Variable(target)
    optimizer.zero_grad()
    output = model(data)
    loss = F.nll_loss(output, target)
    loss.backward()
    optimizer.step()
    if batch_idx % 200 == 0:
      print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
        epoch, batch_idx * len(data), len(train_loader.dataset),
        100. * batch_idx / len(train_loader), loss.data[0]))


def test():
  test_loss = 0
  correct = 0
  for data, target in test_loader:
    data, target = Variable(data, volatile=True), Variable(target)
    output = model(data)
    # sum up batch loss
    test_loss += F.nll_loss(output, target, size_average=False).data[0]
    # get the index of the max log-probability
    pred = output.data.max(1, keepdim=True)[1]
    correct += pred.eq(target.data.view_as(pred)).cpu().sum()

  test_loss /= len(test_loader.dataset)
  print('\nTest set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n'.format(
    test_loss, correct, len(test_loader.dataset),
    100. * correct / len(test_loader.dataset)))


for epoch in range(1, 10):
  train(epoch)
  test()

输出结果:

Train Epoch: 1 [0/60000 (0%)]   Loss: 2.315724
Train Epoch: 1 [12800/60000 (21%)]  Loss: 1.931551
Train Epoch: 1 [25600/60000 (43%)]  Loss: 0.733935
Train Epoch: 1 [38400/60000 (64%)]  Loss: 0.165043
Train Epoch: 1 [51200/60000 (85%)]  Loss: 0.235188

Test set: Average loss: 0.1935, Accuracy: 9421/10000 (94%)

Train Epoch: 2 [0/60000 (0%)]   Loss: 0.333513
Train Epoch: 2 [12800/60000 (21%)]  Loss: 0.163156
Train Epoch: 2 [25600/60000 (43%)]  Loss: 0.213840
Train Epoch: 2 [38400/60000 (64%)]  Loss: 0.141114
Train Epoch: 2 [51200/60000 (85%)]  Loss: 0.128191

Test set: Average loss: 0.1180, Accuracy: 9645/10000 (96%)

Train Epoch: 3 [0/60000 (0%)]   Loss: 0.206469
Train Epoch: 3 [12800/60000 (21%)]  Loss: 0.234443
Train Epoch: 3 [25600/60000 (43%)]  Loss: 0.061048
Train Epoch: 3 [38400/60000 (64%)]  Loss: 0.192217
Train Epoch: 3 [51200/60000 (85%)]  Loss: 0.089190

Test set: Average loss: 0.0938, Accuracy: 9723/10000 (97%)

Train Epoch: 4 [0/60000 (0%)]   Loss: 0.086325
Train Epoch: 4 [12800/60000 (21%)]  Loss: 0.117741
Train Epoch: 4 [25600/60000 (43%)]  Loss: 0.188178
Train Epoch: 4 [38400/60000 (64%)]  Loss: 0.049807
Train Epoch: 4 [51200/60000 (85%)]  Loss: 0.174097

Test set: Average loss: 0.0743, Accuracy: 9767/10000 (98%)

Train Epoch: 5 [0/60000 (0%)]   Loss: 0.063171
Train Epoch: 5 [12800/60000 (21%)]  Loss: 0.061265
Train Epoch: 5 [25600/60000 (43%)]  Loss: 0.103549
Train Epoch: 5 [38400/60000 (64%)]  Loss: 0.019137
Train Epoch: 5 [51200/60000 (85%)]  Loss: 0.067103

Test set: Average loss: 0.0720, Accuracy: 9781/10000 (98%)

Train Epoch: 6 [0/60000 (0%)]   Loss: 0.069251
Train Epoch: 6 [12800/60000 (21%)]  Loss: 0.075502
Train Epoch: 6 [25600/60000 (43%)]  Loss: 0.052337
Train Epoch: 6 [38400/60000 (64%)]  Loss: 0.015375
Train Epoch: 6 [51200/60000 (85%)]  Loss: 0.028996

Test set: Average loss: 0.0694, Accuracy: 9783/10000 (98%)

Train Epoch: 7 [0/60000 (0%)]   Loss: 0.171613
Train Epoch: 7 [12800/60000 (21%)]  Loss: 0.078520
Train Epoch: 7 [25600/60000 (43%)]  Loss: 0.149186
Train Epoch: 7 [38400/60000 (64%)]  Loss: 0.026692
Train Epoch: 7 [51200/60000 (85%)]  Loss: 0.108824

Test set: Average loss: 0.0672, Accuracy: 9793/10000 (98%)

Train Epoch: 8 [0/60000 (0%)]   Loss: 0.029188
Train Epoch: 8 [12800/60000 (21%)]  Loss: 0.031202
Train Epoch: 8 [25600/60000 (43%)]  Loss: 0.194858
Train Epoch: 8 [38400/60000 (64%)]  Loss: 0.051497
Train Epoch: 8 [51200/60000 (85%)]  Loss: 0.024832

Test set: Average loss: 0.0535, Accuracy: 9837/10000 (98%)

Train Epoch: 9 [0/60000 (0%)]   Loss: 0.026706
Train Epoch: 9 [12800/60000 (21%)]  Loss: 0.057807
Train Epoch: 9 [25600/60000 (43%)]  Loss: 0.065225
Train Epoch: 9 [38400/60000 (64%)]  Loss: 0.037004
Train Epoch: 9 [51200/60000 (85%)]  Loss: 0.057822

Test set: Average loss: 0.0538, Accuracy: 9829/10000 (98%)

Process finished with exit code 0

参考:https://github.com/hunkim/PyTorchZeroToAll

以上就是本文的全部内容,希望对大家的学习有所帮助,也希望大家多多支持三水点靠木。

Python 相关文章推荐
python调用windows api锁定计算机示例
Apr 17 Python
python写入中英文字符串到文件的方法
May 06 Python
apache部署python程序出现503错误的解决方法
Jul 24 Python
pandas数据处理基础之筛选指定行或者指定列的数据
May 03 Python
python计算列表内各元素的个数实例
Jun 29 Python
Python实现定时自动关闭的tkinter窗口方法
Feb 16 Python
解决python中用matplotlib画多幅图时出现图形部分重叠的问题
Jul 07 Python
在pycharm下设置自己的个性模版方法
Jul 15 Python
python ftplib模块使用代码实例
Dec 31 Python
Python如何实现的二分查找算法
May 27 Python
Python使用matplotlib绘制圆形代码实例
May 27 Python
Python 恐龙跑跑小游戏实现流程
Feb 15 Python
Python根据指定日期计算后n天,前n天是哪一天的方法
May 29 #Python
python 将md5转为16字节的方法
May 29 #Python
python 利用栈和队列模拟递归的过程
May 29 #Python
查看django执行的sql语句及消耗时间的两种方法
May 29 #Python
让Django支持Sql Server作后端数据库的方法
May 29 #Python
Django 浅谈根据配置生成SQL语句的问题
May 29 #Python
django表单实现下拉框的示例讲解
May 29 #Python
You might like
Could not load type System.ServiceModel.Activation.HttpModule解决办法
2012/12/29 PHP
ThinkPHP连接数据库的方式汇总
2014/12/05 PHP
几个优化WordPress中JavaScript加载体验的插件介绍
2015/12/17 PHP
Laravel5.1 框架控制器基础用法实例分析
2020/01/04 PHP
js中switch case循环实例代码
2013/12/30 Javascript
JS小游戏之象棋暗棋源码详解
2014/09/25 Javascript
javascript实现checkbox复选框实例代码
2016/01/10 Javascript
EasyUI的doCellTip实现鼠标放到单元格上提示单元格内容
2016/08/24 Javascript
React Native实现简单的登录功能(推荐)
2016/09/19 Javascript
使用JavaScript实现表格编辑器(实例讲解)
2017/08/02 Javascript
使用bootstraptable插件实现表格记录的查询、分页、排序操作
2017/08/06 Javascript
vue 中directive功能的简单实现
2018/01/05 Javascript
Vue.js中的computed工作原理
2018/03/22 Javascript
详解小程序如何避免多次点击,重复触发事件
2019/04/08 Javascript
[08:54]DOTA2-DPC中国联赛 正赛 Aster vs LBZS 选手采访
2021/03/11 DOTA
Python实现二叉树结构与进行二叉树遍历的方法详解
2016/05/24 Python
深入理解NumPy简明教程---数组2
2016/12/17 Python
对python过滤器和lambda函数的用法详解
2019/01/21 Python
python_array[0][0]与array[0,0]的区别详解
2020/02/18 Python
Keras 在fit_generator训练方式中加入图像random_crop操作
2020/07/03 Python
html5 web本地存储将取代我们的cookie
2012/12/26 HTML / CSS
HTML5添加禁止缩放功能
2017/11/03 HTML / CSS
HTML5 解决苹果手机不能自动播放音乐问题
2017/12/27 HTML / CSS
美国百货齐全的精品网站,提供美式风格的产品:Overstock.com
2016/07/22 全球购物
应届毕业生就业自荐信
2013/10/26 职场文书
《小石潭记》教学反思
2014/02/13 职场文书
2014年迎新年活动方案
2014/02/19 职场文书
淘宝好评语大全
2014/05/05 职场文书
机关职员工作检讨书
2014/10/23 职场文书
投标承诺函格式
2015/01/21 职场文书
幼师自荐信范文
2015/03/06 职场文书
拔河比赛新闻稿
2015/07/17 职场文书
健康教育主题班会
2015/08/14 职场文书
python 办公自动化——基于pyqt5和openpyxl统计符合要求的名单
2021/05/25 Python
排查并解决MySQL生产库内存使用率高的报警
2022/04/11 MySQL
python模板入门教程之flask Jinja
2022/04/11 Python