pytorch实现CNN卷积神经网络


Posted in Python onFebruary 19, 2020

本文为大家讲解了pytorch实现CNN卷积神经网络,供大家参考,具体内容如下

我对卷积神经网络的一些认识

    卷积神经网络是时下最为流行的一种深度学习网络,由于其具有局部感受野等特性,让其与人眼识别图像具有相似性,因此被广泛应用于图像识别中,本人是研究机械故障诊断方面的,一般利用旋转机械的振动信号作为数据。

    对一维信号,通常采取的方法有两种,第一,直接对其做一维卷积,第二,反映到时频图像上,这就变成了图像识别,此前一直都在利用keras搭建网络,最近学了pytroch搭建cnn的方法,进行一下代码的尝试。所用数据为经典的minist手写字体数据集

import torch
import torch.nn as nn
import torch.utils.data as Data
import torchvision
import matplotlib.pyplot as plt
`EPOCH = 1
BATCH_SIZE = 50
LR = 0.001
DOWNLOAD_MNIST = True

从网上下载数据集:

```python
train_data = torchvision.datasets.MNIST(
 root="./mnist/",
 train = True,
 transform=torchvision.transforms.ToTensor(),
 download = DOWNLOAD_MNIST,
)

print(train_data.train_data.size())
print(train_data.train_labels.size())

