用Pytorch训练CNN(数据集MNIST,使用GPU的方法)


Posted in Python onAugust 19, 2019

听说pytorch使用比TensorFlow简单,加之pytorch现已支持windows,所以今天装了pytorch玩玩,第一件事还是写了个简单的CNN在MNIST上实验,初步体验的确比TensorFlow方便。

参考代码(在莫烦python的教程代码基础上修改)如下:

import torch 
import torch.nn as nn 
from torch.autograd import Variable 
import torch.utils.data as Data 
import torchvision 
import time
#import matplotlib.pyplot as plt 
 
torch.manual_seed(1) 
 
EPOCH = 1 
BATCH_SIZE = 50 
LR = 0.001 
DOWNLOAD_MNIST = False 
if_use_gpu = 1
 
# 获取训练集dataset 
training_data = torchvision.datasets.MNIST( 
       root='./mnist/', # dataset存储路径 
       train=True, # True表示是train训练集,False表示test测试集 
       transform=torchvision.transforms.ToTensor(), # 将原数据规范化到(0,1)区间 
       download=DOWNLOAD_MNIST, 
       ) 
 
# 打印MNIST数据集的训练集及测试集的尺寸 
print(training_data.train_data.size()) 
print(training_data.train_labels.size()) 
# torch.Size([60000, 28, 28]) 
# torch.Size([60000]) 
 
#plt.imshow(training_data.train_data[0].numpy(), cmap='gray') 
#plt.title('%i' % training_data.train_labels[0]) 
#plt.show() 
 
# 通过torchvision.datasets获取的dataset格式可直接可置于DataLoader 
train_loader = Data.DataLoader(dataset=training_data, batch_size=BATCH_SIZE, 
                shuffle=True) 
 
# 获取测试集dataset 

test_data = torchvision.datasets.MNIST( 
       root='./mnist/', # dataset存储路径 
       train=False, # True表示是train训练集,False表示test测试集 
       transform=torchvision.transforms.ToTensor(), # 将原数据规范化到(0,1)区间 
       download=DOWNLOAD_MNIST, 
       ) 
# 取前全部10000个测试集样本 
test_x = Variable(torch.unsqueeze(test_data.test_data, dim=1).float(), requires_grad=False)
#test_x = test_x.cuda()
## (~, 28, 28) to (~, 1, 28, 28), in range(0,1) 
test_y = test_data.test_labels
#test_y = test_y.cuda() 
class CNN(nn.Module): 
  def __init__(self): 
    super(CNN, self).__init__() 
    self.conv1 = nn.Sequential( # (1,28,28) 
           nn.Conv2d(in_channels=1, out_channels=16, kernel_size=5, 
                stride=1, padding=2), # (16,28,28) 
    # 想要con2d卷积出来的图片尺寸没有变化, padding=(kernel_size-1)/2 
           nn.ReLU(), 
           nn.MaxPool2d(kernel_size=2) # (16,14,14) 
           ) 
    self.conv2 = nn.Sequential( # (16,14,14) 
           nn.Conv2d(16, 32, 5, 1, 2), # (32,14,14) 
           nn.ReLU(), 
           nn.MaxPool2d(2) # (32,7,7) 
           ) 
    self.out = nn.Linear(32*7*7, 10) 
 
  def forward(self, x): 
    x = self.conv1(x) 
    x = self.conv2(x) 
    x = x.view(x.size(0), -1) # 将(batch,32,7,7)展平为(batch,32*7*7) 
    output = self.out(x) 
    return output 
 
cnn = CNN() 
if if_use_gpu:
  cnn = cnn.cuda()

optimizer = torch.optim.Adam(cnn.parameters(), lr=LR) 
loss_function = nn.CrossEntropyLoss() 
 


for epoch in range(EPOCH): 
  start = time.time() 
  for step, (x, y) in enumerate(train_loader): 
    b_x = Variable(x, requires_grad=False) 
    b_y = Variable(y, requires_grad=False) 
    if if_use_gpu:
      b_x = b_x.cuda()
      b_y = b_y.cuda()
 
    output = cnn(b_x) 
    loss = loss_function(output, b_y) 
    optimizer.zero_grad() 
    loss.backward() 
    optimizer.step() 
 
    if step % 100 == 0: 
      print('Epoch:', epoch, '|Step:', step, 
         '|train loss:%.4f'%loss.data[0]) 
  duration = time.time() - start 
  print('Training duation: %.4f'%duration)
  
