pytorch实现CNN卷积神经网络


Posted in Python onFebruary 19, 2020

本文为大家讲解了pytorch实现CNN卷积神经网络,供大家参考,具体内容如下

我对卷积神经网络的一些认识

    卷积神经网络是时下最为流行的一种深度学习网络,由于其具有局部感受野等特性,让其与人眼识别图像具有相似性,因此被广泛应用于图像识别中,本人是研究机械故障诊断方面的,一般利用旋转机械的振动信号作为数据。

    对一维信号,通常采取的方法有两种,第一,直接对其做一维卷积,第二,反映到时频图像上,这就变成了图像识别,此前一直都在利用keras搭建网络,最近学了pytroch搭建cnn的方法,进行一下代码的尝试。所用数据为经典的minist手写字体数据集

import torch
import torch.nn as nn
import torch.utils.data as Data
import torchvision
import matplotlib.pyplot as plt
`EPOCH = 1
BATCH_SIZE = 50
LR = 0.001
DOWNLOAD_MNIST = True

从网上下载数据集:

```python
train_data = torchvision.datasets.MNIST(
 root="./mnist/",
 train = True,
 transform=torchvision.transforms.ToTensor(),
 download = DOWNLOAD_MNIST,
)

print(train_data.train_data.size())
print(train_data.train_labels.size())

