pytorch实现CNN卷积神经网络


Posted in Python onFebruary 19, 2020

本文为大家讲解了pytorch实现CNN卷积神经网络,供大家参考,具体内容如下

我对卷积神经网络的一些认识

    卷积神经网络是时下最为流行的一种深度学习网络,由于其具有局部感受野等特性,让其与人眼识别图像具有相似性,因此被广泛应用于图像识别中,本人是研究机械故障诊断方面的,一般利用旋转机械的振动信号作为数据。

    对一维信号,通常采取的方法有两种,第一,直接对其做一维卷积,第二,反映到时频图像上,这就变成了图像识别,此前一直都在利用keras搭建网络,最近学了pytroch搭建cnn的方法,进行一下代码的尝试。所用数据为经典的minist手写字体数据集

import torch
import torch.nn as nn
import torch.utils.data as Data
import torchvision
import matplotlib.pyplot as plt
`EPOCH = 1
BATCH_SIZE = 50
LR = 0.001
DOWNLOAD_MNIST = True

从网上下载数据集:

```python
train_data = torchvision.datasets.MNIST(
 root="./mnist/",
 train = True,
 transform=torchvision.transforms.ToTensor(),
 download = DOWNLOAD_MNIST,
)

print(train_data.train_data.size())
print(train_data.train_labels.size())

