用Pytorch训练CNN(数据集MNIST,使用GPU的方法)


Posted in Python onAugust 19, 2019

听说pytorch使用比TensorFlow简单,加之pytorch现已支持windows,所以今天装了pytorch玩玩,第一件事还是写了个简单的CNN在MNIST上实验,初步体验的确比TensorFlow方便。

参考代码(在莫烦python的教程代码基础上修改)如下:

import torch 
import torch.nn as nn 
from torch.autograd import Variable 
import torch.utils.data as Data 
import torchvision 
import time
#import matplotlib.pyplot as plt 
 
torch.manual_seed(1) 
 
EPOCH = 1 
BATCH_SIZE = 50 
LR = 0.001 
DOWNLOAD_MNIST = False 
if_use_gpu = 1
 
# 获取训练集dataset 
training_data = torchvision.datasets.MNIST( 
       root='./mnist/', # dataset存储路径 
       train=True, # True表示是train训练集,False表示test测试集 
       transform=torchvision.transforms.ToTensor(), # 将原数据规范化到(0,1)区间 
       download=DOWNLOAD_MNIST, 
       ) 
 
# 打印MNIST数据集的训练集及测试集的尺寸 
print(training_data.train_data.size()) 
print(training_data.train_labels.size()) 
# torch.Size([60000, 28, 28]) 
# torch.Size([60000]) 
 
#plt.imshow(training_data.train_data[0].numpy(), cmap='gray') 
#plt.title('%i' % training_data.train_labels[0]) 
#plt.show() 
 
# 通过torchvision.datasets获取的dataset格式可直接可置于DataLoader 
train_loader = Data.DataLoader(dataset=training_data, batch_size=BATCH_SIZE, 
                shuffle=True) 
 
# 获取测试集dataset 

test_data = torchvision.datasets.MNIST( 
       root='./mnist/', # dataset存储路径 
       train=False, # True表示是train训练集,False表示test测试集 
       transform=torchvision.transforms.ToTensor(), # 将原数据规范化到(0,1)区间 
       download=DOWNLOAD_MNIST, 
       ) 
# 取前全部10000个测试集样本 
test_x = Variable(torch.unsqueeze(test_data.test_data, dim=1).float(), requires_grad=False)
#test_x = test_x.cuda()
## (~, 28, 28) to (~, 1, 28, 28), in range(0,1) 
test_y = test_data.test_labels
#test_y = test_y.cuda() 
class CNN(nn.Module): 
  def __init__(self): 
    super(CNN, self).__init__() 
    self.conv1 = nn.Sequential( # (1,28,28) 
           nn.Conv2d(in_channels=1, out_channels=16, kernel_size=5, 
                stride=1, padding=2), # (16,28,28) 
    # 想要con2d卷积出来的图片尺寸没有变化, padding=(kernel_size-1)/2 
           nn.ReLU(), 
           nn.MaxPool2d(kernel_size=2) # (16,14,14) 
           ) 
    self.conv2 = nn.Sequential( # (16,14,14) 
           nn.Conv2d(16, 32, 5, 1, 2), # (32,14,14) 
           nn.ReLU(), 
           nn.MaxPool2d(2) # (32,7,7) 
           ) 
    self.out = nn.Linear(32*7*7, 10) 
 
  def forward(self, x): 
    x = self.conv1(x) 
    x = self.conv2(x) 
    x = x.view(x.size(0), -1) # 将(batch,32,7,7)展平为(batch,32*7*7) 
    output = self.out(x) 
    return output 
 
cnn = CNN() 
if if_use_gpu:
  cnn = cnn.cuda()

optimizer = torch.optim.Adam(cnn.parameters(), lr=LR) 
loss_function = nn.CrossEntropyLoss() 
 


for epoch in range(EPOCH): 
  start = time.time() 
  for step, (x, y) in enumerate(train_loader): 
    b_x = Variable(x, requires_grad=False) 
    b_y = Variable(y, requires_grad=False) 
    if if_use_gpu:
      b_x = b_x.cuda()
      b_y = b_y.cuda()
 
    output = cnn(b_x) 
    loss = loss_function(output, b_y) 
    optimizer.zero_grad() 
    loss.backward() 
    optimizer.step() 
 
    if step % 100 == 0: 
      print('Epoch:', epoch, '|Step:', step, 
         '|train loss:%.4f'%loss.data[0]) 
  duration = time.time() - start 
  print('Training duation: %.4f'%duration)
  
