PyTorch CNN实战之MNIST手写数字识别示例


Posted in Python onMay 29, 2018

简介

卷积神经网络(Convolutional Neural Network, CNN)是深度学习技术中极具代表的网络结构之一,在图像处理领域取得了很大的成功,在国际标准的ImageNet数据集上,许多成功的模型都是基于CNN的。

卷积神经网络CNN的结构一般包含这几个层:

  1. 输入层:用于数据的输入
  2. 卷积层:使用卷积核进行特征提取和特征映射
  3. 激励层:由于卷积也是一种线性运算,因此需要增加非线性映射
  4. 池化层:进行下采样,对特征图稀疏处理,减少数据运算量。
  5. 全连接层:通常在CNN的尾部进行重新拟合,减少特征信息的损失
  6. 输出层:用于输出结果

PyTorch CNN实战之MNIST手写数字识别示例

PyTorch实战

本文选用上篇的数据集MNIST手写数字识别实践CNN。

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import datasets, transforms
from torch.autograd import Variable

# Training settings
batch_size = 64

# MNIST Dataset
train_dataset = datasets.MNIST(root='./data/',
                train=True,
                transform=transforms.ToTensor(),
                download=True)

test_dataset = datasets.MNIST(root='./data/',
               train=False,
               transform=transforms.ToTensor())

# Data Loader (Input Pipeline)
train_loader = torch.utils.data.DataLoader(dataset=train_dataset,
                      batch_size=batch_size,
                      shuffle=True)

test_loader = torch.utils.data.DataLoader(dataset=test_dataset,
                     batch_size=batch_size,
                     shuffle=False)


class Net(nn.Module):
  def __init__(self):
    super(Net, self).__init__()
    # 输入1通道,输出10通道,kernel 5*5
    self.conv1 = nn.Conv2d(1, 10, kernel_size=5)
    self.conv2 = nn.Conv2d(10, 20, kernel_size=5)
    self.mp = nn.MaxPool2d(2)
    # fully connect
    self.fc = nn.Linear(320, 10)

  def forward(self, x):
    # in_size = 64
    in_size = x.size(0) # one batch
    # x: 64*10*12*12
    x = F.relu(self.mp(self.conv1(x)))
    # x: 64*20*4*4
    x = F.relu(self.mp(self.conv2(x)))
    # x: 64*320
    x = x.view(in_size, -1) # flatten the tensor
    # x: 64*10
    x = self.fc(x)
    return F.log_softmax(x)


model = Net()

optimizer = optim.SGD(model.parameters(), lr=0.01, momentum=0.5)

def train(epoch):
  for batch_idx, (data, target) in enumerate(train_loader):
    data, target = Variable(data), Variable(target)
    optimizer.zero_grad()
    output = model(data)
    loss = F.nll_loss(output, target)
    loss.backward()
    optimizer.step()
    if batch_idx % 200 == 0:
      print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
        epoch, batch_idx * len(data), len(train_loader.dataset),
        100. * batch_idx / len(train_loader), loss.data[0]))


def test():
  test_loss = 0
  correct = 0
  for data, target in test_loader:
    data, target = Variable(data, volatile=True), Variable(target)
    output = model(data)
    # sum up batch loss
    test_loss += F.nll_loss(output, target, size_average=False).data[0]
    # get the index of the max log-probability
    pred = output.data.max(1, keepdim=True)[1]
    correct += pred.eq(target.data.view_as(pred)).cpu().sum()

  test_loss /= len(test_loader.dataset)
  print('\nTest set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n'.format(
    test_loss, correct, len(test_loader.dataset),
    100. * correct / len(test_loader.dataset)))


for epoch in range(1, 10):
  train(epoch)
  test()

输出结果:

Train Epoch: 1 [0/60000 (0%)]   Loss: 2.315724
Train Epoch: 1 [12800/60000 (21%)]  Loss: 1.931551
Train Epoch: 1 [25600/60000 (43%)]  Loss: 0.733935
Train Epoch: 1 [38400/60000 (64%)]  Loss: 0.165043
Train Epoch: 1 [51200/60000 (85%)]  Loss: 0.235188

Test set: Average loss: 0.1935, Accuracy: 9421/10000 (94%)

