PyTorch上实现卷积神经网络CNN的方法


Posted in Python onApril 28, 2018

一、卷积神经网络

卷积神经网络(ConvolutionalNeuralNetwork,CNN)最初是为解决图像识别等问题设计的,CNN现在的应用已经不限于图像和视频,也可用于时间序列信号,比如音频信号和文本数据等。CNN作为一个深度学习架构被提出的最初诉求是降低对图像数据预处理的要求,避免复杂的特征工程。在卷积神经网络中,第一个卷积层会直接接受图像像素级的输入,每一层卷积(滤波器)都会提取数据中最有效的特征,这种方法可以提取到图像中最基础的特征,而后再进行组合和抽象形成更高阶的特征,因此CNN在理论上具有对图像缩放、平移和旋转的不变性。

卷积神经网络CNN的要点就是局部连接(LocalConnection)、权值共享(WeightsSharing)和池化层(Pooling)中的降采样(Down-Sampling)。其中,局部连接和权值共享降低了参数量,使训练复杂度大大下降并减轻了过拟合。同时权值共享还赋予了卷积网络对平移的容忍性,池化层降采样则进一步降低了输出参数量并赋予模型对轻度形变的容忍性,提高了模型的泛化能力。可以把卷积层卷积操作理解为用少量参数在图像的多个位置上提取相似特征的过程。

二、代码实现

import torch 
import torch.nn as nn 
from torch.autograd import Variable 
import torch.utils.data as Data 
import torchvision 
import matplotlib.pyplot as plt 
 
torch.manual_seed(1) 
 
EPOCH = 1 
BATCH_SIZE = 50 
LR = 0.001 
DOWNLOAD_MNIST = True 
 
# 获取训练集dataset 
training_data = torchvision.datasets.MNIST( 
       root='./mnist/', # dataset存储路径 
       train=True, # True表示是train训练集,False表示test测试集 
       transform=torchvision.transforms.ToTensor(), # 将原数据规范化到(0,1)区间 
       download=DOWNLOAD_MNIST, 
       ) 
 
# 打印MNIST数据集的训练集及测试集的尺寸 
print(training_data.train_data.size()) 
print(training_data.train_labels.size()) 
# torch.Size([60000, 28, 28]) 
# torch.Size([60000]) 
 
plt.imshow(training_data.train_data[0].numpy(), cmap='gray') 
plt.title('%i' % training_data.train_labels[0]) 
plt.show() 
 
# 通过torchvision.datasets获取的dataset格式可直接可置于DataLoader 
train_loader = Data.DataLoader(dataset=training_data, batch_size=BATCH_SIZE, 
                shuffle=True) 
 
# 获取测试集dataset 
test_data = torchvision.datasets.MNIST(root='./mnist/', train=False) 
# 取前2000个测试集样本 
test_x = Variable(torch.unsqueeze(test_data.test_data, dim=1), 
         volatile=True).type(torch.FloatTensor)[:2000]/255 
# (2000, 28, 28) to (2000, 1, 28, 28), in range(0,1) 
test_y = test_data.test_labels[:2000] 
 
class CNN(nn.Module): 
  def __init__(self): 
    super(CNN, self).__init__() 
    self.conv1 = nn.Sequential( # (1,28,28) 
           nn.Conv2d(in_channels=1, out_channels=16, kernel_size=5, 
                stride=1, padding=2), # (16,28,28) 
    # 想要con2d卷积出来的图片尺寸没有变化, padding=(kernel_size-1)/2 
           nn.ReLU(), 
           nn.MaxPool2d(kernel_size=2) # (16,14,14) 
           ) 
    self.conv2 = nn.Sequential( # (16,14,14) 
           nn.Conv2d(16, 32, 5, 1, 2), # (32,14,14) 
           nn.ReLU(), 
           nn.MaxPool2d(2) # (32,7,7) 
           ) 
    self.out = nn.Linear(32*7*7, 10) 
 
  def forward(self, x): 
    x = self.conv1(x) 
    x = self.conv2(x) 
    x = x.view(x.size(0), -1) # 将(batch,32,7,7)展平为(batch,32*7*7) 
    output = self.out(x) 
    return output 
 