cnn = cnn.cpu()
test_output = cnn(test_x) 
pred_y = torch.max(test_output, 1)[1].data.squeeze()
accuracy = sum(pred_y == test_y) / test_y.size(0) 
print('Test Acc: %.4f'%accuracy)

以上这篇用Pytorch训练CNN(数据集MNIST,使用GPU的方法)就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
基于Python实现对PDF文件的OCR识别
Aug 05 Python
使用pyecharts在jupyter notebook上绘图
Apr 23 Python
Python算法之求n个节点不同二叉树个数
Oct 27 Python
Python最火、R极具潜力 2017机器学习调查报告
Dec 11 Python
Python使用xlwt模块操作Excel的方法详解
Mar 27 Python
Python读取Pickle文件信息并计算与当前时间间隔的方法分析
Jan 30 Python
pytorch多进程加速及代码优化方法
Aug 19 Python
pymysql模块的使用(增删改查)详解
Sep 09 Python
django商品分类及商品数据建模实例详解
Jan 03 Python
Python对象的属性访问过程详解
Mar 05 Python
python 两种方法修改文件的创建时间、修改时间、访问时间
Sep 26 Python
pandas中对文本类型数据的处理小结
Nov 01 Python
python opencv实现证件照换底功能
Aug 19 #Python
解决pytorch GPU 计算过程中出现内存耗尽的问题
Aug 19 #Python
将Pytorch模型从CPU转换成GPU的实现方法
Aug 19 #Python
pytorch 使用单个GPU与多个GPU进行训练与测试的方法
Aug 19 #Python
在pytorch中为Module和Tensor指定GPU的例子
Aug 19 #Python
pytorch使用指定GPU训练的实例
Aug 19 #Python
关于pytorch多GPU训练实例与性能对比分析
Aug 19 #Python
You might like
THINKPHP项目开发中的日志记录实例分析
2014/12/01 PHP
PHP yield关键字功能与用法分析
2019/01/03 PHP
php用xpath解析html的代码实例讲解
2019/02/14 PHP
js 图片缩放(按比例)控制代码
2009/05/27 Javascript
Js日期选择自动填充到输入框(界面漂亮兼容火狐)
2013/08/02 Javascript
使用jQuery和PHP实现类似360功能开关效果
2014/02/12 Javascript
通过正则表达式实现表单验证是否为中文
2014/02/18 Javascript
Javascript MVC框架Backbone.js详解
2014/09/18 Javascript
JavaScript中神奇的call()方法
2015/03/12 Javascript
js实现同一页面多个运动效果的方法
2015/04/10 Javascript
JavaScript点击按钮后弹出透明浮动层的方法
2015/05/11 Javascript
javascript数组随机排序实例分析
2015/07/22 Javascript
理解javascript正则表达式
2016/03/08 Javascript
jsonp跨域请求实现示例
2017/03/13 Javascript
js实现音频控制进度条功能
2017/04/01 Javascript
jquery 校验中国身份证号码实例详解
2017/04/11 jQuery
Angular2实现组件交互的方法分析
2017/12/19 Javascript
详解vue中axios的封装
2018/07/18 Javascript
微信小程序ibeacon三点定位详解
2018/10/31 Javascript
关于layui 下拉列表的change事件详解
2019/09/20 Javascript
[37:02]OG vs INfamous 2019国际邀请赛小组赛 BO2 第二场 8.15
2019/08/17 DOTA
web.py获取上传文件名的正确方法
2014/08/26 Python
Django imgareaselect手动剪切头像实现方法
2015/05/26 Python
PyCharm 常用快捷键和设置方法
2017/12/20 Python
numpy linalg模块的具体使用方法
2019/05/26 Python
Python Websocket服务端通信的使用示例
2020/02/25 Python
python爬虫beautifulsoup解析html方法
2020/12/07 Python
美国奢侈品购物平台:Orchard Mile
2018/05/02 全球购物
公证委托书大全
2014/04/04 职场文书
党员查摆问题及整改措施
2014/10/10 职场文书
2015年小学二年级班主任工作总结
2015/05/21 职场文书
2015大学生暑期实习报告
2015/07/13 职场文书
毕业生入职感言
2015/07/31 职场文书
自考生自我评价
2019/06/21 职场文书
thinkphp 获取控制器及控制器方法
2021/04/16 PHP
MySQL 如何限制一张表的记录数
2021/09/14 MySQL