Train Epoch: 2 [0/60000 (0%)]   Loss: 0.333513
Train Epoch: 2 [12800/60000 (21%)]  Loss: 0.163156
Train Epoch: 2 [25600/60000 (43%)]  Loss: 0.213840
Train Epoch: 2 [38400/60000 (64%)]  Loss: 0.141114
Train Epoch: 2 [51200/60000 (85%)]  Loss: 0.128191

Test set: Average loss: 0.1180, Accuracy: 9645/10000 (96%)

Train Epoch: 3 [0/60000 (0%)]   Loss: 0.206469
Train Epoch: 3 [12800/60000 (21%)]  Loss: 0.234443
Train Epoch: 3 [25600/60000 (43%)]  Loss: 0.061048
Train Epoch: 3 [38400/60000 (64%)]  Loss: 0.192217
Train Epoch: 3 [51200/60000 (85%)]  Loss: 0.089190

Test set: Average loss: 0.0938, Accuracy: 9723/10000 (97%)

Train Epoch: 4 [0/60000 (0%)]   Loss: 0.086325
Train Epoch: 4 [12800/60000 (21%)]  Loss: 0.117741
Train Epoch: 4 [25600/60000 (43%)]  Loss: 0.188178
Train Epoch: 4 [38400/60000 (64%)]  Loss: 0.049807
Train Epoch: 4 [51200/60000 (85%)]  Loss: 0.174097

Test set: Average loss: 0.0743, Accuracy: 9767/10000 (98%)

Train Epoch: 5 [0/60000 (0%)]   Loss: 0.063171
Train Epoch: 5 [12800/60000 (21%)]  Loss: 0.061265
Train Epoch: 5 [25600/60000 (43%)]  Loss: 0.103549
Train Epoch: 5 [38400/60000 (64%)]  Loss: 0.019137
Train Epoch: 5 [51200/60000 (85%)]  Loss: 0.067103

Test set: Average loss: 0.0720, Accuracy: 9781/10000 (98%)

Train Epoch: 6 [0/60000 (0%)]   Loss: 0.069251
Train Epoch: 6 [12800/60000 (21%)]  Loss: 0.075502
Train Epoch: 6 [25600/60000 (43%)]  Loss: 0.052337
Train Epoch: 6 [38400/60000 (64%)]  Loss: 0.015375
Train Epoch: 6 [51200/60000 (85%)]  Loss: 0.028996

Test set: Average loss: 0.0694, Accuracy: 9783/10000 (98%)

Train Epoch: 7 [0/60000 (0%)]   Loss: 0.171613
Train Epoch: 7 [12800/60000 (21%)]  Loss: 0.078520
Train Epoch: 7 [25600/60000 (43%)]  Loss: 0.149186
Train Epoch: 7 [38400/60000 (64%)]  Loss: 0.026692
Train Epoch: 7 [51200/60000 (85%)]  Loss: 0.108824

Test set: Average loss: 0.0672, Accuracy: 9793/10000 (98%)

Train Epoch: 8 [0/60000 (0%)]   Loss: 0.029188
Train Epoch: 8 [12800/60000 (21%)]  Loss: 0.031202
Train Epoch: 8 [25600/60000 (43%)]  Loss: 0.194858
Train Epoch: 8 [38400/60000 (64%)]  Loss: 0.051497
Train Epoch: 8 [51200/60000 (85%)]  Loss: 0.024832

Test set: Average loss: 0.0535, Accuracy: 9837/10000 (98%)

Train Epoch: 9 [0/60000 (0%)]   Loss: 0.026706
Train Epoch: 9 [12800/60000 (21%)]  Loss: 0.057807
Train Epoch: 9 [25600/60000 (43%)]  Loss: 0.065225
Train Epoch: 9 [38400/60000 (64%)]  Loss: 0.037004
Train Epoch: 9 [51200/60000 (85%)]  Loss: 0.057822

Test set: Average loss: 0.0538, Accuracy: 9829/10000 (98%)

Process finished with exit code 0

参考:https://github.com/hunkim/PyTorchZeroToAll

以上就是本文的全部内容,希望对大家的学习有所帮助,也希望大家多多支持三水点靠木。