cnn = CNN() 
print(cnn) 
''''' 
CNN ( 
 (conv1): Sequential ( 
  (0): Conv2d(1, 16, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2)) 
  (1): ReLU () 
  (2): MaxPool2d (size=(2, 2), stride=(2, 2), dilation=(1, 1)) 
 ) 
 (conv2): Sequential ( 
  (0): Conv2d(16, 32, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2)) 
  (1): ReLU () 
  (2): MaxPool2d (size=(2, 2), stride=(2, 2), dilation=(1, 1)) 
 ) 
 (out): Linear (1568 -> 10) 
) 
''' 
optimizer = torch.optim.Adam(cnn.parameters(), lr=LR) 
loss_function = nn.CrossEntropyLoss() 
 
for epoch in range(EPOCH): 
  for step, (x, y) in enumerate(train_loader): 
    b_x = Variable(x) 
    b_y = Variable(y) 
 
    output = cnn(b_x) 
    loss = loss_function(output, b_y) 
    optimizer.zero_grad() 
    loss.backward() 
    optimizer.step() 
 
    if step % 100 == 0: 
      test_output = cnn(test_x) 
      pred_y = torch.max(test_output, 1)[1].data.squeeze() 
      accuracy = sum(pred_y == test_y) / test_y.size(0) 
      print('Epoch:', epoch, '|Step:', step, 
         '|train loss:%.4f'%loss.data[0], '|test accuracy:%.4f'%accuracy) 
 
test_output = cnn(test_x[:10]) 
pred_y = torch.max(test_output, 1)[1].data.numpy().squeeze() 
print(pred_y, 'prediction number') 
print(test_y[:10].numpy(), 'real number') 
''''' 
Epoch: 0 |Step: 0 |train loss:2.3145 |test accuracy:0.1040 
Epoch: 0 |Step: 100 |train loss:0.5857 |test accuracy:0.8865 
Epoch: 0 |Step: 200 |train loss:0.0600 |test accuracy:0.9380 
Epoch: 0 |Step: 300 |train loss:0.0996 |test accuracy:0.9345 
Epoch: 0 |Step: 400 |train loss:0.0381 |test accuracy:0.9645 
Epoch: 0 |Step: 500 |train loss:0.0266 |test accuracy:0.9620 
Epoch: 0 |Step: 600 |train loss:0.0973 |test accuracy:0.9685 
Epoch: 0 |Step: 700 |train loss:0.0421 |test accuracy:0.9725 
Epoch: 0 |Step: 800 |train loss:0.0654 |test accuracy:0.9710 
Epoch: 0 |Step: 900 |train loss:0.1333 |test accuracy:0.9740 
Epoch: 0 |Step: 1000 |train loss:0.0289 |test accuracy:0.9720 
Epoch: 0 |Step: 1100 |train loss:0.0429 |test accuracy:0.9770 
[7 2 1 0 4 1 4 9 5 9] prediction number 
[7 2 1 0 4 1 4 9 5 9] real number 
'''

 三、分析解读

通过利用torchvision.datasets可以快速获取可以直接置于DataLoader中的dataset格式的数据,通过train参数控制是获取训练数据集还是测试数据集,也可以在获取的时候便直接转换成训练所需的数据格式。

卷积神经网络的搭建通过定义一个CNN类来实现,卷积层conv1,conv2及out层以类属性的形式定义,各层之间的衔接信息在forward中定义,定义的时候要留意各层的神经元数量。

CNN的网络结构如下:

CNN (

 (conv1): Sequential (

  (0): Conv2d(1, 16,kernel_size=(5, 5), stride=(1, 1), padding=(2, 2))

  (1): ReLU ()

  (2): MaxPool2d (size=(2,2), stride=(2, 2), dilation=(1, 1))

 )

 (conv2): Sequential (

  (0): Conv2d(16, 32,kernel_size=(5, 5), stride=(1, 1), padding=(2, 2))

  (1): ReLU ()

  (2): MaxPool2d (size=(2,2), stride=(2, 2), dilation=(1, 1))

 )

 (out): Linear (1568 ->10)

)

经过实验可见,在EPOCH=1的训练结果中,测试集准确率可达到97.7%。

以上就是本文的全部内容,希望对大家的学习有所帮助,也希望大家多多支持三水点靠木。

