pytorch实现CNN卷积神经网络


Posted in Python onFebruary 19, 2020

本文为大家讲解了pytorch实现CNN卷积神经网络,供大家参考,具体内容如下

我对卷积神经网络的一些认识

    卷积神经网络是时下最为流行的一种深度学习网络,由于其具有局部感受野等特性,让其与人眼识别图像具有相似性,因此被广泛应用于图像识别中,本人是研究机械故障诊断方面的,一般利用旋转机械的振动信号作为数据。

    对一维信号,通常采取的方法有两种,第一,直接对其做一维卷积,第二,反映到时频图像上,这就变成了图像识别,此前一直都在利用keras搭建网络,最近学了pytroch搭建cnn的方法,进行一下代码的尝试。所用数据为经典的minist手写字体数据集

import torch
import torch.nn as nn
import torch.utils.data as Data
import torchvision
import matplotlib.pyplot as plt
`EPOCH = 1
BATCH_SIZE = 50
LR = 0.001
DOWNLOAD_MNIST = True

从网上下载数据集:

```python
train_data = torchvision.datasets.MNIST(
 root="./mnist/",
 train = True,
 transform=torchvision.transforms.ToTensor(),
 download = DOWNLOAD_MNIST,
)

print(train_data.train_data.size())
print(train_data.train_labels.size())

