pytorch实现CNN卷积神经网络


Posted in Python onFebruary 19, 2020

本文为大家讲解了pytorch实现CNN卷积神经网络,供大家参考,具体内容如下

我对卷积神经网络的一些认识

    卷积神经网络是时下最为流行的一种深度学习网络,由于其具有局部感受野等特性,让其与人眼识别图像具有相似性,因此被广泛应用于图像识别中,本人是研究机械故障诊断方面的,一般利用旋转机械的振动信号作为数据。

    对一维信号,通常采取的方法有两种,第一,直接对其做一维卷积,第二,反映到时频图像上,这就变成了图像识别,此前一直都在利用keras搭建网络,最近学了pytroch搭建cnn的方法,进行一下代码的尝试。所用数据为经典的minist手写字体数据集

import torch
import torch.nn as nn
import torch.utils.data as Data
import torchvision
import matplotlib.pyplot as plt
`EPOCH = 1
BATCH_SIZE = 50
LR = 0.001
DOWNLOAD_MNIST = True

从网上下载数据集:

```python
train_data = torchvision.datasets.MNIST(
 root="./mnist/",
 train = True,
 transform=torchvision.transforms.ToTensor(),
 download = DOWNLOAD_MNIST,
)

print(train_data.train_data.size())
print(train_data.train_labels.size())

