用Pytorch训练CNN(数据集MNIST,使用GPU的方法)


Posted in Python onAugust 19, 2019

听说pytorch使用比TensorFlow简单,加之pytorch现已支持windows,所以今天装了pytorch玩玩,第一件事还是写了个简单的CNN在MNIST上实验,初步体验的确比TensorFlow方便。

参考代码(在莫烦python的教程代码基础上修改)如下:

import torch 
import torch.nn as nn 
from torch.autograd import Variable 
import torch.utils.data as Data 
import torchvision 
import time
#import matplotlib.pyplot as plt 
 
torch.manual_seed(1) 
 
EPOCH = 1 
BATCH_SIZE = 50 
LR = 0.001 
DOWNLOAD_MNIST = False 
if_use_gpu = 1
 
# 获取训练集dataset 
training_data = torchvision.datasets.MNIST( 
       root='./mnist/', # dataset存储路径 
       train=True, # True表示是train训练集,False表示test测试集 
       transform=torchvision.transforms.ToTensor(), # 将原数据规范化到(0,1)区间 
       download=DOWNLOAD_MNIST, 
       ) 
 
# 打印MNIST数据集的训练集及测试集的尺寸 
print(training_data.train_data.size()) 
print(training_data.train_labels.size()) 
# torch.Size([60000, 28, 28]) 
# torch.Size([60000]) 
 
#plt.imshow(training_data.train_data[0].numpy(), cmap='gray') 
#plt.title('%i' % training_data.train_labels[0]) 
#plt.show() 
 
# 通过torchvision.datasets获取的dataset格式可直接可置于DataLoader 
train_loader = Data.DataLoader(dataset=training_data, batch_size=BATCH_SIZE, 
                shuffle=True) 
 
# 获取测试集dataset 

test_data = torchvision.datasets.MNIST( 
       root='./mnist/', # dataset存储路径 
       train=False, # True表示是train训练集,False表示test测试集 
       transform=torchvision.transforms.ToTensor(), # 将原数据规范化到(0,1)区间 
       download=DOWNLOAD_MNIST, 
       ) 
# 取前全部10000个测试集样本 
test_x = Variable(torch.unsqueeze(test_data.test_data, dim=1).float(), requires_grad=False)
#test_x = test_x.cuda()
## (~, 28, 28) to (~, 1, 28, 28), in range(0,1) 
test_y = test_data.test_labels
#test_y = test_y.cuda() 
class CNN(nn.Module): 
  def __init__(self): 
    super(CNN, self).__init__() 
    self.conv1 = nn.Sequential( # (1,28,28) 
           nn.Conv2d(in_channels=1, out_channels=16, kernel_size=5, 
                stride=1, padding=2), # (16,28,28) 
    # 想要con2d卷积出来的图片尺寸没有变化, padding=(kernel_size-1)/2 
           nn.ReLU(), 
           nn.MaxPool2d(kernel_size=2) # (16,14,14) 
           ) 
    self.conv2 = nn.Sequential( # (16,14,14) 
           nn.Conv2d(16, 32, 5, 1, 2), # (32,14,14) 
           nn.ReLU(), 
           nn.MaxPool2d(2) # (32,7,7) 
           ) 
    self.out = nn.Linear(32*7*7, 10) 
 
  def forward(self, x): 
    x = self.conv1(x) 
    x = self.conv2(x) 
    x = x.view(x.size(0), -1) # 将(batch,32,7,7)展平为(batch,32*7*7) 
    output = self.out(x) 
    return output 
 
cnn = CNN() 
if if_use_gpu:
  cnn = cnn.cuda()

optimizer = torch.optim.Adam(cnn.parameters(), lr=LR) 
loss_function = nn.CrossEntropyLoss() 
 


for epoch in range(EPOCH): 
  start = time.time() 
  for step, (x, y) in enumerate(train_loader): 
    b_x = Variable(x, requires_grad=False) 
    b_y = Variable(y, requires_grad=False) 
    if if_use_gpu:
      b_x = b_x.cuda()
      b_y = b_y.cuda()
 
    output = cnn(b_x) 
    loss = loss_function(output, b_y) 
    optimizer.zero_grad() 
    loss.backward() 
    optimizer.step() 
 
    if step % 100 == 0: 
      print('Epoch:', epoch, '|Step:', step, 
         '|train loss:%.4f'%loss.data[0]) 
  duration = time.time() - start 
  print('Training duation: %.4f'%duration)
  
