用Pytorch训练CNN(数据集MNIST,使用GPU的方法)


Posted in Python onAugust 19, 2019

听说pytorch使用比TensorFlow简单,加之pytorch现已支持windows,所以今天装了pytorch玩玩,第一件事还是写了个简单的CNN在MNIST上实验,初步体验的确比TensorFlow方便。

参考代码(在莫烦python的教程代码基础上修改)如下:

import torch 
import torch.nn as nn 
from torch.autograd import Variable 
import torch.utils.data as Data 
import torchvision 
import time
#import matplotlib.pyplot as plt 
 
torch.manual_seed(1) 
 
EPOCH = 1 
BATCH_SIZE = 50 
LR = 0.001 
DOWNLOAD_MNIST = False 
if_use_gpu = 1
 
# 获取训练集dataset 
training_data = torchvision.datasets.MNIST( 
       root='./mnist/', # dataset存储路径 
       train=True, # True表示是train训练集,False表示test测试集 
       transform=torchvision.transforms.ToTensor(), # 将原数据规范化到(0,1)区间 
       download=DOWNLOAD_MNIST, 
       ) 
 
# 打印MNIST数据集的训练集及测试集的尺寸 
print(training_data.train_data.size()) 
print(training_data.train_labels.size()) 
# torch.Size([60000, 28, 28]) 
# torch.Size([60000]) 
 
#plt.imshow(training_data.train_data[0].numpy(), cmap='gray') 
#plt.title('%i' % training_data.train_labels[0]) 
#plt.show() 
 
# 通过torchvision.datasets获取的dataset格式可直接可置于DataLoader 
train_loader = Data.DataLoader(dataset=training_data, batch_size=BATCH_SIZE, 
                shuffle=True) 
 
# 获取测试集dataset 

test_data = torchvision.datasets.MNIST( 
       root='./mnist/', # dataset存储路径 
       train=False, # True表示是train训练集,False表示test测试集 
       transform=torchvision.transforms.ToTensor(), # 将原数据规范化到(0,1)区间 
       download=DOWNLOAD_MNIST, 
       ) 
# 取前全部10000个测试集样本 
test_x = Variable(torch.unsqueeze(test_data.test_data, dim=1).float(), requires_grad=False)
#test_x = test_x.cuda()
## (~, 28, 28) to (~, 1, 28, 28), in range(0,1) 
test_y = test_data.test_labels
#test_y = test_y.cuda() 
class CNN(nn.Module): 
  def __init__(self): 
    super(CNN, self).__init__() 
    self.conv1 = nn.Sequential( # (1,28,28) 
           nn.Conv2d(in_channels=1, out_channels=16, kernel_size=5, 
                stride=1, padding=2), # (16,28,28) 
    # 想要con2d卷积出来的图片尺寸没有变化, padding=(kernel_size-1)/2 
           nn.ReLU(), 
           nn.MaxPool2d(kernel_size=2) # (16,14,14) 
           ) 
    self.conv2 = nn.Sequential( # (16,14,14) 
           nn.Conv2d(16, 32, 5, 1, 2), # (32,14,14) 
           nn.ReLU(), 
           nn.MaxPool2d(2) # (32,7,7) 
           ) 
    self.out = nn.Linear(32*7*7, 10) 
 
  def forward(self, x): 
    x = self.conv1(x) 
    x = self.conv2(x) 
    x = x.view(x.size(0), -1) # 将(batch,32,7,7)展平为(batch,32*7*7) 
    output = self.out(x) 
    return output 
 
cnn = CNN() 
if if_use_gpu:
  cnn = cnn.cuda()

optimizer = torch.optim.Adam(cnn.parameters(), lr=LR) 
loss_function = nn.CrossEntropyLoss() 
 


for epoch in range(EPOCH): 
  start = time.time() 
  for step, (x, y) in enumerate(train_loader): 
    b_x = Variable(x, requires_grad=False) 
    b_y = Variable(y, requires_grad=False) 
    if if_use_gpu:
      b_x = b_x.cuda()
      b_y = b_y.cuda()
 
    output = cnn(b_x) 
    loss = loss_function(output, b_y) 
    optimizer.zero_grad() 
    loss.backward() 
    optimizer.step() 
 
    if step % 100 == 0: 
      print('Epoch:', epoch, '|Step:', step, 
         '|train loss:%.4f'%loss.data[0]) 
  duration = time.time() - start 
  print('Training duation: %.4f'%duration)
  