```plt.imshow(train_data.train_data[0].numpy(), cmap='autumn')
plt.title("%i" % train_data.train_labels[0])
plt.show()

train_loader = Data.DataLoader(dataset=train_data, batch_size=BATCH_SIZE, shuffle=True)

test_data = torchvision.datasets.MNIST(root="./mnist/", train=False)
test_x = torch.unsqueeze(test_data.test_data, dim=1).type(torch.FloatTensor)[:2000]/255.

test_y = test_data.test_labels[:2000]


class CNN(nn.Module):
 def __init__(self):
  super(CNN, self).__init__()
  self.conv1 = nn.Sequential(
   nn.Conv2d(
    in_channels=1,
    out_channels=16,
    kernel_size=5,
    stride=1,
    padding=2,
   ),
   
   nn.ReLU(),
   nn.MaxPool2d(kernel_size=2),
  )
  
  self.conv2 = nn.Sequential(
   nn.Conv2d(16, 32, 5, 1, 2),
   nn.ReLU(),
   nn.MaxPool2d(2),
  )
  
  self.out = nn.Linear(32*7*7, 10) # fully connected layer, output 10 classes
  
 def forward(self, x):
  x = self.conv1(x)
  x = self.conv2(x)
  x = x.view(x.size(0), -1) # flatten the output of conv2 to (batch_size, 32*7*7)
  output = self.out(x)
  return output
 
optimizer = torch.optim.Adam(cnn.parameters(), lr=LR)
loss_func = nn.CrossEntropyLoss()
 
 from matplotlib import cm
try: from sklearn.manifold import TSNE; HAS_SK = True
except: HAS_SK = False; print('Please install sklearn for layer visualization')
def plot_with_labels(lowDWeights, labels):
 plt.cla()
 X, Y = lowDWeights[:, 0], lowDWeights[:, 1]
 for x, y, s in zip(X, Y, labels):
  c = cm.rainbow(int(255 * s / 9)); plt.text(x, y, s, backgroundcolor=c, fontsize=9)
 plt.xlim(X.min(), X.max()); plt.ylim(Y.min(), Y.max()); plt.title('Visualize last layer'); plt.show(); plt.pause(0.01)

plt.ion()

for epoch in range(EPOCH):
 for step, (b_x, b_y) in enumerate(train_loader):
  output = cnn(b_x)
  loss = loss_func(output, b_y)
  optimizer.zero_grad()
  loss.backward()
  optimizer.step()
  
  if step % 50 == 0:
   test_output = cnn(test_x)
   pred_y = torch.max(test_output, 1)[1].data.numpy()
   accuracy = float((pred_y == test_y.data.numpy()).astype(int).sum()) / float(test_y.size(0))
   print("Epoch: ", epoch, "| train loss: %.4f" % loss.data.numpy(), 
     "| test accuracy: %.2f" % accuracy)
   
plt.ioff()

以上就是本文的全部内容,希望对大家的学习有所帮助,也希望大家多多支持三水点靠木。

Python 相关文章推荐
python cookielib 登录人人网的实现代码
Dec 19 Python
python实现带错误处理功能的远程文件读取方法
Apr 29 Python
python删除过期文件的方法
May 29 Python
python九九乘法表的实例
Sep 26 Python
python2.7+selenium2实现淘宝滑块自动认证功能
Feb 24 Python
python实现嵌套列表平铺的两种方法
Nov 08 Python
Python批量修改图片分辨率的实例代码
Jul 04 Python
Flask框架学习笔记之消息提示与异常处理操作详解
Aug 15 Python
python批量处理文件或文件夹
Jul 28 Python
python编程进阶之异常处理用法实例分析
Feb 21 Python
Python延迟绑定问题原理及解决方案
Aug 04 Python
pycharm配置安装autopep8自动规范代码的实现
Mar 02 Python
python+opencv3生成一个自定义纯色图教程
Feb 19 #Python
Python 实现Image和Ndarray互相转换
Feb 19 #Python
python3+opencv生成不规则黑白mask实例
Feb 19 #Python
使用celery和Django处理异步任务的流程分析
Feb 19 #Python
Python Numpy,mask图像的生成详解
Feb 19 #Python
浅谈图像处理中掩膜(mask)的意义
Feb 19 #Python
Python中logging日志库实例详解
Feb 19 #Python
You might like
Apache设置虚拟WEB
2006/10/09 PHP
php网页后退不再出现过期
2007/03/08 PHP
php 计划任务 检测用户连接状态
2012/03/29 PHP
php 数据结构之链表队列
2017/10/17 PHP
Laravel 已登陆用户再次查看登陆页面的自动跳转设置方法
2019/09/30 PHP
Auntion-TableSort国人写的一个javascript表格排序的东西
2007/11/12 Javascript
javascript数组去掉重复
2011/05/12 Javascript
使用js写的一个简易的投票
2013/11/27 Javascript
ExtJS4如何自动生成控制grid的列显示、隐藏的checkbox
2014/05/02 Javascript
nodejs npm包管理的配置方法及常用命令介绍
2014/06/05 NodeJs
javascript闭包的理解
2015/04/01 Javascript
javascript同步服务器时间和同步倒计时小技巧
2015/09/24 Javascript
微信小程序 SocketIO 实例讲解
2016/10/13 Javascript
纯JS实现图片验证码功能并兼容IE6-8(推荐)
2017/04/19 Javascript
jQuery无冲突模式详解
2019/01/17 jQuery
微信小程序Page中data数据操作和函数调用方法
2019/05/08 Javascript
layui数据表格跨行自动合并的例子
2019/09/02 Javascript
Vue解析剪切板图片并实现发送功能
2020/02/04 Javascript
用js编写留言板
2020/03/17 Javascript
Python发送Email方法实例
2014/08/21 Python
Python中的Descriptor描述符学习教程
2016/06/02 Python
Python面向对象之类的封装操作示例
2019/06/08 Python
python列表推导和生成器表达式知识点总结
2020/01/10 Python
OpenCV python sklearn随机超参数搜索的实现
2020/01/17 Python
利用Tensorflow的队列多线程读取数据方式
2020/02/05 Python
python3 sorted 如何实现自定义排序标准
2020/03/12 Python
Python 如何创建一个简单的REST接口
2020/07/30 Python
selenium+超级鹰实现模拟登录12306
2021/01/24 Python
使用HTML5里的classList操作CSS类
2016/06/28 HTML / CSS
极度干燥澳大利亚官方网站:Superdry澳大利亚
2019/03/28 全球购物
俄罗斯在线大型超市:ТутПросто
2021/01/08 全球购物
英文求职信范文
2014/05/23 职场文书
美食节目策划方案
2014/05/31 职场文书
个人工作作风整改措施思想汇报
2014/10/13 职场文书
2014小学教师个人工作总结
2014/11/10 职场文书
2015年农村党员公开承诺事项
2015/04/28 职场文书