pytorch实现CNN卷积神经网络


Posted in Python onFebruary 19, 2020

本文为大家讲解了pytorch实现CNN卷积神经网络,供大家参考,具体内容如下

我对卷积神经网络的一些认识

    卷积神经网络是时下最为流行的一种深度学习网络,由于其具有局部感受野等特性,让其与人眼识别图像具有相似性,因此被广泛应用于图像识别中,本人是研究机械故障诊断方面的,一般利用旋转机械的振动信号作为数据。

    对一维信号,通常采取的方法有两种,第一,直接对其做一维卷积,第二,反映到时频图像上,这就变成了图像识别,此前一直都在利用keras搭建网络,最近学了pytroch搭建cnn的方法,进行一下代码的尝试。所用数据为经典的minist手写字体数据集

import torch
import torch.nn as nn
import torch.utils.data as Data
import torchvision
import matplotlib.pyplot as plt
`EPOCH = 1
BATCH_SIZE = 50
LR = 0.001
DOWNLOAD_MNIST = True

从网上下载数据集:

```python
train_data = torchvision.datasets.MNIST(
 root="./mnist/",
 train = True,
 transform=torchvision.transforms.ToTensor(),
 download = DOWNLOAD_MNIST,
)

print(train_data.train_data.size())
print(train_data.train_labels.size())