Python 相关文章推荐
Python中的pprint折腾记
Jan 21 Python
黑科技 Python脚本帮你找出微信上删除你好友的人
Jan 07 Python
Django小白教程之Django用户注册与登录
Apr 22 Python
python中json格式数据输出的简单实现方法
Oct 31 Python
Python 在字符串中加入变量的实例讲解
May 02 Python
python3.6.3+opencv3.3.0实现动态人脸捕获
May 25 Python
Pyqt5 关于流式布局和滚动条的综合使用示例代码
Mar 24 Python
Python如何实现爬取B站视频
May 20 Python
python查看矩阵的行列号以及维数方式
May 22 Python
Python unittest生成测试报告过程解析
Sep 08 Python
python 利用openpyxl读取Excel表格中指定的行或列教程
Feb 06 Python
Python如何使用神经网络进行简单文本分类
Feb 25 Python
python 日志增量抓取实现方法
Apr 28 #Python
Django 使用logging打印日志的实例
Apr 28 #Python
python实现log日志的示例代码
Apr 28 #Python
Python学习笔记之open()函数打开文件路径报错问题
Apr 28 #Python
Python之读取TXT文件的方法小结
Apr 27 #Python
如何利用python查找电脑文件
Apr 27 #Python
Python3 中把txt数据文件读入到矩阵中的方法
Apr 27 #Python
You might like
PHP中实现图片的锐化
2006/10/09 PHP
使用HMAC-SHA1签名方法详解
2013/06/26 PHP
PHP简单字符串过滤方法示例
2016/09/04 PHP
浅谈PHP错误类型及屏蔽方法
2017/05/27 PHP
js 绑定带参数的事件以及手动触发事件
2010/04/27 Javascript
css样式标签和js语法属性区别
2013/11/06 Javascript
将查询条件的input、select清空
2014/01/14 Javascript
取得元素的左和上偏移量的方法
2014/09/17 Javascript
js实现贪吃蛇小游戏(容易理解)
2017/01/22 Javascript
js数字计算 误差问题的快速解决方法
2017/02/28 Javascript
JavaScript实现微信红包算法及问题解决方法
2018/04/26 Javascript
vue获取元素宽、高、距离左边距离,右,上距离等还有XY坐标轴的方法
2018/09/05 Javascript
微信小程序用户授权,以及判断登录是否过期的方法
2019/05/10 Javascript
vue-cli webpack配置文件分析
2019/05/20 Javascript
微信小程序登录对接Django后端实现JWT方式验证登录详解
2019/07/29 Javascript
[37:02]OG vs INfamous 2019国际邀请赛小组赛 BO2 第二场 8.15
2019/08/17 DOTA
Python实现类继承实例
2014/07/04 Python
python实现旋转和水平翻转的方法
2018/10/25 Python
Python从数据库读取大量数据批量写入文件的方法
2018/12/10 Python
python 计算一个字符串中所有数字的和实例
2019/06/11 Python
python GUI库图形界面开发之PyQt5拖放控件实例详解
2020/02/25 Python
Django User 模块之 AbstractUser 扩展详解
2020/03/11 Python
关于python3.9安装wordcloud出错的问题及解决办法
2020/11/02 Python
python实现xml转json文件的示例代码
2020/12/30 Python
纯css3实现走马灯效果
2014/12/26 HTML / CSS
html5本地存储_动力节点Java学院整理
2017/07/12 HTML / CSS
AVON雅芳官网:世界上最大的美容化妆品公司之一
2016/11/02 全球购物
7 For All Mankind官网:美国加州洛杉矶的高级牛仔服装品牌
2018/12/20 全球购物
国际经济与贸易专业大学生职业规划书
2014/03/01 职场文书
2015元旦标语横幅
2014/12/09 职场文书
办公室规章制度范本
2015/08/04 职场文书
《当代神农氏》教学反思
2016/02/23 职场文书
古诗文之爱国名句(77句)
2019/09/24 职场文书
Ajax常用封装库——Axios的使用
2021/05/08 Javascript
原生JavaScript实现简单五子棋游戏
2021/06/28 Javascript
「魔导具师妲莉亚永不妥协~从今天开始的自由职人生活~」1、2卷发售宣传CM公开
2022/03/21 日漫