pytorch实现CNN卷积神经网络


Posted in Python onFebruary 19, 2020

本文为大家讲解了pytorch实现CNN卷积神经网络,供大家参考,具体内容如下

我对卷积神经网络的一些认识

    卷积神经网络是时下最为流行的一种深度学习网络,由于其具有局部感受野等特性,让其与人眼识别图像具有相似性,因此被广泛应用于图像识别中,本人是研究机械故障诊断方面的,一般利用旋转机械的振动信号作为数据。

    对一维信号,通常采取的方法有两种,第一,直接对其做一维卷积,第二,反映到时频图像上,这就变成了图像识别,此前一直都在利用keras搭建网络,最近学了pytroch搭建cnn的方法,进行一下代码的尝试。所用数据为经典的minist手写字体数据集

import torch
import torch.nn as nn
import torch.utils.data as Data
import torchvision
import matplotlib.pyplot as plt
`EPOCH = 1
BATCH_SIZE = 50
LR = 0.001
DOWNLOAD_MNIST = True

从网上下载数据集:

```python
train_data = torchvision.datasets.MNIST(
 root="./mnist/",
 train = True,
 transform=torchvision.transforms.ToTensor(),
 download = DOWNLOAD_MNIST,
)

print(train_data.train_data.size())
print(train_data.train_labels.size())

