用Pytorch训练CNN(数据集MNIST,使用GPU的方法)


Posted in Python onAugust 19, 2019

听说pytorch使用比TensorFlow简单,加之pytorch现已支持windows,所以今天装了pytorch玩玩,第一件事还是写了个简单的CNN在MNIST上实验,初步体验的确比TensorFlow方便。

参考代码(在莫烦python的教程代码基础上修改)如下:

import torch 
import torch.nn as nn 
from torch.autograd import Variable 
import torch.utils.data as Data 
import torchvision 
import time
#import matplotlib.pyplot as plt 
 
torch.manual_seed(1) 
 
EPOCH = 1 
BATCH_SIZE = 50 
LR = 0.001 
DOWNLOAD_MNIST = False 
if_use_gpu = 1
 
# 获取训练集dataset 
training_data = torchvision.datasets.MNIST( 
       root='./mnist/', # dataset存储路径 
       train=True, # True表示是train训练集,False表示test测试集 
       transform=torchvision.transforms.ToTensor(), # 将原数据规范化到(0,1)区间 
       download=DOWNLOAD_MNIST, 
       ) 
 
# 打印MNIST数据集的训练集及测试集的尺寸 
print(training_data.train_data.size()) 
print(training_data.train_labels.size()) 
# torch.Size([60000, 28, 28]) 
# torch.Size([60000]) 
 
#plt.imshow(training_data.train_data[0].numpy(), cmap='gray') 
#plt.title('%i' % training_data.train_labels[0]) 
#plt.show() 
 
# 通过torchvision.datasets获取的dataset格式可直接可置于DataLoader 
train_loader = Data.DataLoader(dataset=training_data, batch_size=BATCH_SIZE, 
                shuffle=True) 
 
# 获取测试集dataset 

test_data = torchvision.datasets.MNIST( 
       root='./mnist/', # dataset存储路径 
       train=False, # True表示是train训练集,False表示test测试集 
       transform=torchvision.transforms.ToTensor(), # 将原数据规范化到(0,1)区间 
       download=DOWNLOAD_MNIST, 
       ) 
# 取前全部10000个测试集样本 
test_x = Variable(torch.unsqueeze(test_data.test_data, dim=1).float(), requires_grad=False)
#test_x = test_x.cuda()
## (~, 28, 28) to (~, 1, 28, 28), in range(0,1) 
test_y = test_data.test_labels
#test_y = test_y.cuda() 
class CNN(nn.Module): 
  def __init__(self): 
    super(CNN, self).__init__() 
    self.conv1 = nn.Sequential( # (1,28,28) 
           nn.Conv2d(in_channels=1, out_channels=16, kernel_size=5, 
                stride=1, padding=2), # (16,28,28) 
    # 想要con2d卷积出来的图片尺寸没有变化, padding=(kernel_size-1)/2 
           nn.ReLU(), 
           nn.MaxPool2d(kernel_size=2) # (16,14,14) 
           ) 
    self.conv2 = nn.Sequential( # (16,14,14) 
           nn.Conv2d(16, 32, 5, 1, 2), # (32,14,14) 
           nn.ReLU(), 
           nn.MaxPool2d(2) # (32,7,7) 
           ) 
    self.out = nn.Linear(32*7*7, 10) 
 
  def forward(self, x): 
    x = self.conv1(x) 
    x = self.conv2(x) 
    x = x.view(x.size(0), -1) # 将(batch,32,7,7)展平为(batch,32*7*7) 
    output = self.out(x) 
    return output 
 
cnn = CNN() 
if if_use_gpu:
  cnn = cnn.cuda()

optimizer = torch.optim.Adam(cnn.parameters(), lr=LR) 
loss_function = nn.CrossEntropyLoss() 
 


for epoch in range(EPOCH): 
  start = time.time() 
  for step, (x, y) in enumerate(train_loader): 
    b_x = Variable(x, requires_grad=False) 
    b_y = Variable(y, requires_grad=False) 
    if if_use_gpu:
      b_x = b_x.cuda()
      b_y = b_y.cuda()
 
    output = cnn(b_x) 
    loss = loss_function(output, b_y) 
    optimizer.zero_grad() 
    loss.backward() 
    optimizer.step() 
 
    if step % 100 == 0: 
      print('Epoch:', epoch, '|Step:', step, 
         '|train loss:%.4f'%loss.data[0]) 
  duration = time.time() - start 
  print('Training duation: %.4f'%duration)
  
