pytorch实现CNN卷积神经网络


Posted in Python onFebruary 19, 2020

本文为大家讲解了pytorch实现CNN卷积神经网络,供大家参考,具体内容如下

我对卷积神经网络的一些认识

    卷积神经网络是时下最为流行的一种深度学习网络,由于其具有局部感受野等特性,让其与人眼识别图像具有相似性,因此被广泛应用于图像识别中,本人是研究机械故障诊断方面的,一般利用旋转机械的振动信号作为数据。

    对一维信号,通常采取的方法有两种,第一,直接对其做一维卷积,第二,反映到时频图像上,这就变成了图像识别,此前一直都在利用keras搭建网络,最近学了pytroch搭建cnn的方法,进行一下代码的尝试。所用数据为经典的minist手写字体数据集

import torch
import torch.nn as nn
import torch.utils.data as Data
import torchvision
import matplotlib.pyplot as plt
`EPOCH = 1
BATCH_SIZE = 50
LR = 0.001
DOWNLOAD_MNIST = True

从网上下载数据集:

```python
train_data = torchvision.datasets.MNIST(
 root="./mnist/",
 train = True,
 transform=torchvision.transforms.ToTensor(),
 download = DOWNLOAD_MNIST,
)

print(train_data.train_data.size())
print(train_data.train_labels.size())