Python 相关文章推荐
使用Python发送邮件附件以定时备份MySQL的教程
Apr 25 Python
Python编程中的文件读写及相关的文件对象方法讲解
Jan 19 Python
Python合并字典键值并去除重复元素的实例
Dec 18 Python
Python实现二维数组按照某行或列排序的方法【numpy lexsort】
Sep 22 Python
pycharm下打开、执行并调试scrapy爬虫程序的方法
Nov 29 Python
Python地图绘制实操详解
Mar 04 Python
python实现飞机大战小游戏
Nov 08 Python
浅谈python 中的 type(), dtype(), astype()的区别
Apr 09 Python
浅谈numpy中np.array()与np.asarray的区别以及.tolist
Jun 03 Python
python实现批量转换图片为黑白
Jun 16 Python
Python使用shutil模块实现文件拷贝
Jul 31 Python
利用python为PostgreSQL的表自动添加分区
Jan 18 Python
Python根据指定日期计算后n天,前n天是哪一天的方法
May 29 #Python
python 将md5转为16字节的方法
May 29 #Python
python 利用栈和队列模拟递归的过程
May 29 #Python
查看django执行的sql语句及消耗时间的两种方法
May 29 #Python
让Django支持Sql Server作后端数据库的方法
May 29 #Python
Django 浅谈根据配置生成SQL语句的问题
May 29 #Python
django表单实现下拉框的示例讲解
May 29 #Python
You might like
PHP mkdir()定义和用法
2009/01/14 PHP
无法载入 mcrypt 扩展,请检查 PHP 配置终极解决方案
2011/07/18 PHP
在PHP中利用wsdl创建标准webservice的实现代码
2011/12/07 PHP
PHPMailer邮件发送的实现代码
2013/05/04 PHP
PHP后台备份MySQL数据库的源码实例
2019/03/18 PHP
php设计模式之适配器模式原理、用法及注意事项详解
2019/09/24 PHP
javascript 操作文件 实现方法小结
2009/07/02 Javascript
JavaScript toFixed() 方法
2010/04/15 Javascript
返回页面顶部top按钮通过锚点实现(自写)
2013/08/30 Javascript
按钮接受回车事件的三种实现方法
2014/06/06 Javascript
JavaScript定时器和优化的取消定时器方法
2015/07/03 Javascript
微信小程序 使用canvas制作K线实例详解
2017/01/12 Javascript
angular2 ng2 @input和@output理解及示例
2017/10/10 Javascript
Vue2.0中集成UEditor富文本编辑器的方法
2018/03/03 Javascript
三种Webpack打包方式(小结)
2018/09/19 Javascript
angularJS1 url中携带参数的获取方法
2018/10/09 Javascript
从组件封装看Vue的作用域插槽的实现
2019/02/12 Javascript
迅速了解一下ES10中Object.fromEntries的用法使用
2019/03/05 Javascript
Node.js API详解之 zlib模块用法分析
2020/05/19 Javascript
uniapp微信小程序:key失效的解决方法
2021/01/20 Javascript
Pthon批量处理将pdb文件生成dssp文件
2015/06/21 Python
Python处理JSON数据并生成条形图
2016/08/05 Python
详解Python装饰器由浅入深
2016/12/09 Python
50行Python代码实现人脸检测功能
2018/01/23 Python
Python实现PS图像抽象画风效果的方法
2018/01/23 Python
python画图时设置分辨率和画布大小的实现(plt.figure())
2021/01/08 Python
将不规则的Python多维数组拉平到一维的方法实现
2021/01/11 Python
利用CSS3的transition属性实现滑动效果
2015/08/05 HTML / CSS
HTML5录音实践总结(Preact)
2020/05/07 HTML / CSS
施华洛世奇波兰官网:SWAROVSKI波兰
2019/06/18 全球购物
公司新员工的演讲稿注意事项
2014/01/01 职场文书
素质拓展感言
2014/01/29 职场文书
网络编辑求职信
2014/04/30 职场文书
软件工程毕业生自荐信
2014/07/04 职场文书
县人大领导班子四风对照检查材料思想汇报
2014/10/09 职场文书
招标保密承诺书
2015/01/20 职场文书