cnn = cnn.cpu()
test_output = cnn(test_x) 
pred_y = torch.max(test_output, 1)[1].data.squeeze()
accuracy = sum(pred_y == test_y) / test_y.size(0) 
print('Test Acc: %.4f'%accuracy)

以上这篇用Pytorch训练CNN(数据集MNIST,使用GPU的方法)就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
python字典多条件排序方法实例
Jun 30 Python
在Python的setuptools框架下生成egg的教程
Apr 13 Python
Python发送form-data请求及拼接form-data内容的方法
Mar 05 Python
python实现图像识别功能
Jan 29 Python
Python图片转换成矩阵,矩阵数据转换成图片的实例
Jul 02 Python
在Pycharm中将pyinstaller加入External Tools的方法
Jan 16 Python
详解python3 + Scrapy爬虫学习之创建项目
Apr 12 Python
一行python实现树形结构的方法
Aug 09 Python
pandas 像SQL一样使用WHERE IN查询条件说明
Jun 05 Python
解决python便携版无法直接运行py文件的问题
Sep 01 Python
python opencv实现图像配准与比较
Feb 09 Python
详解OpenCV获取高动态范围(HDR)成像
Apr 29 Python
python opencv实现证件照换底功能
Aug 19 #Python
解决pytorch GPU 计算过程中出现内存耗尽的问题
Aug 19 #Python
将Pytorch模型从CPU转换成GPU的实现方法
Aug 19 #Python
pytorch 使用单个GPU与多个GPU进行训练与测试的方法
Aug 19 #Python
在pytorch中为Module和Tensor指定GPU的例子
Aug 19 #Python
pytorch使用指定GPU训练的实例
Aug 19 #Python
关于pytorch多GPU训练实例与性能对比分析
Aug 19 #Python
You might like
Mysql的常用命令
2006/10/09 PHP
简单采集了yahoo的一些数据
2007/02/14 PHP
Drupal7 form表单二次开发要点与实例
2014/03/02 PHP
PHP缓存机制Output Control详解
2014/07/14 PHP
在修改准备发的批量美化select+可修改select时,在非IE下发现了几个问题
2007/01/09 Javascript
JS面向对象编程 for Cookie
2010/09/19 Javascript
js 获取(接收)地址栏参数值的方法
2013/04/01 Javascript
JavaScript对象学习经验整理
2013/10/12 Javascript
Javascript 绘制 sin 曲线过程附图
2014/08/21 Javascript
谈谈js中的prototype及prototype属性解释和常用方法
2015/11/25 Javascript
AngularJS ng-change 指令的详解及简单实例
2016/07/30 Javascript
Javascript读写cookie的实例源码
2019/03/16 Javascript
Vue 利用指令实现禁止反复发送请求的两种方法
2019/09/15 Javascript
Vue数字输入框组件示例代码详解
2020/01/15 Javascript
js实现盒子移动动画效果
2020/08/09 Javascript
python调用c++ ctype list传数组或者返回数组的方法
2019/02/13 Python
python实现kNN算法识别手写体数字的示例代码
2019/08/16 Python
python单向链表的基本实现与使用方法【定义、遍历、添加、删除、查找等】
2019/10/24 Python
如何设置PyCharm中的Python代码模版(推荐)
2020/11/20 Python
CSS3实现曲线阴影和翘边阴影
2016/05/03 HTML / CSS
HTML5的结构和语义(5):内嵌媒体
2008/10/17 HTML / CSS
Huda Beauty官方商店:化妆和美容产品
2020/09/05 全球购物
矫正人员思想汇报
2014/01/08 职场文书
初一生物教学反思
2014/01/18 职场文书
入股协议书
2014/04/14 职场文书
导师工作推荐信范文
2014/05/17 职场文书
小学教师师德师风自我剖析材料
2014/09/29 职场文书
毕业设计论文致谢词
2015/05/14 职场文书
爱国电影观后感
2015/06/19 职场文书
幼儿园开学温馨提示
2015/07/15 职场文书
2015年庆祝国庆节66周年演讲稿
2015/07/30 职场文书
二年级作文之动物作文
2019/11/13 职场文书
python爬取网页版QQ空间,生成各类图表
2021/06/02 Python
「我的青春恋爱物语果然有问题。-妄言录-」第20卷封面公开
2022/03/21 日漫
安装Ruby和 Rails的详细步骤
2022/04/19 Ruby
MySQL提取JSON字段数据实现查询
2022/04/22 MySQL