```plt.imshow(train_data.train_data[0].numpy(), cmap='autumn')
plt.title("%i" % train_data.train_labels[0])
plt.show()

train_loader = Data.DataLoader(dataset=train_data, batch_size=BATCH_SIZE, shuffle=True)

test_data = torchvision.datasets.MNIST(root="./mnist/", train=False)
test_x = torch.unsqueeze(test_data.test_data, dim=1).type(torch.FloatTensor)[:2000]/255.

test_y = test_data.test_labels[:2000]


class CNN(nn.Module):
 def __init__(self):
  super(CNN, self).__init__()
  self.conv1 = nn.Sequential(
   nn.Conv2d(
    in_channels=1,
    out_channels=16,
    kernel_size=5,
    stride=1,
    padding=2,
   ),
   
   nn.ReLU(),
   nn.MaxPool2d(kernel_size=2),
  )
  
  self.conv2 = nn.Sequential(
   nn.Conv2d(16, 32, 5, 1, 2),
   nn.ReLU(),
   nn.MaxPool2d(2),
  )
  
  self.out = nn.Linear(32*7*7, 10) # fully connected layer, output 10 classes
  
 def forward(self, x):
  x = self.conv1(x)
  x = self.conv2(x)
  x = x.view(x.size(0), -1) # flatten the output of conv2 to (batch_size, 32*7*7)
  output = self.out(x)
  return output
 
optimizer = torch.optim.Adam(cnn.parameters(), lr=LR)
loss_func = nn.CrossEntropyLoss()
 
 from matplotlib import cm
try: from sklearn.manifold import TSNE; HAS_SK = True
except: HAS_SK = False; print('Please install sklearn for layer visualization')
def plot_with_labels(lowDWeights, labels):
 plt.cla()
 X, Y = lowDWeights[:, 0], lowDWeights[:, 1]
 for x, y, s in zip(X, Y, labels):
  c = cm.rainbow(int(255 * s / 9)); plt.text(x, y, s, backgroundcolor=c, fontsize=9)
 plt.xlim(X.min(), X.max()); plt.ylim(Y.min(), Y.max()); plt.title('Visualize last layer'); plt.show(); plt.pause(0.01)

plt.ion()

for epoch in range(EPOCH):
 for step, (b_x, b_y) in enumerate(train_loader):
  output = cnn(b_x)
  loss = loss_func(output, b_y)
  optimizer.zero_grad()
  loss.backward()
  optimizer.step()
  
  if step % 50 == 0:
   test_output = cnn(test_x)
   pred_y = torch.max(test_output, 1)[1].data.numpy()
   accuracy = float((pred_y == test_y.data.numpy()).astype(int).sum()) / float(test_y.size(0))
   print("Epoch: ", epoch, "| train loss: %.4f" % loss.data.numpy(), 
     "| test accuracy: %.2f" % accuracy)
   
plt.ioff()

以上就是本文的全部内容,希望对大家的学习有所帮助,也希望大家多多支持三水点靠木。

Python 相关文章推荐
python中实现php的var_dump函数功能
Jan 21 Python
在Django框架中伪造捕捉到的URLconf值的方法
Jul 18 Python
Python的Django框架中模板碎片缓存简介
Jul 24 Python
使用Python3 编写简单信用卡管理程序
Dec 21 Python
python中实现迭代器(iterator)的方法示例
Jan 19 Python
python如何重载模块实例解析
Jan 25 Python
python实现列表中由数值查到索引的方法
Jun 27 Python
Python将文字转成语音并读出来的实例详解
Jul 15 Python
详解Python中字符串前“b”,“r”,“u”,“f”的作用
Dec 18 Python
完美解决pycharm导入自己写的py文件爆红问题
Feb 12 Python
使用Python获取当前工作目录和执行命令的位置
Mar 09 Python
使用python求斐波那契数列中第n个数的值示例代码
Jul 26 Python
python+opencv3生成一个自定义纯色图教程
Feb 19 #Python
Python 实现Image和Ndarray互相转换
Feb 19 #Python
python3+opencv生成不规则黑白mask实例
Feb 19 #Python
使用celery和Django处理异步任务的流程分析
Feb 19 #Python
Python Numpy,mask图像的生成详解
Feb 19 #Python
浅谈图像处理中掩膜(mask)的意义
Feb 19 #Python
Python中logging日志库实例详解
Feb 19 #Python
You might like
用mysql触发器自动更新memcache的实现代码
2009/10/11 PHP
PHP读书笔记整理_结构语句详解
2016/07/01 PHP
PHP中in_array的隐式转换的解决方法
2018/03/06 PHP
PHP Include文件实例讲解
2019/02/15 PHP
JavaScript进阶教程(第四课第一部分)
2007/04/05 Javascript
javascript 时间比较实现代码
2009/10/28 Javascript
用JS控制回车事件的代码
2011/02/20 Javascript
利用JS实现浏览器的title闪烁
2013/07/08 Javascript
javascript版的in_array函数(判断数组中是否存在特定值)
2014/05/09 Javascript
浅谈javascript面向对象程序设计
2015/01/21 Javascript
jQuery使用removeClass方法删除元素指定Class的方法
2015/03/26 Javascript
jQuery表格插件datatables用法汇总
2016/03/29 Javascript
jQuery表单验证简单示例
2016/10/17 Javascript
js实现刷新页面后回到记录时滚动条的位置【两种方案可选】
2016/12/12 Javascript
利用Vue.js框架实现火车票查询系统(附源码)
2017/02/27 Javascript
简单实现jQuery轮播效果
2017/08/18 jQuery
解析vue中的$mount
2017/12/21 Javascript
利用adb shell和node.js实现抖音自动抢红包功能(推荐)
2018/02/22 Javascript
JavaScript设计模式之责任链模式实例分析
2019/01/16 Javascript
JS实现排行榜文字向上滚动轮播效果
2019/11/26 Javascript
[03:34]2014DOTA2西雅图国际邀请赛 淘汰赛7月15日TOPPLAY
2014/07/15 DOTA
[01:05:59]Mineski vs Secret 2019国际邀请赛淘汰赛 败者组 BO3 第二场 8.22
2019/09/05 DOTA
Pycharm学习教程(1) 定制外观
2017/05/02 Python
python数据处理之如何选取csv文件中某几行的数据
2019/09/02 Python
python 下 CMake 安装配置 OPENCV 4.1.1的方法
2019/09/30 Python
基于Python把网站域名解析成ip地址
2020/05/25 Python
python中upper是做什么用的
2020/07/20 Python
基于python实现监听Rabbitmq系统日志代码示例
2020/11/28 Python
python集合的新增元素方法整理
2020/12/07 Python
CSS3中媒体查询结合rem布局适配手机屏幕
2019/06/10 HTML / CSS
June Jacobs尊积帕官网:知名的spa水疗护肤品牌
2019/03/21 全球购物
四年大学生活的自我评价范文
2014/02/07 职场文书
学习十八届四中全会依法治国心得体会
2014/11/03 职场文书
领导干部群众路线对照检查材料
2014/11/05 职场文书
长城的导游词
2015/01/30 职场文书
办公室岗位职责
2015/02/04 职场文书