pytorch实现CNN卷积神经网络


Posted in Python onFebruary 19, 2020

本文为大家讲解了pytorch实现CNN卷积神经网络,供大家参考,具体内容如下

我对卷积神经网络的一些认识

    卷积神经网络是时下最为流行的一种深度学习网络,由于其具有局部感受野等特性,让其与人眼识别图像具有相似性,因此被广泛应用于图像识别中,本人是研究机械故障诊断方面的,一般利用旋转机械的振动信号作为数据。

    对一维信号,通常采取的方法有两种,第一,直接对其做一维卷积,第二,反映到时频图像上,这就变成了图像识别,此前一直都在利用keras搭建网络,最近学了pytroch搭建cnn的方法,进行一下代码的尝试。所用数据为经典的minist手写字体数据集

import torch
import torch.nn as nn
import torch.utils.data as Data
import torchvision
import matplotlib.pyplot as plt
`EPOCH = 1
BATCH_SIZE = 50
LR = 0.001
DOWNLOAD_MNIST = True

从网上下载数据集:

```python
train_data = torchvision.datasets.MNIST(
 root="./mnist/",
 train = True,
 transform=torchvision.transforms.ToTensor(),
 download = DOWNLOAD_MNIST,
)

print(train_data.train_data.size())
print(train_data.train_labels.size())