```plt.imshow(train_data.train_data[0].numpy(), cmap='autumn')
plt.title("%i" % train_data.train_labels[0])
plt.show()

train_loader = Data.DataLoader(dataset=train_data, batch_size=BATCH_SIZE, shuffle=True)

test_data = torchvision.datasets.MNIST(root="./mnist/", train=False)
test_x = torch.unsqueeze(test_data.test_data, dim=1).type(torch.FloatTensor)[:2000]/255.

test_y = test_data.test_labels[:2000]


class CNN(nn.Module):
 def __init__(self):
  super(CNN, self).__init__()
  self.conv1 = nn.Sequential(
   nn.Conv2d(
    in_channels=1,
    out_channels=16,
    kernel_size=5,
    stride=1,
    padding=2,
   ),
   
   nn.ReLU(),
   nn.MaxPool2d(kernel_size=2),
  )
  
  self.conv2 = nn.Sequential(
   nn.Conv2d(16, 32, 5, 1, 2),
   nn.ReLU(),
   nn.MaxPool2d(2),
  )
  
  self.out = nn.Linear(32*7*7, 10) # fully connected layer, output 10 classes
  
 def forward(self, x):
  x = self.conv1(x)
  x = self.conv2(x)
  x = x.view(x.size(0), -1) # flatten the output of conv2 to (batch_size, 32*7*7)
  output = self.out(x)
  return output
 
optimizer = torch.optim.Adam(cnn.parameters(), lr=LR)
loss_func = nn.CrossEntropyLoss()
 
 from matplotlib import cm
try: from sklearn.manifold import TSNE; HAS_SK = True
except: HAS_SK = False; print('Please install sklearn for layer visualization')
def plot_with_labels(lowDWeights, labels):
 plt.cla()
 X, Y = lowDWeights[:, 0], lowDWeights[:, 1]
 for x, y, s in zip(X, Y, labels):
  c = cm.rainbow(int(255 * s / 9)); plt.text(x, y, s, backgroundcolor=c, fontsize=9)
 plt.xlim(X.min(), X.max()); plt.ylim(Y.min(), Y.max()); plt.title('Visualize last layer'); plt.show(); plt.pause(0.01)

plt.ion()

for epoch in range(EPOCH):
 for step, (b_x, b_y) in enumerate(train_loader):
  output = cnn(b_x)
  loss = loss_func(output, b_y)
  optimizer.zero_grad()
  loss.backward()
  optimizer.step()
  
  if step % 50 == 0:
   test_output = cnn(test_x)
   pred_y = torch.max(test_output, 1)[1].data.numpy()
   accuracy = float((pred_y == test_y.data.numpy()).astype(int).sum()) / float(test_y.size(0))
   print("Epoch: ", epoch, "| train loss: %.4f" % loss.data.numpy(), 
     "| test accuracy: %.2f" % accuracy)
   
plt.ioff()

以上就是本文的全部内容,希望对大家的学习有所帮助,也希望大家多多支持三水点靠木。

Python 相关文章推荐
python实现simhash算法实例
Apr 25 Python
python中的__init__ 、__new__、__call__小结
Apr 25 Python
举例讲解Python编程中对线程锁的使用
Jul 12 Python
简单实现python聊天程序
Apr 01 Python
对Python 3.2 迭代器的next函数实例讲解
Oct 18 Python
python+mysql实现学生信息查询系统
Feb 21 Python
Python实现九宫格式的朋友圈功能内附“马云”朋友圈
May 07 Python
python函数声明和调用定义及原理详解
Dec 02 Python
Python集成开发工具Pycharm的安装和使用详解
Mar 18 Python
PyQt5 界面显示无响应的实现
Mar 26 Python
解决keras GAN训练是loss不发生变化,accuracy一直为0.5的问题
Jul 02 Python
用Python提取PDF表格的方法
Apr 11 Python
python+opencv3生成一个自定义纯色图教程
Feb 19 #Python
Python 实现Image和Ndarray互相转换
Feb 19 #Python
python3+opencv生成不规则黑白mask实例
Feb 19 #Python
使用celery和Django处理异步任务的流程分析
Feb 19 #Python
Python Numpy,mask图像的生成详解
Feb 19 #Python
浅谈图像处理中掩膜(mask)的意义
Feb 19 #Python
Python中logging日志库实例详解
Feb 19 #Python
You might like
一个更简单的无限级分类菜单代码
2007/01/16 PHP
利用ThinkPHP内置的ThinkAjax实现异步传输技术的实现方法
2011/12/19 PHP
php里array_work用法实例分析
2015/07/13 PHP
php实现面包屑导航例子分享
2015/12/19 PHP
Yii使用技巧大汇总
2015/12/29 PHP
使用laravel的migrate创建数据表的方法
2019/09/30 PHP
Alliance vs Liquid BO3 第三场2.13
2021/03/10 DOTA
通过继承IHttpHandle实现JS插件的组织与管理
2010/07/13 Javascript
JS实现的手机端精简幻灯片效果
2016/09/05 Javascript
手动初始化Angular的模块与控制器
2016/12/26 Javascript
解析jquery easyui tree异步加载子节点问题
2017/03/08 Javascript
JS 中document.write()的用法和清空的原因浅析
2017/12/04 Javascript
js经验分享 JavaScript反调试技巧
2018/03/10 Javascript
Vue filter格式化时间戳时间成标准日期格式的方法
2018/09/16 Javascript
小程序实现授权登陆的解决方案
2018/12/02 Javascript
JS与SQL方式随机生成高强度密码示例
2018/12/29 Javascript
详解如何写出一个利于扩展的vue路由配置
2019/05/16 Javascript
[46:42]DOTA2-DPC中国联赛正赛 Aster vs Magma BO3 第二场 3月5日
2021/03/11 DOTA
Python3字符串学习教程
2015/08/20 Python
最大K个数问题的Python版解法总结
2016/06/16 Python
python中字符串内置函数的用法总结
2018/09/13 Python
python多线程并发让两个LED同时亮的方法
2019/02/18 Python
Python中PyQt5/PySide2的按钮控件使用实例
2019/08/17 Python
Python基于gevent实现文件字符串查找器
2020/08/11 Python
Pretty Little Thing爱尔兰:时尚女性服饰
2017/03/27 全球购物
迪拜航空官方网站:flydubai
2017/04/20 全球购物
发现世界上最好的珠宝设计师:JewelStreet
2017/12/17 全球购物
Myprotein台湾官方网站:全球领先的运动营养品牌
2018/12/10 全球购物
.NET程序员的数据库面试题
2012/10/10 面试题
卫校中专生的自我评价
2014/01/15 职场文书
申报材料格式
2014/12/30 职场文书
单位租车协议书
2015/01/29 职场文书
行政经理岗位职责
2015/04/15 职场文书
2019年房屋委托租赁合同范本(通用版)!
2019/07/17 职场文书
人生感悟经典句子
2019/08/20 职场文书
python基础入门之字典和集合
2021/06/13 Python