```plt.imshow(train_data.train_data[0].numpy(), cmap='autumn')
plt.title("%i" % train_data.train_labels[0])
plt.show()

train_loader = Data.DataLoader(dataset=train_data, batch_size=BATCH_SIZE, shuffle=True)

test_data = torchvision.datasets.MNIST(root="./mnist/", train=False)
test_x = torch.unsqueeze(test_data.test_data, dim=1).type(torch.FloatTensor)[:2000]/255.

test_y = test_data.test_labels[:2000]


class CNN(nn.Module):
 def __init__(self):
  super(CNN, self).__init__()
  self.conv1 = nn.Sequential(
   nn.Conv2d(
    in_channels=1,
    out_channels=16,
    kernel_size=5,
    stride=1,
    padding=2,
   ),
   
   nn.ReLU(),
   nn.MaxPool2d(kernel_size=2),
  )
  
  self.conv2 = nn.Sequential(
   nn.Conv2d(16, 32, 5, 1, 2),
   nn.ReLU(),
   nn.MaxPool2d(2),
  )
  
  self.out = nn.Linear(32*7*7, 10) # fully connected layer, output 10 classes
  
 def forward(self, x):
  x = self.conv1(x)
  x = self.conv2(x)
  x = x.view(x.size(0), -1) # flatten the output of conv2 to (batch_size, 32*7*7)
  output = self.out(x)
  return output
 
optimizer = torch.optim.Adam(cnn.parameters(), lr=LR)
loss_func = nn.CrossEntropyLoss()
 
 from matplotlib import cm
try: from sklearn.manifold import TSNE; HAS_SK = True
except: HAS_SK = False; print('Please install sklearn for layer visualization')
def plot_with_labels(lowDWeights, labels):
 plt.cla()
 X, Y = lowDWeights[:, 0], lowDWeights[:, 1]
 for x, y, s in zip(X, Y, labels):
  c = cm.rainbow(int(255 * s / 9)); plt.text(x, y, s, backgroundcolor=c, fontsize=9)
 plt.xlim(X.min(), X.max()); plt.ylim(Y.min(), Y.max()); plt.title('Visualize last layer'); plt.show(); plt.pause(0.01)

plt.ion()

for epoch in range(EPOCH):
 for step, (b_x, b_y) in enumerate(train_loader):
  output = cnn(b_x)
  loss = loss_func(output, b_y)
  optimizer.zero_grad()
  loss.backward()
  optimizer.step()
  
  if step % 50 == 0:
   test_output = cnn(test_x)
   pred_y = torch.max(test_output, 1)[1].data.numpy()
   accuracy = float((pred_y == test_y.data.numpy()).astype(int).sum()) / float(test_y.size(0))
   print("Epoch: ", epoch, "| train loss: %.4f" % loss.data.numpy(), 
     "| test accuracy: %.2f" % accuracy)
   
plt.ioff()

以上就是本文的全部内容,希望对大家的学习有所帮助,也希望大家多多支持三水点靠木。

Python 相关文章推荐
Python实现的Excel文件读写类
Jul 30 Python
Python构建网页爬虫原理分析
Dec 19 Python
python的dataframe转换为多维矩阵的方法
Apr 11 Python
numpy中以文本的方式存储以及读取数据方法
Jun 04 Python
在python中用print()输出多个格式化参数的方法
Jul 16 Python
Django中使用CORS实现跨域请求过程解析
Aug 05 Python
pytorch 在sequential中使用view来reshape的例子
Aug 20 Python
python函数声明和调用定义及原理详解
Dec 02 Python
python 非线性规划方式(scipy.optimize.minimize)
Feb 11 Python
Python3 hashlib密码散列算法原理详解
Mar 30 Python
python matplotlib库的基本使用
Sep 23 Python
pycharm 如何取消连按两下shift出现的全局搜索
Jan 15 Python
python+opencv3生成一个自定义纯色图教程
Feb 19 #Python
Python 实现Image和Ndarray互相转换
Feb 19 #Python
python3+opencv生成不规则黑白mask实例
Feb 19 #Python
使用celery和Django处理异步任务的流程分析
Feb 19 #Python
Python Numpy,mask图像的生成详解
Feb 19 #Python
浅谈图像处理中掩膜(mask)的意义
Feb 19 #Python
Python中logging日志库实例详解
Feb 19 #Python
You might like
PHP下对数组进行排序的函数
2010/08/08 PHP
PHP中mysqli_affected_rows作用行数返回值分析
2014/12/26 PHP
PHPMailer发送邮件
2016/12/28 PHP
php+jQuery ajax实现的实时刷新显示数据功能示例
2019/09/12 PHP
自动更新作用
2006/10/08 Javascript
dojo 之基础篇(三)之向服务器发送数据
2007/03/24 Javascript
多浏览器兼容性比较好的复制到剪贴板的js代码
2011/10/09 Javascript
javascript中获取下个月一号,是星期几
2012/06/01 Javascript
js 控制页面跳转的5种方法
2013/09/09 Javascript
使用JQuery快速实现Tab的AJAX动态载入(实例讲解)
2013/12/11 Javascript
用Jquery实现滚动新闻
2014/02/12 Javascript
jQuery中 attr() 方法使用小结
2015/05/03 Javascript
JQuery球队选择实例
2015/05/18 Javascript
jQuery制作圣诞主题页面 更像是爱情影集
2016/08/10 Javascript
javaScript如何跳出多重循环break、continue
2016/09/01 Javascript
AngularJS实现数据列表的增加、删除和上移下移等功能实例
2016/09/05 Javascript
AngularJS指令中的绑定策略实例分析
2016/12/14 Javascript
将鼠标焦点定位到文本框最后(代码分享)
2017/01/11 Javascript
JS获取子、父、兄节点方法小结
2017/08/14 Javascript
jQuery实现倒计时功能 jQuery实现计时器功能
2017/09/19 jQuery
vue2过滤器模糊查询方法
2018/09/16 Javascript
Python实现根据指定端口探测服务器/模块部署的方法
2014/08/25 Python
Python 正则表达式实现计算器功能
2017/04/29 Python
详解python中 os._exit() 和 sys.exit(), exit(0)和exit(1) 的用法和区别
2017/06/23 Python
教你使用python画一朵花送女朋友
2018/03/29 Python
78行Python代码实现现微信撤回消息功能
2018/07/26 Python
CSS3弹性盒模型flex box快速入门心得(必看篇)
2016/05/24 HTML / CSS
土木工程建筑专业毕业生求职信
2013/10/21 职场文书
精细化工应届生求职信
2013/11/17 职场文书
自我评价范文
2013/12/22 职场文书
公司新年寄语
2014/04/04 职场文书
关于环保的建议书
2014/05/12 职场文书
公司外出活动方案
2014/08/14 职场文书
2015年公司工作总结
2015/04/25 职场文书
Java常用函数式接口总结
2021/06/29 Java/Android
weblogic服务建立数据源连接测试更新mysql驱动包的问题及解决方法
2022/01/22 MySQL