cnn = cnn.cpu()
test_output = cnn(test_x) 
pred_y = torch.max(test_output, 1)[1].data.squeeze()
accuracy = sum(pred_y == test_y) / test_y.size(0) 
print('Test Acc: %.4f'%accuracy)

以上这篇用Pytorch训练CNN(数据集MNIST,使用GPU的方法)就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
Python多线程编程(二):启动线程的两种方法
Apr 05 Python
Python中的面向对象编程详解(上)
Apr 13 Python
浅析Python中MySQLdb的事务处理功能
Sep 21 Python
使用Python搭建虚拟环境的配置方法
Feb 28 Python
python实现五子棋小游戏
Mar 25 Python
详解Python Qt的窗体开发的基本操作
Jul 14 Python
Python 字符串处理特殊空格\xc2\xa0\t\n Non-breaking space
Feb 23 Python
python实现梯度法 python最速下降法
Mar 24 Python
利用python 下载bilibili视频
Nov 13 Python
Pycharm在指定目录下生成文件和删除文件的实现
Dec 28 Python
python中yield的用法详解
Jan 13 Python
用Python爬取各大高校并可视化帮弟弟选大学,弟弟直呼牛X
Jun 11 Python
python opencv实现证件照换底功能
Aug 19 #Python
解决pytorch GPU 计算过程中出现内存耗尽的问题
Aug 19 #Python
将Pytorch模型从CPU转换成GPU的实现方法
Aug 19 #Python
pytorch 使用单个GPU与多个GPU进行训练与测试的方法
Aug 19 #Python
在pytorch中为Module和Tensor指定GPU的例子
Aug 19 #Python
pytorch使用指定GPU训练的实例
Aug 19 #Python
关于pytorch多GPU训练实例与性能对比分析
Aug 19 #Python
You might like
用mysql触发器自动更新memcache的实现代码
2009/10/11 PHP
php数组去重的函数代码
2013/02/03 PHP
基于PHP导出Excel的小经验 完美解决乱码问题
2013/06/10 PHP
php配置php-fpm启动参数及配置详解
2013/11/04 PHP
PHP在网页中动态生成PDF文件详细教程
2014/07/05 PHP
php图片水印添加、压缩、剪切的封装类实现
2020/04/18 PHP
PHP基于DOM创建xml文档的方法示例
2017/02/08 PHP
同一个表单 根据要求递交到不同页面的实现方法小结
2009/08/05 Javascript
如何让easyui gridview 宽度自适应窗口改变及fitColumns应用
2013/01/25 Javascript
多种方法实现360浏览器下禁止自动填写用户名密码
2014/06/16 Javascript
微信小程序开发之实现选项卡(窗口顶部TabBar)页面切换
2016/11/25 Javascript
AngularJs实现聊天列表实时刷新功能
2017/06/15 Javascript
Vue中正确使用jQuery的方法
2017/10/30 jQuery
vue移动端路由切换实例分析
2018/05/14 Javascript
深入浅出理解JavaScript闭包的功能与用法
2018/08/01 Javascript
laydate只显示时分 不显示秒的功能实现方法
2019/09/28 Javascript
python文件和目录操作方法大全(含实例)
2014/03/12 Python
对于Python的Django框架使用的一些实用建议
2015/04/03 Python
python+pyqt实现右下角弹出框
2017/10/26 Python
《与孩子一起学编程》python自测题
2018/05/27 Python
django session完成状态保持的方法
2018/11/27 Python
python主线程与子线程的结束顺序实例解析
2019/12/17 Python
Python3 xml.etree.ElementTree支持的XPath语法详解
2020/03/06 Python
Python Pygame实现俄罗斯方块
2021/02/19 Python
英国电子产品购物网站:Tech in the basket
2019/11/08 全球购物
优秀学生自我鉴定范例
2013/12/18 职场文书
五年级音乐教学反思
2014/02/06 职场文书
《草虫的村落》教学反思
2014/02/16 职场文书
团队队名口号大全
2014/06/06 职场文书
新教师培训心得体会
2014/09/02 职场文书
2014年办公室个人工作总结
2014/11/12 职场文书
2015年小学生自我评价范文
2015/03/03 职场文书
装饰技术负责人岗位职责
2015/04/13 职场文书
2015年学生资助工作总结
2015/05/25 职场文书
未婚证明范本
2015/06/15 职场文书
Python实现简单得递归下降Parser
2022/05/02 Python