cnn = cnn.cpu()
test_output = cnn(test_x) 
pred_y = torch.max(test_output, 1)[1].data.squeeze()
accuracy = sum(pred_y == test_y) / test_y.size(0) 
print('Test Acc: %.4f'%accuracy)

以上这篇用Pytorch训练CNN(数据集MNIST,使用GPU的方法)就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
python自动化脚本安装指定版本python环境详解
Sep 14 Python
Python多进程库multiprocessing中进程池Pool类的使用详解
Nov 24 Python
Pycharm设置去除显示的波浪线方法
Oct 28 Python
pycharm运行程序时在Python console窗口中运行的方法
Dec 03 Python
Python matplotlib的使用并自定义colormap的方法
Dec 13 Python
对python实现模板生成脚本的方法详解
Jan 30 Python
python matplotlib折线图样式实现过程
Nov 04 Python
执行Python程序时模块报错问题
Mar 26 Python
python:HDF和CSV存储优劣对比分析
Jun 08 Python
opencv 阈值分割的具体使用
Jul 08 Python
互斥锁解决 Python 中多线程共享全局变量的问题(推荐)
Sep 28 Python
基于Python的EasyGUI学习实践
May 07 Python
python opencv实现证件照换底功能
Aug 19 #Python
解决pytorch GPU 计算过程中出现内存耗尽的问题
Aug 19 #Python
将Pytorch模型从CPU转换成GPU的实现方法
Aug 19 #Python
pytorch 使用单个GPU与多个GPU进行训练与测试的方法
Aug 19 #Python
在pytorch中为Module和Tensor指定GPU的例子
Aug 19 #Python
pytorch使用指定GPU训练的实例
Aug 19 #Python
关于pytorch多GPU训练实例与性能对比分析
Aug 19 #Python
You might like
动漫定律:眯眯眼都是怪物!这些角色狠话不多~
2020/03/03 日漫
通过table标签,PHP输出EXCEL的实现方法
2013/07/24 PHP
PHP获取文件扩展名的方法实例总结
2017/06/10 PHP
Docker搭建自己的PHP开发环境
2018/02/24 PHP
Jquery实现无刷新DropDownList联动实现代码
2010/03/08 Javascript
JS中的this变量的使用介绍
2013/10/21 Javascript
jQuery设置div一直在页面顶部显示的方法
2013/10/24 Javascript
jquery实现页面关键词高亮显示的方法
2015/03/12 Javascript
JS实现仿QQ效果的三级竖向菜单
2015/09/25 Javascript
jQuery实现表格元素动态创建功能
2017/01/09 Javascript
整理关于Bootstrap导航的慕课笔记
2017/03/29 Javascript
two.js之实现动画效果示例
2017/11/06 Javascript
jQuery EasyUI 选项卡面板tabs的使用实例讲解
2017/12/25 jQuery
基于ajax实现上传图片代码示例解析
2020/12/03 Javascript
vue中配置scss全局变量的步骤
2020/12/28 Vue.js
详解vite2.0配置学习(typescript版本)
2021/02/25 Javascript
一波神奇的Python语句、函数与方法的使用技巧总结
2015/12/08 Python
Python中三元表达式的几种写法介绍
2019/03/04 Python
Python实现查找字符串数组最长公共前缀示例
2019/03/27 Python
python判断一个对象是否可迭代的例子
2019/07/22 Python
详解python中的数据类型和控制流
2019/08/08 Python
python GUI库图形界面开发之PyQt5菜单栏控件QMenuBar的详细使用方法与实例
2020/02/28 Python
python数据预处理 :样本分布不均的解决(过采样和欠采样)
2020/02/29 Python
python实现飞船大战
2020/04/24 Python
python实现二分查找算法
2020/09/18 Python
澳大利亚办公室装修:JasonL Office Furniture
2019/06/25 全球购物
Java语言的优势
2015/01/10 面试题
心理健康活动总结
2014/04/30 职场文书
小学运动会开幕词
2015/01/28 职场文书
项目经理助理岗位职责
2015/04/13 职场文书
2015年团支部年度工作总结
2015/05/27 职场文书
领导新年致辞2016
2015/07/29 职场文书
python - timeit 时间模块
2021/04/06 Python
Python基础之赋值,浅拷贝,深拷贝的区别
2021/04/30 Python
python利用pandas分析学生期末成绩实例代码
2021/07/09 Python
uniapp引入支付宝原生扫码插件步骤详解
2022/07/23 Javascript