cnn = cnn.cpu()
test_output = cnn(test_x) 
pred_y = torch.max(test_output, 1)[1].data.squeeze()
accuracy = sum(pred_y == test_y) / test_y.size(0) 
print('Test Acc: %.4f'%accuracy)

以上这篇用Pytorch训练CNN(数据集MNIST,使用GPU的方法)就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
使用 Python 获取 Linux 系统信息的代码
Jul 13 Python
Python使用win32com模块实现数据库表结构自动生成word表格的方法
Jul 17 Python
python递归实现快速排序
Aug 18 Python
Python弹出输入框并获取输入值的实例
Jun 18 Python
Python : turtle色彩控制实例详解
Jan 19 Python
python except异常处理之后不退出,解决异常继续执行的实现
Apr 25 Python
python函数调用,循环,列表复制实例
May 03 Python
Java爬虫技术框架之Heritrix框架详解
Jul 22 Python
python爬虫构建代理ip池抓取数据库的示例代码
Sep 22 Python
Python基于内置函数type创建新类型
Oct 22 Python
一篇文章教你用python画动态爱心表白
Nov 22 Python
如何在向量化NumPy数组上进行移动窗口
May 18 Python
python opencv实现证件照换底功能
Aug 19 #Python
解决pytorch GPU 计算过程中出现内存耗尽的问题
Aug 19 #Python
将Pytorch模型从CPU转换成GPU的实现方法
Aug 19 #Python
pytorch 使用单个GPU与多个GPU进行训练与测试的方法
Aug 19 #Python
在pytorch中为Module和Tensor指定GPU的例子
Aug 19 #Python
pytorch使用指定GPU训练的实例
Aug 19 #Python
关于pytorch多GPU训练实例与性能对比分析
Aug 19 #Python
You might like
example1.php
2006/10/09 PHP
国外PHP程序员的13个好习惯小结
2012/02/20 PHP
php xml常用函数的集合(比较详细)
2013/06/06 PHP
php目录遍历函数opendir用法实例
2014/11/20 PHP
PHP递归创建多级目录
2015/11/05 PHP
从面试题学习Javascript 面向对象(创建对象)
2012/03/30 Javascript
jquery.Jwin.js 基于jquery的弹出层插件代码
2012/05/23 Javascript
js获取location.href的参数实例代码
2013/08/02 Javascript
js和php如何获取当前url的内容
2013/09/22 Javascript
jquery数组封装使用方法分享(jquery数组遍历)
2014/03/25 Javascript
jQuery实现的原图对比窗帘效果
2014/06/15 Javascript
angular.js之路由的选择方法
2016/09/24 Javascript
javascript入门之string对象【新手必看】
2016/11/22 Javascript
Bootstrap select下拉联动(jQuery cxselect)
2017/01/04 Javascript
微信小程序 Button 组件详解及简单实例
2017/01/10 Javascript
基于jQuery实现简单人工智能聊天室
2017/02/10 Javascript
详解.vue文件中监听input输入事件(oninput)
2017/09/19 Javascript
微信小程序实现简单表格
2019/02/14 Javascript
vue中的过滤器实例代码详解
2019/06/06 Javascript
node.js express框架简介与实现
2019/07/23 Javascript
在vue中使用echars实现上浮与下钻效果
2019/11/08 Javascript
JS实现前端动态分页码代码实例
2020/06/02 Javascript
js实现手表表盘时钟与圆周运动
2020/09/18 Javascript
详解Vite的新体验
2021/02/22 Javascript
Python实现的简单文件传输服务器和客户端
2015/04/08 Python
python使用自定义user-agent抓取网页的方法
2015/04/15 Python
Python实现的多线程http压力测试代码
2017/02/08 Python
python try 异常处理(史上最全)
2019/03/07 Python
pytorch实现特殊的Module--Sqeuential三种写法
2020/01/15 Python
树莓派4B安装Tensorflow的方法步骤
2020/07/16 Python
Java面试题:说出如下代码的执行结果
2015/10/30 面试题
自考毕业生自我鉴定
2013/11/04 职场文书
旅游管理专业生自荐信范文
2014/01/02 职场文书
宿舍违规用电检讨书
2014/02/16 职场文书
六年级语文教学反思
2016/03/03 职场文书
django项目、vue项目部署云服务器的详细过程
2022/07/23 Servers