用Pytorch训练CNN(数据集MNIST,使用GPU的方法)


Posted in Python onAugust 19, 2019

听说pytorch使用比TensorFlow简单,加之pytorch现已支持windows,所以今天装了pytorch玩玩,第一件事还是写了个简单的CNN在MNIST上实验,初步体验的确比TensorFlow方便。

参考代码(在莫烦python的教程代码基础上修改)如下:

import torch 
import torch.nn as nn 
from torch.autograd import Variable 
import torch.utils.data as Data 
import torchvision 
import time
#import matplotlib.pyplot as plt 
 
torch.manual_seed(1) 
 
EPOCH = 1 
BATCH_SIZE = 50 
LR = 0.001 
DOWNLOAD_MNIST = False 
if_use_gpu = 1
 
# 获取训练集dataset 
training_data = torchvision.datasets.MNIST( 
       root='./mnist/', # dataset存储路径 
       train=True, # True表示是train训练集,False表示test测试集 
       transform=torchvision.transforms.ToTensor(), # 将原数据规范化到(0,1)区间 
       download=DOWNLOAD_MNIST, 
       ) 
 
# 打印MNIST数据集的训练集及测试集的尺寸 
print(training_data.train_data.size()) 
print(training_data.train_labels.size()) 
# torch.Size([60000, 28, 28]) 
# torch.Size([60000]) 
 
#plt.imshow(training_data.train_data[0].numpy(), cmap='gray') 
#plt.title('%i' % training_data.train_labels[0]) 
#plt.show() 
 
# 通过torchvision.datasets获取的dataset格式可直接可置于DataLoader 
train_loader = Data.DataLoader(dataset=training_data, batch_size=BATCH_SIZE, 
                shuffle=True) 
 
# 获取测试集dataset 

test_data = torchvision.datasets.MNIST( 
       root='./mnist/', # dataset存储路径 
       train=False, # True表示是train训练集,False表示test测试集 
       transform=torchvision.transforms.ToTensor(), # 将原数据规范化到(0,1)区间 
       download=DOWNLOAD_MNIST, 
       ) 
# 取前全部10000个测试集样本 
test_x = Variable(torch.unsqueeze(test_data.test_data, dim=1).float(), requires_grad=False)
#test_x = test_x.cuda()
## (~, 28, 28) to (~, 1, 28, 28), in range(0,1) 
test_y = test_data.test_labels
#test_y = test_y.cuda() 
class CNN(nn.Module): 
  def __init__(self): 
    super(CNN, self).__init__() 
    self.conv1 = nn.Sequential( # (1,28,28) 
           nn.Conv2d(in_channels=1, out_channels=16, kernel_size=5, 
                stride=1, padding=2), # (16,28,28) 
    # 想要con2d卷积出来的图片尺寸没有变化, padding=(kernel_size-1)/2 
           nn.ReLU(), 
           nn.MaxPool2d(kernel_size=2) # (16,14,14) 
           ) 
    self.conv2 = nn.Sequential( # (16,14,14) 
           nn.Conv2d(16, 32, 5, 1, 2), # (32,14,14) 
           nn.ReLU(), 
           nn.MaxPool2d(2) # (32,7,7) 
           ) 
    self.out = nn.Linear(32*7*7, 10) 
 
  def forward(self, x): 
    x = self.conv1(x) 
    x = self.conv2(x) 
    x = x.view(x.size(0), -1) # 将(batch,32,7,7)展平为(batch,32*7*7) 
    output = self.out(x) 
    return output 
 
cnn = CNN() 
if if_use_gpu:
  cnn = cnn.cuda()

optimizer = torch.optim.Adam(cnn.parameters(), lr=LR) 
loss_function = nn.CrossEntropyLoss() 
 


for epoch in range(EPOCH): 
  start = time.time() 
  for step, (x, y) in enumerate(train_loader): 
    b_x = Variable(x, requires_grad=False) 
    b_y = Variable(y, requires_grad=False) 
    if if_use_gpu:
      b_x = b_x.cuda()
      b_y = b_y.cuda()
 
    output = cnn(b_x) 
    loss = loss_function(output, b_y) 
    optimizer.zero_grad() 
    loss.backward() 
    optimizer.step() 
 
    if step % 100 == 0: 
      print('Epoch:', epoch, '|Step:', step, 
         '|train loss:%.4f'%loss.data[0]) 
  duration = time.time() - start 
  print('Training duation: %.4f'%duration)
  
cnn = cnn.cpu()
test_output = cnn(test_x) 
pred_y = torch.max(test_output, 1)[1].data.squeeze()
accuracy = sum(pred_y == test_y) / test_y.size(0) 
print('Test Acc: %.4f'%accuracy)

