pytorch实现CNN卷积神经网络


Posted in Python onFebruary 19, 2020

本文为大家讲解了pytorch实现CNN卷积神经网络,供大家参考,具体内容如下

我对卷积神经网络的一些认识

    卷积神经网络是时下最为流行的一种深度学习网络,由于其具有局部感受野等特性,让其与人眼识别图像具有相似性,因此被广泛应用于图像识别中,本人是研究机械故障诊断方面的,一般利用旋转机械的振动信号作为数据。

    对一维信号,通常采取的方法有两种,第一,直接对其做一维卷积,第二,反映到时频图像上,这就变成了图像识别,此前一直都在利用keras搭建网络,最近学了pytroch搭建cnn的方法,进行一下代码的尝试。所用数据为经典的minist手写字体数据集

import torch
import torch.nn as nn
import torch.utils.data as Data
import torchvision
import matplotlib.pyplot as plt
`EPOCH = 1
BATCH_SIZE = 50
LR = 0.001
DOWNLOAD_MNIST = True

从网上下载数据集:

```python
train_data = torchvision.datasets.MNIST(
 root="./mnist/",
 train = True,
 transform=torchvision.transforms.ToTensor(),
 download = DOWNLOAD_MNIST,
)

print(train_data.train_data.size())
print(train_data.train_labels.size())