```plt.imshow(train_data.train_data[0].numpy(), cmap='autumn')
plt.title("%i" % train_data.train_labels[0])
plt.show()

train_loader = Data.DataLoader(dataset=train_data, batch_size=BATCH_SIZE, shuffle=True)

test_data = torchvision.datasets.MNIST(root="./mnist/", train=False)
test_x = torch.unsqueeze(test_data.test_data, dim=1).type(torch.FloatTensor)[:2000]/255.

test_y = test_data.test_labels[:2000]


class CNN(nn.Module):
 def __init__(self):
  super(CNN, self).__init__()
  self.conv1 = nn.Sequential(
   nn.Conv2d(
    in_channels=1,
    out_channels=16,
    kernel_size=5,
    stride=1,
    padding=2,
   ),
   
   nn.ReLU(),
   nn.MaxPool2d(kernel_size=2),
  )
  
  self.conv2 = nn.Sequential(
   nn.Conv2d(16, 32, 5, 1, 2),
   nn.ReLU(),
   nn.MaxPool2d(2),
  )
  
  self.out = nn.Linear(32*7*7, 10) # fully connected layer, output 10 classes
  
 def forward(self, x):
  x = self.conv1(x)
  x = self.conv2(x)
  x = x.view(x.size(0), -1) # flatten the output of conv2 to (batch_size, 32*7*7)
  output = self.out(x)
  return output
 
optimizer = torch.optim.Adam(cnn.parameters(), lr=LR)
loss_func = nn.CrossEntropyLoss()
 
 from matplotlib import cm
try: from sklearn.manifold import TSNE; HAS_SK = True
except: HAS_SK = False; print('Please install sklearn for layer visualization')
def plot_with_labels(lowDWeights, labels):
 plt.cla()
 X, Y = lowDWeights[:, 0], lowDWeights[:, 1]
 for x, y, s in zip(X, Y, labels):
  c = cm.rainbow(int(255 * s / 9)); plt.text(x, y, s, backgroundcolor=c, fontsize=9)
 plt.xlim(X.min(), X.max()); plt.ylim(Y.min(), Y.max()); plt.title('Visualize last layer'); plt.show(); plt.pause(0.01)

plt.ion()

for epoch in range(EPOCH):
 for step, (b_x, b_y) in enumerate(train_loader):
  output = cnn(b_x)
  loss = loss_func(output, b_y)
  optimizer.zero_grad()
  loss.backward()
  optimizer.step()
  
  if step % 50 == 0:
   test_output = cnn(test_x)
   pred_y = torch.max(test_output, 1)[1].data.numpy()
   accuracy = float((pred_y == test_y.data.numpy()).astype(int).sum()) / float(test_y.size(0))
   print("Epoch: ", epoch, "| train loss: %.4f" % loss.data.numpy(), 
     "| test accuracy: %.2f" % accuracy)
   
plt.ioff()

以上就是本文的全部内容,希望对大家的学习有所帮助,也希望大家多多支持三水点靠木。

Python 相关文章推荐
深入解析Python中的集合类型操作符
Aug 19 Python
利用Python学习RabbitMQ消息队列
Nov 30 Python
利用python程序生成word和PDF文档的方法
Feb 14 Python
Python可变参数用法实例分析
Apr 02 Python
深入理解python中函数传递参数是值传递还是引用传递
Nov 07 Python
python监控进程脚本
Apr 12 Python
python顺序的读取文件夹下名称有序的文件方法
Jul 11 Python
linux下python中文乱码解决方案详解
Aug 28 Python
python中property属性的介绍及其应用详解
Aug 29 Python
Anaconda使用IDLE的实现示例
Sep 23 Python
python中slice参数过长的处理方法及实例
Dec 15 Python
Python中tkinter的用户登录管理的实现
Apr 22 Python
python+opencv3生成一个自定义纯色图教程
Feb 19 #Python
Python 实现Image和Ndarray互相转换
Feb 19 #Python
python3+opencv生成不规则黑白mask实例
Feb 19 #Python
使用celery和Django处理异步任务的流程分析
Feb 19 #Python
Python Numpy,mask图像的生成详解
Feb 19 #Python
浅谈图像处理中掩膜(mask)的意义
Feb 19 #Python
Python中logging日志库实例详解
Feb 19 #Python
You might like
一周学会PHP(视频)Http下载
2006/12/12 PHP
php操作mongodb封装类与用法实例
2018/09/01 PHP
Javascript String对象扩展HTML编码和解码的方法
2009/06/02 Javascript
javascript 面向对象编程 function也是类
2009/09/17 Javascript
JavaScript 入门基础知识 想学习js的朋友可以参考下
2009/12/26 Javascript
理解Javascript_03_javascript全局观
2010/10/11 Javascript
Dom操作之兼容技巧分享
2011/09/20 Javascript
分享一个自定义的console类 让你不再纠结JS中的调试代码的兼容
2012/04/20 Javascript
JavaScript的History API使搜索引擎抓取AJAX内容
2015/12/07 Javascript
Bootstrap Paginator分页插件使用方法详解
2016/05/30 Javascript
全面解析DOM操作和jQuery实现选项移动操作代码分享
2016/06/07 Javascript
AngularJs的UI组件ui-Bootstrap之Tooltip和Popover
2018/07/13 Javascript
基于Vue实现关键词实时搜索高亮显示关键词
2018/07/21 Javascript
JS浅拷贝和深拷贝原理与实现方法分析
2019/02/28 Javascript
Layui Table js 模拟选中checkbox的例子
2019/09/03 Javascript
layui问题之自动滚动二级iframe页面到指定位置的方法
2019/09/18 Javascript
[01:03]悬念揭晓 11月26日DOTA2完美盛典不见不散
2017/11/23 DOTA
python的urllib模块显示下载进度示例
2014/01/17 Python
python传递参数方式小结
2015/04/17 Python
Python实现优先级队列结构的方法详解
2016/06/02 Python
Python中关键字global和nonlocal的区别详解
2018/09/03 Python
DataFrame:通过SparkSql将scala类转为DataFrame的方法
2019/01/29 Python
Python爬虫库BeautifulSoup获取对象(标签)名,属性,内容,注释
2020/01/25 Python
keras和tensorflow使用fit_generator 批次训练操作
2020/07/03 Python
澳大利亚优惠网站:Deals.com.au
2019/07/02 全球购物
北京华建集团SQL面试题
2014/06/03 面试题
建筑施工员岗位职责
2013/11/26 职场文书
工程造价专业大学生职业规划范文
2014/03/09 职场文书
承诺书的格式范文
2014/03/28 职场文书
高中校园广播稿3篇
2014/09/29 职场文书
2015年反腐倡廉工作总结
2015/05/14 职场文书
2019财务毕业实习报告
2019/06/27 职场文书
Python Numpy之linspace用法说明
2021/04/17 Python
解决python绘图使用subplots出现标题重叠的问题
2021/04/30 Python
详解pytorch创建tensor函数
2022/03/22 Python
js前端面试常见浏览器缓存强缓存及协商缓存实例
2022/06/21 Javascript