```plt.imshow(train_data.train_data[0].numpy(), cmap='autumn')
plt.title("%i" % train_data.train_labels[0])
plt.show()

train_loader = Data.DataLoader(dataset=train_data, batch_size=BATCH_SIZE, shuffle=True)

test_data = torchvision.datasets.MNIST(root="./mnist/", train=False)
test_x = torch.unsqueeze(test_data.test_data, dim=1).type(torch.FloatTensor)[:2000]/255.

test_y = test_data.test_labels[:2000]


class CNN(nn.Module):
 def __init__(self):
  super(CNN, self).__init__()
  self.conv1 = nn.Sequential(
   nn.Conv2d(
    in_channels=1,
    out_channels=16,
    kernel_size=5,
    stride=1,
    padding=2,
   ),
   
   nn.ReLU(),
   nn.MaxPool2d(kernel_size=2),
  )
  
  self.conv2 = nn.Sequential(
   nn.Conv2d(16, 32, 5, 1, 2),
   nn.ReLU(),
   nn.MaxPool2d(2),
  )
  
  self.out = nn.Linear(32*7*7, 10) # fully connected layer, output 10 classes
  
 def forward(self, x):
  x = self.conv1(x)
  x = self.conv2(x)
  x = x.view(x.size(0), -1) # flatten the output of conv2 to (batch_size, 32*7*7)
  output = self.out(x)
  return output
 
optimizer = torch.optim.Adam(cnn.parameters(), lr=LR)
loss_func = nn.CrossEntropyLoss()
 
 from matplotlib import cm
try: from sklearn.manifold import TSNE; HAS_SK = True
except: HAS_SK = False; print('Please install sklearn for layer visualization')
def plot_with_labels(lowDWeights, labels):
 plt.cla()
 X, Y = lowDWeights[:, 0], lowDWeights[:, 1]
 for x, y, s in zip(X, Y, labels):
  c = cm.rainbow(int(255 * s / 9)); plt.text(x, y, s, backgroundcolor=c, fontsize=9)
 plt.xlim(X.min(), X.max()); plt.ylim(Y.min(), Y.max()); plt.title('Visualize last layer'); plt.show(); plt.pause(0.01)

plt.ion()

for epoch in range(EPOCH):
 for step, (b_x, b_y) in enumerate(train_loader):
  output = cnn(b_x)
  loss = loss_func(output, b_y)
  optimizer.zero_grad()
  loss.backward()
  optimizer.step()
  
  if step % 50 == 0:
   test_output = cnn(test_x)
   pred_y = torch.max(test_output, 1)[1].data.numpy()
   accuracy = float((pred_y == test_y.data.numpy()).astype(int).sum()) / float(test_y.size(0))
   print("Epoch: ", epoch, "| train loss: %.4f" % loss.data.numpy(), 
     "| test accuracy: %.2f" % accuracy)
   
plt.ioff()

以上就是本文的全部内容,希望对大家的学习有所帮助,也希望大家多多支持三水点靠木。

Python 相关文章推荐
python实现中文输出的两种方法
May 09 Python
详解Python的Django框架中的模版继承
Jul 16 Python
MySQL中表的复制以及大型数据表的备份教程
Nov 25 Python
Python使用dis模块把Python反编译为字节码的用法详解
Jun 14 Python
python 异常处理总结
Oct 18 Python
Python3安装Scrapy的方法步骤
Nov 23 Python
详解Django模版中加载静态文件配置方法
Jul 21 Python
Pycharm+Python+PyQt5使用详解
Sep 25 Python
Python爬虫程序架构和运行流程原理解析
Mar 09 Python
python实现FTP循环上传文件
Mar 20 Python
python实现同一局域网下传输图片
Mar 20 Python
python 利用jieba.analyse进行 关键词提取
Dec 17 Python
python+opencv3生成一个自定义纯色图教程
Feb 19 #Python
Python 实现Image和Ndarray互相转换
Feb 19 #Python
python3+opencv生成不规则黑白mask实例
Feb 19 #Python
使用celery和Django处理异步任务的流程分析
Feb 19 #Python
Python Numpy,mask图像的生成详解
Feb 19 #Python
浅谈图像处理中掩膜(mask)的意义
Feb 19 #Python
Python中logging日志库实例详解
Feb 19 #Python
You might like
PHP对表单提交特殊字符的过滤和处理方法汇总
2014/02/18 PHP
浅析get与post的一些特殊情况
2014/07/28 PHP
ExtJs设置GridPanel表格文本垂直居中示例
2013/07/15 Javascript
node.js中的path.sep方法使用说明
2014/12/08 Javascript
jQuery自定义动画函数实例详解(附demo源码)
2015/12/10 Javascript
jQuery Uploadify 上传插件出现Http Error 302 错误的解决办法
2015/12/12 Javascript
一种新的javascript对象创建方式Object.create()
2015/12/28 Javascript
AngularJS自动表单验证
2016/02/01 Javascript
分享javascript、jquery实用代码段
2016/10/20 Javascript
js实现返回顶部效果
2017/03/10 Javascript
vue.js项目中实用的小技巧汇总
2017/11/29 Javascript
基于滚动条位置判断的简单实例
2017/12/14 Javascript
如何实现双向绑定mvvm的原理实现
2019/05/28 Javascript
JS实现简单日历特效
2020/01/03 Javascript
antd配置config-overrides.js文件的操作
2020/10/31 Javascript
使用Python的Zato发送AMQP消息的教程
2015/04/16 Python
详解Python的Django框架中的模版相关知识
2015/07/15 Python
Python中的FTP通信模块ftplib的用法整理
2016/07/08 Python
对numpy.append()里的axis的用法详解
2018/06/28 Python
Python计算开方、立方、圆周率,精确到小数点后任意位的方法
2018/07/17 Python
Python判断字符串是否为字母或者数字(浮点数)的多种方法
2018/08/03 Python
浅谈Python3 numpy.ptp()最大值与最小值的差
2019/08/24 Python
python创建学生管理系统
2019/11/22 Python
基于Python 中函数的 收集参数 机制
2019/12/21 Python
对CSS3选择器的研究(详解)
2016/09/16 HTML / CSS
html5 Canvas画图教程(11)—使用lineTo/arc/bezierCurveTo画椭圆形
2013/01/09 HTML / CSS
基于HTML5新特性Mutation Observer实现编辑器的撤销和回退操作
2016/01/11 HTML / CSS
荷兰在线体育用品商店:Avantisport.nl
2018/07/04 全球购物
Java程序开发中如何应用线程
2016/03/03 面试题
会计岗位职责
2013/11/08 职场文书
餐饮业会计岗位职责
2013/12/19 职场文书
应届毕业生如何写求职信
2014/02/16 职场文书
企业诚信承诺书
2014/05/23 职场文书
音乐教育专业自荐信
2014/09/18 职场文书
毕业生自荐求职信书写的技巧
2019/08/26 职场文书
温馨祝福晨语:美丽的一天从我的问候开始
2019/11/28 职场文书