```plt.imshow(train_data.train_data[0].numpy(), cmap='autumn')
plt.title("%i" % train_data.train_labels[0])
plt.show()

train_loader = Data.DataLoader(dataset=train_data, batch_size=BATCH_SIZE, shuffle=True)

test_data = torchvision.datasets.MNIST(root="./mnist/", train=False)
test_x = torch.unsqueeze(test_data.test_data, dim=1).type(torch.FloatTensor)[:2000]/255.

test_y = test_data.test_labels[:2000]


class CNN(nn.Module):
 def __init__(self):
  super(CNN, self).__init__()
  self.conv1 = nn.Sequential(
   nn.Conv2d(
    in_channels=1,
    out_channels=16,
    kernel_size=5,
    stride=1,
    padding=2,
   ),
   
   nn.ReLU(),
   nn.MaxPool2d(kernel_size=2),
  )
  
  self.conv2 = nn.Sequential(
   nn.Conv2d(16, 32, 5, 1, 2),
   nn.ReLU(),
   nn.MaxPool2d(2),
  )
  
  self.out = nn.Linear(32*7*7, 10) # fully connected layer, output 10 classes
  
 def forward(self, x):
  x = self.conv1(x)
  x = self.conv2(x)
  x = x.view(x.size(0), -1) # flatten the output of conv2 to (batch_size, 32*7*7)
  output = self.out(x)
  return output
 
optimizer = torch.optim.Adam(cnn.parameters(), lr=LR)
loss_func = nn.CrossEntropyLoss()
 
 from matplotlib import cm
try: from sklearn.manifold import TSNE; HAS_SK = True
except: HAS_SK = False; print('Please install sklearn for layer visualization')
def plot_with_labels(lowDWeights, labels):
 plt.cla()
 X, Y = lowDWeights[:, 0], lowDWeights[:, 1]
 for x, y, s in zip(X, Y, labels):
  c = cm.rainbow(int(255 * s / 9)); plt.text(x, y, s, backgroundcolor=c, fontsize=9)
 plt.xlim(X.min(), X.max()); plt.ylim(Y.min(), Y.max()); plt.title('Visualize last layer'); plt.show(); plt.pause(0.01)

plt.ion()

for epoch in range(EPOCH):
 for step, (b_x, b_y) in enumerate(train_loader):
  output = cnn(b_x)
  loss = loss_func(output, b_y)
  optimizer.zero_grad()
  loss.backward()
  optimizer.step()
  
  if step % 50 == 0:
   test_output = cnn(test_x)
   pred_y = torch.max(test_output, 1)[1].data.numpy()
   accuracy = float((pred_y == test_y.data.numpy()).astype(int).sum()) / float(test_y.size(0))
   print("Epoch: ", epoch, "| train loss: %.4f" % loss.data.numpy(), 
     "| test accuracy: %.2f" % accuracy)
   
plt.ioff()

以上就是本文的全部内容,希望对大家的学习有所帮助,也希望大家多多支持三水点靠木。

Python 相关文章推荐
详解Python核心编程中的浅拷贝与深拷贝
Jan 07 Python
Python遍历pandas数据方法总结
Feb 09 Python
对Python中type打开文件的方式介绍
Apr 28 Python
Python selenium抓取微博内容的示例代码
May 17 Python
Python 读取某个目录下所有的文件实例
Jun 23 Python
Python3.8对可迭代解包的改进及用法详解
Oct 15 Python
Python大数据之使用lxml库解析html网页文件示例
Nov 16 Python
IronPython连接MySQL的方法步骤
Dec 27 Python
Python抓新型冠状病毒肺炎疫情数据并绘制全国疫情分布的代码实例
Feb 05 Python
python读取csv文件指定行的2种方法详解
Feb 13 Python
pandas分组聚合详解
Apr 10 Python
Pymysql实现往表中插入数据过程解析
Jun 02 Python
python+opencv3生成一个自定义纯色图教程
Feb 19 #Python
Python 实现Image和Ndarray互相转换
Feb 19 #Python
python3+opencv生成不规则黑白mask实例
Feb 19 #Python
使用celery和Django处理异步任务的流程分析
Feb 19 #Python
Python Numpy,mask图像的生成详解
Feb 19 #Python
浅谈图像处理中掩膜(mask)的意义
Feb 19 #Python
Python中logging日志库实例详解
Feb 19 #Python
You might like
WINXP下apache+php4+mysql
2006/11/25 PHP
收集的php编写大型网站问题集
2007/03/06 PHP
用PHP与XML联手进行网站编程代码实例
2008/07/10 PHP
PHP错误抑制符(@)导致引用传参失败Bug的分析
2011/05/02 PHP
PHP三元运算符的结合性介绍
2012/01/10 PHP
详解PHP编码转换函数应用技巧
2016/10/22 PHP
PHP带节点操作的无限分类实现方法详解
2016/11/09 PHP
在修改准备发的批量美化select+可修改select时,在非IE下发现了几个问题
2007/01/09 Javascript
checkbox 多选框 联动实现代码
2008/10/22 Javascript
javascript 表单规则集合对象
2009/07/21 Javascript
JQuery的自定义事件代码,触发,绑定简单实例
2013/08/01 Javascript
js实现表单Radio切换效果的方法
2015/08/17 Javascript
js模拟淘宝网的多级选择菜单实现方法
2015/08/18 Javascript
JavaScript调用传递变量参数的相关问题及解决办法
2015/11/01 Javascript
javascript的正则匹配方法学习
2016/02/24 Javascript
JavaScript中输出信息的方法(信息确认框-提示输入框-文档流输出)
2016/06/12 Javascript
Angular的$http与$location
2016/12/26 Javascript
Vue学习笔记进阶篇之函数化组件解析
2017/07/21 Javascript
微信小程序实现下拉刷新和轮播图效果
2017/11/21 Javascript
vue.js实现带日期星期的数字时钟功能示例
2018/08/28 Javascript
基于vue实现一个神奇的动态按钮效果
2019/05/15 Javascript
vue的三种图片引入方式代码实例
2019/11/19 Javascript
nuxt 自定义 auth 中间件实现令牌的持久化操作
2020/11/05 Javascript
在nuxt中使用路由重定向的实例
2020/11/06 Javascript
[01:05:24]Ti4 冒泡赛第二天 iG vs NEWBEE 3
2014/07/15 DOTA
在Python中使用next()方法操作文件的教程
2015/05/24 Python
Python正则表达式实现简易计算器功能示例
2019/05/07 Python
关于torch.optim的灵活使用详解(包括重写SGD,加上L1正则)
2020/02/20 Python
求职信模板标准格式范文
2014/02/23 职场文书
暑期社会实践感言
2014/02/25 职场文书
我为自己代言广告词
2014/03/18 职场文书
《登鹳雀楼》教学反思
2014/04/09 职场文书
课外小组活动总结
2014/08/27 职场文书
材料采购员岗位职责
2015/04/03 职场文书
药店收银员岗位职责
2015/04/07 职场文书
2021年pycharm的最新安装教程及基本使用图文详解
2021/04/03 Python