以上这篇用Pytorch训练CNN(数据集MNIST,使用GPU的方法)就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
从零学python系列之数据处理编程实例(二)
May 22 Python
python操作redis方法总结
Jun 06 Python
解决pip install xxx报错SyntaxError: invalid syntax的问题
Nov 30 Python
Python饼状图的绘制实例
Jan 15 Python
在Python中,不用while和for循环遍历列表的实例
Feb 20 Python
利用python numpy+matplotlib绘制股票k线图的方法
Jun 26 Python
Django中Middleware中的函数详解
Jul 18 Python
python自动化UI工具发送QQ消息的实例
Aug 27 Python
使用python绘制温度变化雷达图
Oct 18 Python
pycharm安装及如何导入numpy
Apr 03 Python
python邮件中附加文字、html、图片、附件实现方法
Jan 04 Python
Python tensorflow卷积神经Inception V3网络结构
May 06 Python
python opencv实现证件照换底功能
Aug 19 #Python
解决pytorch GPU 计算过程中出现内存耗尽的问题
Aug 19 #Python
将Pytorch模型从CPU转换成GPU的实现方法
Aug 19 #Python
pytorch 使用单个GPU与多个GPU进行训练与测试的方法
Aug 19 #Python
在pytorch中为Module和Tensor指定GPU的例子
Aug 19 #Python
pytorch使用指定GPU训练的实例
Aug 19 #Python
关于pytorch多GPU训练实例与性能对比分析
Aug 19 #Python
You might like
php smarty截取中文字符乱码问题?gb2312/utf-8
2011/11/07 PHP
php漏洞之跨网站请求伪造与防止伪造方法
2013/08/15 PHP
php采用curl访问域名返回405 method not allowed提示的解决方法
2014/06/26 PHP
详解Yii实现分页的两种方法
2017/01/14 PHP
PHP7如何开启Opcode打造强悍性能详解
2018/05/11 PHP
Laravel 5.5 实现禁用用户注册示例
2019/10/24 PHP
thinkphp5 + ajax 使用formdata提交数据(包括文件上传) 后台返回json完整实例
2020/03/02 PHP
setTimeout()与setInterval()方法区别介绍
2013/12/24 Javascript
详解JavaScript编程中的数组结构
2015/10/24 Javascript
javascript实现简单加载随机色方块
2015/12/25 Javascript
jQuery模拟物体自由落体运动(附演示与demo源码下载)
2016/01/21 Javascript
关于vue.js v-bind 的一些理解和思考
2017/06/06 Javascript
JS自定义滚动条效果简单实现代码
2020/10/27 Javascript
react.js 父子组件数据绑定实时通讯的示例代码
2017/09/25 Javascript
element-ui中select组件绑定值改变,触发change事件方法
2018/08/24 Javascript
微信小程序学习笔记之登录API与获取用户信息操作图文详解
2019/03/29 Javascript
如何获取vue单文件自身源码路径
2019/05/06 Javascript
layui的面包屑或者表单不显示的解决方法
2019/09/05 Javascript
JavaScript封装单向链表的示例代码
2020/09/17 Javascript
vue 使用 v-model 双向绑定父子组件的值遇见的问题及解决方案
2021/03/01 Vue.js
[58:57]2018DOTA2亚洲邀请赛3月29日小组赛B组 Effect VS VGJ.T
2018/03/30 DOTA
Python人脸识别初探
2017/12/21 Python
Python实现购物车购物小程序
2018/04/18 Python
利用python对Excel中的特定数据提取并写入新表的方法
2018/06/14 Python
Python matplotlib修改默认字体的操作
2020/03/05 Python
使用OpenCV实现人脸图像卡通化的示例代码
2021/01/15 Python
css3动画 小球滚动 js控制动画暂停
2019/11/29 HTML / CSS
html5指南-2.如何操作document metadata
2013/01/07 HTML / CSS
丝芙兰加拿大官方网站:SEPHORA加拿大
2018/11/20 全球购物
德国自然时尚和有机产品购物网站:Waschbär
2019/05/29 全球购物
中医药大学市场营销专业自荐信
2013/09/29 职场文书
实习老师离校感言
2014/02/03 职场文书
中学生运动会入场词
2014/02/12 职场文书
公司领导班子召开党的群众路线教育实践活动总结大会新闻稿
2014/10/21 职场文书
2014年绩效考核工作总结
2014/12/11 职场文书
家庭经济困难证明
2015/06/23 职场文书