```plt.imshow(train_data.train_data[0].numpy(), cmap='autumn')
plt.title("%i" % train_data.train_labels[0])
plt.show()

train_loader = Data.DataLoader(dataset=train_data, batch_size=BATCH_SIZE, shuffle=True)

test_data = torchvision.datasets.MNIST(root="./mnist/", train=False)
test_x = torch.unsqueeze(test_data.test_data, dim=1).type(torch.FloatTensor)[:2000]/255.

test_y = test_data.test_labels[:2000]


class CNN(nn.Module):
 def __init__(self):
  super(CNN, self).__init__()
  self.conv1 = nn.Sequential(
   nn.Conv2d(
    in_channels=1,
    out_channels=16,
    kernel_size=5,
    stride=1,
    padding=2,
   ),
   
   nn.ReLU(),
   nn.MaxPool2d(kernel_size=2),
  )
  
  self.conv2 = nn.Sequential(
   nn.Conv2d(16, 32, 5, 1, 2),
   nn.ReLU(),
   nn.MaxPool2d(2),
  )
  
  self.out = nn.Linear(32*7*7, 10) # fully connected layer, output 10 classes
  
 def forward(self, x):
  x = self.conv1(x)
  x = self.conv2(x)
  x = x.view(x.size(0), -1) # flatten the output of conv2 to (batch_size, 32*7*7)
  output = self.out(x)
  return output
 
optimizer = torch.optim.Adam(cnn.parameters(), lr=LR)
loss_func = nn.CrossEntropyLoss()
 
 from matplotlib import cm
try: from sklearn.manifold import TSNE; HAS_SK = True
except: HAS_SK = False; print('Please install sklearn for layer visualization')
def plot_with_labels(lowDWeights, labels):
 plt.cla()
 X, Y = lowDWeights[:, 0], lowDWeights[:, 1]
 for x, y, s in zip(X, Y, labels):
  c = cm.rainbow(int(255 * s / 9)); plt.text(x, y, s, backgroundcolor=c, fontsize=9)
 plt.xlim(X.min(), X.max()); plt.ylim(Y.min(), Y.max()); plt.title('Visualize last layer'); plt.show(); plt.pause(0.01)

plt.ion()

for epoch in range(EPOCH):
 for step, (b_x, b_y) in enumerate(train_loader):
  output = cnn(b_x)
  loss = loss_func(output, b_y)
  optimizer.zero_grad()
  loss.backward()
  optimizer.step()
  
  if step % 50 == 0:
   test_output = cnn(test_x)
   pred_y = torch.max(test_output, 1)[1].data.numpy()
   accuracy = float((pred_y == test_y.data.numpy()).astype(int).sum()) / float(test_y.size(0))
   print("Epoch: ", epoch, "| train loss: %.4f" % loss.data.numpy(), 
     "| test accuracy: %.2f" % accuracy)
   
plt.ioff()

以上就是本文的全部内容,希望对大家的学习有所帮助,也希望大家多多支持三水点靠木。

Python 相关文章推荐
利用python求相邻数的方法示例
Aug 18 Python
详解Python最长公共子串和最长公共子序列的实现
Jul 07 Python
Django中的ajax请求
Oct 19 Python
Python实战之制作天气查询软件
May 14 Python
python 利用浏览器 Cookie 模拟登录的用户访问知乎的方法
Jul 11 Python
Django 接收Post请求数据,并保存到数据库的实现方法
Jul 12 Python
python实现一行输入多个值和一行输出多个值的例子
Jul 16 Python
Python中 CSV格式清洗与转换的实例代码
Aug 29 Python
python爬虫scrapy框架之增量式爬虫的示例代码
Feb 26 Python
详解解Django 多对多表关系的三种创建方式
Aug 23 Python
Pygame Draw绘图函数的具体使用
Nov 17 Python
python百行代码实现汉服圈图片爬取
Nov 23 Python
python+opencv3生成一个自定义纯色图教程
Feb 19 #Python
Python 实现Image和Ndarray互相转换
Feb 19 #Python
python3+opencv生成不规则黑白mask实例
Feb 19 #Python
使用celery和Django处理异步任务的流程分析
Feb 19 #Python
Python Numpy,mask图像的生成详解
Feb 19 #Python
浅谈图像处理中掩膜(mask)的意义
Feb 19 #Python
Python中logging日志库实例详解
Feb 19 #Python
You might like
Laravel 5框架学习之环境与配置
2015/04/08 PHP
PHP常用字符串操作函数实例总结(trim、nl2br、addcslashes、uudecode、md5等)
2016/01/09 PHP
PHP 绘制网站登录首页图片验证码
2016/04/12 PHP
PHP 传输会话curl函数的实例详解
2017/09/12 PHP
php的扩展写法总结
2019/05/14 PHP
PHP的Trait机制原理与用法分析
2019/10/18 PHP
laravel邮件发送的实现代码示例
2020/01/31 PHP
Javascript中的Split使用方法与技巧
2007/03/09 Javascript
JS中的log对象获取以及debug的写法介绍
2014/03/03 Javascript
JS获取Table中td值的方法
2015/03/19 Javascript
javascript等号运算符使用详解
2015/04/16 Javascript
使用AngularJS中的SCE来防止XSS攻击的方法
2015/06/18 Javascript
jQuery移动web开发之页面跳转和加载外部页面的实现
2015/12/04 Javascript
ES6概念 ymbol.for()方法
2016/12/25 Javascript
Ionic+AngularJS实现登录和注册带验证功能
2017/02/09 Javascript
JavaScript数组去重的多种方法(四种)
2017/09/19 Javascript
vue动态渲染svg、添加点击事件的实现
2020/03/13 Javascript
[02:32]DOTA2英雄基础教程 祸乱之源
2013/12/23 DOTA
详解Python 模拟实现生产者消费者模式的实例
2017/08/10 Python
Python使用requests发送POST请求实例代码
2018/01/25 Python
Flask框架钩子函数功能与用法分析
2019/08/02 Python
使用Python将字符串转换为格式化的日期时间字符串
2019/09/01 Python
详解pyinstaller selenium python3 chrome打包问题
2019/10/18 Python
Python实现随机取一个矩阵数组的某几行
2019/11/26 Python
Django admin管理工具TabularInline类用法详解
2020/05/14 Python
经典广告词大全
2014/03/14 职场文书
求职自我推荐信
2014/06/25 职场文书
护士医德医风自我评价
2014/09/15 职场文书
幼儿园小班个人工作总结
2015/02/12 职场文书
2015社区健康教育工作总结
2015/05/20 职场文书
祝寿主持词
2015/07/02 职场文书
将图片保存到mysql数据库并展示在前端页面的实现代码
2021/05/02 MySQL
关于CentOS 8 搭建MongoDB4.4分片集群的问题
2021/10/24 MongoDB
angular异步验证器防抖实例详解
2022/03/31 Javascript
Golang 实现 WebSockets 之创建 WebSockets
2022/04/24 Golang
SQL Server中锁的用法
2022/05/20 SQL Server