```plt.imshow(train_data.train_data[0].numpy(), cmap='autumn')
plt.title("%i" % train_data.train_labels[0])
plt.show()

train_loader = Data.DataLoader(dataset=train_data, batch_size=BATCH_SIZE, shuffle=True)

test_data = torchvision.datasets.MNIST(root="./mnist/", train=False)
test_x = torch.unsqueeze(test_data.test_data, dim=1).type(torch.FloatTensor)[:2000]/255.

test_y = test_data.test_labels[:2000]


class CNN(nn.Module):
 def __init__(self):
  super(CNN, self).__init__()
  self.conv1 = nn.Sequential(
   nn.Conv2d(
    in_channels=1,
    out_channels=16,
    kernel_size=5,
    stride=1,
    padding=2,
   ),
   
   nn.ReLU(),
   nn.MaxPool2d(kernel_size=2),
  )
  
  self.conv2 = nn.Sequential(
   nn.Conv2d(16, 32, 5, 1, 2),
   nn.ReLU(),
   nn.MaxPool2d(2),
  )
  
  self.out = nn.Linear(32*7*7, 10) # fully connected layer, output 10 classes
  
 def forward(self, x):
  x = self.conv1(x)
  x = self.conv2(x)
  x = x.view(x.size(0), -1) # flatten the output of conv2 to (batch_size, 32*7*7)
  output = self.out(x)
  return output
 
optimizer = torch.optim.Adam(cnn.parameters(), lr=LR)
loss_func = nn.CrossEntropyLoss()
 
 from matplotlib import cm
try: from sklearn.manifold import TSNE; HAS_SK = True
except: HAS_SK = False; print('Please install sklearn for layer visualization')
def plot_with_labels(lowDWeights, labels):
 plt.cla()
 X, Y = lowDWeights[:, 0], lowDWeights[:, 1]
 for x, y, s in zip(X, Y, labels):
  c = cm.rainbow(int(255 * s / 9)); plt.text(x, y, s, backgroundcolor=c, fontsize=9)
 plt.xlim(X.min(), X.max()); plt.ylim(Y.min(), Y.max()); plt.title('Visualize last layer'); plt.show(); plt.pause(0.01)

plt.ion()

for epoch in range(EPOCH):
 for step, (b_x, b_y) in enumerate(train_loader):
  output = cnn(b_x)
  loss = loss_func(output, b_y)
  optimizer.zero_grad()
  loss.backward()
  optimizer.step()
  
  if step % 50 == 0:
   test_output = cnn(test_x)
   pred_y = torch.max(test_output, 1)[1].data.numpy()
   accuracy = float((pred_y == test_y.data.numpy()).astype(int).sum()) / float(test_y.size(0))
   print("Epoch: ", epoch, "| train loss: %.4f" % loss.data.numpy(), 
     "| test accuracy: %.2f" % accuracy)
   
plt.ioff()

以上就是本文的全部内容,希望对大家的学习有所帮助,也希望大家多多支持三水点靠木。

Python 相关文章推荐
centos系统升级python 2.7.3
Jul 03 Python
python嵌套字典比较值与取值的实现示例
Nov 03 Python
Python实现定期检查源目录与备份目录的差异并进行备份功能示例
Feb 27 Python
Python实现字典按key或者value进行排序操作示例【sorted】
May 03 Python
详解numpy的argmax的具体使用
May 27 Python
浅析Python与Mongodb数据库之间的操作方法
Jul 01 Python
Django中的模型类设计及展示示例详解
May 29 Python
matplotlib教程——强大的python作图工具库
Oct 15 Python
快速创建python 虚拟环境
Nov 28 Python
python实现KNN近邻算法
Dec 30 Python
python Autopep8实现按PEP8风格自动排版Python代码
Mar 02 Python
Python办公自动化之教你用Python批量识别发票并录入到Excel表格中
Jun 26 Python
python+opencv3生成一个自定义纯色图教程
Feb 19 #Python
Python 实现Image和Ndarray互相转换
Feb 19 #Python
python3+opencv生成不规则黑白mask实例
Feb 19 #Python
使用celery和Django处理异步任务的流程分析
Feb 19 #Python
Python Numpy,mask图像的生成详解
Feb 19 #Python
浅谈图像处理中掩膜(mask)的意义
Feb 19 #Python
Python中logging日志库实例详解
Feb 19 #Python
You might like
php 文章采集正则代码
2009/12/28 PHP
php将url地址转化为完整的a标签链接代码(php为url地址添加a标签)
2014/01/17 PHP
php实现天干地支计算器示例
2014/03/14 PHP
PHP批量检测并去除文件BOM头代码实例
2014/05/08 PHP
PHP版本的选择5.2.17 5.3.27 5.3.28 5.4 5.5兼容性问题分析
2016/04/04 PHP
PHP基本语法实例总结
2016/09/09 PHP
CI框架无限级分类+递归的实现代码
2016/11/01 PHP
由php中字符offset特征造成的绕过漏洞详解
2017/07/07 PHP
PHP单例模式数据库连接类与页面静态化实现方法
2019/03/20 PHP
菜鸟javascript基础整理1
2010/12/06 Javascript
javascript获取隐藏dom的宽高 具体实现
2013/07/14 Javascript
ionic由于使用了header和subheader导致被遮挡的问题的两种解决方法
2016/09/22 Javascript
JS中常用的正则表达式
2016/09/29 Javascript
浅谈webpack打包过程中因为图片的路径导致的问题
2018/02/21 Javascript
详解Angular6.0使用路由步骤(共7步)
2018/06/29 Javascript
Node.js+Express+Mysql 实现增删改查
2019/04/03 Javascript
vue-test-utils初使用详解
2019/05/23 Javascript
vue+koa2实现session、token登陆状态验证的示例
2019/08/30 Javascript
JS如何实现在弹出窗口中加载页面
2020/12/03 Javascript
基于Vue3.0开发轻量级手机端弹框组件V3Popup的场景分析
2020/12/30 Vue.js
[36:45]TNC vs VGJ.S 2018国际邀请赛小组赛BO2 第二场 8.18
2018/08/19 DOTA
python使用socket连接远程服务器的方法
2015/04/29 Python
Python脚本实时处理log文件的方法
2016/11/21 Python
python遍历文件夹下所有excel文件
2018/01/03 Python
Django后台获取前端post上传的文件方法
2018/05/28 Python
python pptx复制指定页的ppt教程
2020/02/14 Python
python删除文件、清空目录的实现方法
2020/09/23 Python
css3 线性渐变和径向渐变示例附图
2014/04/08 HTML / CSS
CSS3实现闪烁动画效果的方法
2015/02/09 HTML / CSS
用CSS3的box-reflect来制作倒影效果
2016/11/15 HTML / CSS
自学考试自我鉴定范文
2013/09/26 职场文书
市级文明单位申报材料
2014/05/07 职场文书
对外汉语专业大学生职业生涯规划书
2014/10/11 职场文书
2014个人年度工作总结范文
2014/12/24 职场文书
大三学生英语考试作弊检讨书
2015/01/01 职场文书
莫言诺贝尔获奖感言(全文)
2015/07/31 职场文书