```plt.imshow(train_data.train_data[0].numpy(), cmap='autumn')
plt.title("%i" % train_data.train_labels[0])
plt.show()

train_loader = Data.DataLoader(dataset=train_data, batch_size=BATCH_SIZE, shuffle=True)

test_data = torchvision.datasets.MNIST(root="./mnist/", train=False)
test_x = torch.unsqueeze(test_data.test_data, dim=1).type(torch.FloatTensor)[:2000]/255.

test_y = test_data.test_labels[:2000]


class CNN(nn.Module):
 def __init__(self):
  super(CNN, self).__init__()
  self.conv1 = nn.Sequential(
   nn.Conv2d(
    in_channels=1,
    out_channels=16,
    kernel_size=5,
    stride=1,
    padding=2,
   ),
   
   nn.ReLU(),
   nn.MaxPool2d(kernel_size=2),
  )
  
  self.conv2 = nn.Sequential(
   nn.Conv2d(16, 32, 5, 1, 2),
   nn.ReLU(),
   nn.MaxPool2d(2),
  )
  
  self.out = nn.Linear(32*7*7, 10) # fully connected layer, output 10 classes
  
 def forward(self, x):
  x = self.conv1(x)
  x = self.conv2(x)
  x = x.view(x.size(0), -1) # flatten the output of conv2 to (batch_size, 32*7*7)
  output = self.out(x)
  return output
 
optimizer = torch.optim.Adam(cnn.parameters(), lr=LR)
loss_func = nn.CrossEntropyLoss()
 
 from matplotlib import cm
try: from sklearn.manifold import TSNE; HAS_SK = True
except: HAS_SK = False; print('Please install sklearn for layer visualization')
def plot_with_labels(lowDWeights, labels):
 plt.cla()
 X, Y = lowDWeights[:, 0], lowDWeights[:, 1]
 for x, y, s in zip(X, Y, labels):
  c = cm.rainbow(int(255 * s / 9)); plt.text(x, y, s, backgroundcolor=c, fontsize=9)
 plt.xlim(X.min(), X.max()); plt.ylim(Y.min(), Y.max()); plt.title('Visualize last layer'); plt.show(); plt.pause(0.01)

plt.ion()

for epoch in range(EPOCH):
 for step, (b_x, b_y) in enumerate(train_loader):
  output = cnn(b_x)
  loss = loss_func(output, b_y)
  optimizer.zero_grad()
  loss.backward()
  optimizer.step()
  
  if step % 50 == 0:
   test_output = cnn(test_x)
   pred_y = torch.max(test_output, 1)[1].data.numpy()
   accuracy = float((pred_y == test_y.data.numpy()).astype(int).sum()) / float(test_y.size(0))
   print("Epoch: ", epoch, "| train loss: %.4f" % loss.data.numpy(), 
     "| test accuracy: %.2f" % accuracy)
   
plt.ioff()

以上就是本文的全部内容,希望对大家的学习有所帮助,也希望大家多多支持三水点靠木。

Python 相关文章推荐
Python压缩解压缩zip文件及破解zip文件密码的方法
Nov 04 Python
Python3实现并发检验代理池地址的方法
Sep 18 Python
python 3利用BeautifulSoup抓取div标签的方法示例
May 28 Python
python实现图书管理系统
Mar 12 Python
python实现关键词提取的示例讲解
Apr 28 Python
Python中的TCP socket写法示例
May 11 Python
对python 中class与变量的使用方法详解
Jun 26 Python
详解python websocket获取实时数据的几种常见链接方式
Jul 01 Python
Django基础三之视图函数的使用方法
Jul 18 Python
调用其他python脚本文件里面的类和方法过程解析
Nov 15 Python
Python如何将装饰器定义为类
Jul 30 Python
python爬虫中抓取指数的实例讲解
Dec 01 Python
python+opencv3生成一个自定义纯色图教程
Feb 19 #Python
Python 实现Image和Ndarray互相转换
Feb 19 #Python
python3+opencv生成不规则黑白mask实例
Feb 19 #Python
使用celery和Django处理异步任务的流程分析
Feb 19 #Python
Python Numpy,mask图像的生成详解
Feb 19 #Python
浅谈图像处理中掩膜(mask)的意义
Feb 19 #Python
Python中logging日志库实例详解
Feb 19 #Python
You might like
thinkphp实现like模糊查询实例
2014/10/29 PHP
5款适合PHP使用的HTML编辑器推荐
2015/07/03 PHP
thinkPHP下的widget扩展用法实例分析
2015/12/26 PHP
Zend Framework实现Zend_View集成Smarty模板系统的方法
2016/03/05 PHP
CI框架中redis缓存相关操作文件示例代码
2016/05/17 PHP
thinkphp框架类库扩展操作示例
2019/11/26 PHP
js 3秒后跳转页面的实现代码
2014/03/10 Javascript
JavaScript调试技巧之console.log()详解
2014/03/19 Javascript
JSONP跨域的原理解析及其实现介绍
2014/03/22 Javascript
jquery实现根据浏览器窗口大小自动缩放图片的方法
2015/07/17 Javascript
基于javascript制作微博发布栏效果
2016/04/04 Javascript
关于JS中的方法是否加括号的问题
2016/07/27 Javascript
pc加载更多功能和移动端下拉刷新加载数据
2016/11/07 Javascript
Bootstrap 3浏览器兼容性问题及解决方案
2017/04/11 Javascript
详解Nodejs之npm&package.json
2017/06/15 NodeJs
axios全局请求参数设置,请求及返回拦截器的方法
2018/03/05 Javascript
JavaScript实现简单的文本逐字打印效果示例
2018/04/12 Javascript
React中使用async validator进行表单验证的实例代码
2018/08/17 Javascript
JavaScript快速调试的两个技巧
2020/11/04 Javascript
[41:12]Liquid vs Secret 2019国际邀请赛淘汰赛 败者组 BO3 第一场 8.24
2019/09/10 DOTA
使用优化器来提升Python程序的执行效率的教程
2015/04/02 Python
利用 Monkey 命令操作屏幕快速滑动
2016/12/07 Python
python如何重载模块实例解析
2018/01/25 Python
python使用flask与js进行前后台交互的例子
2019/07/19 Python
django rest framework serializer返回时间自动格式化方法
2020/03/31 Python
解决windows下python3使用multiprocessing.Pool出现的问题
2020/04/08 Python
无惧面试,带你搞懂python 装饰器
2020/08/17 Python
python实现三壶谜题的示例详解
2020/11/02 Python
html5 Canvas绘制线条 closePath()实例代码
2012/05/10 HTML / CSS
2014新课程改革心得体会
2014/03/10 职场文书
文明社区申报材料
2014/08/21 职场文书
夫妻房产协议书的格式
2014/10/11 职场文书
考试作弊万能检讨书
2014/10/19 职场文书
2015年乡镇工会工作总结
2015/05/19 职场文书
新郎父母婚礼答谢词
2015/09/29 职场文书
使用Spring处理x-www-form-urlencoded方式
2021/11/02 Java/Android