用Pytorch训练CNN(数据集MNIST,使用GPU的方法)


Posted in Python onAugust 19, 2019

听说pytorch使用比TensorFlow简单,加之pytorch现已支持windows,所以今天装了pytorch玩玩,第一件事还是写了个简单的CNN在MNIST上实验,初步体验的确比TensorFlow方便。

参考代码(在莫烦python的教程代码基础上修改)如下:

import torch 
import torch.nn as nn 
from torch.autograd import Variable 
import torch.utils.data as Data 
import torchvision 
import time
#import matplotlib.pyplot as plt 
 
torch.manual_seed(1) 
 
EPOCH = 1 
BATCH_SIZE = 50 
LR = 0.001 
DOWNLOAD_MNIST = False 
if_use_gpu = 1
 
# 获取训练集dataset 
training_data = torchvision.datasets.MNIST( 
       root='./mnist/', # dataset存储路径 
       train=True, # True表示是train训练集,False表示test测试集 
       transform=torchvision.transforms.ToTensor(), # 将原数据规范化到(0,1)区间 
       download=DOWNLOAD_MNIST, 
       ) 
 
# 打印MNIST数据集的训练集及测试集的尺寸 
print(training_data.train_data.size()) 
print(training_data.train_labels.size()) 
# torch.Size([60000, 28, 28]) 
# torch.Size([60000]) 
 
#plt.imshow(training_data.train_data[0].numpy(), cmap='gray') 
#plt.title('%i' % training_data.train_labels[0]) 
#plt.show() 
 
# 通过torchvision.datasets获取的dataset格式可直接可置于DataLoader 
train_loader = Data.DataLoader(dataset=training_data, batch_size=BATCH_SIZE, 
                shuffle=True) 
 
# 获取测试集dataset 

test_data = torchvision.datasets.MNIST( 
       root='./mnist/', # dataset存储路径 
       train=False, # True表示是train训练集,False表示test测试集 
       transform=torchvision.transforms.ToTensor(), # 将原数据规范化到(0,1)区间 
       download=DOWNLOAD_MNIST, 
       ) 
# 取前全部10000个测试集样本 
test_x = Variable(torch.unsqueeze(test_data.test_data, dim=1).float(), requires_grad=False)
#test_x = test_x.cuda()
## (~, 28, 28) to (~, 1, 28, 28), in range(0,1) 
test_y = test_data.test_labels
#test_y = test_y.cuda() 
class CNN(nn.Module): 
  def __init__(self): 
    super(CNN, self).__init__() 
    self.conv1 = nn.Sequential( # (1,28,28) 
           nn.Conv2d(in_channels=1, out_channels=16, kernel_size=5, 
                stride=1, padding=2), # (16,28,28) 
    # 想要con2d卷积出来的图片尺寸没有变化, padding=(kernel_size-1)/2 
           nn.ReLU(), 
           nn.MaxPool2d(kernel_size=2) # (16,14,14) 
           ) 
    self.conv2 = nn.Sequential( # (16,14,14) 
           nn.Conv2d(16, 32, 5, 1, 2), # (32,14,14) 
           nn.ReLU(), 
           nn.MaxPool2d(2) # (32,7,7) 
           ) 
    self.out = nn.Linear(32*7*7, 10) 
 
  def forward(self, x): 
    x = self.conv1(x) 
    x = self.conv2(x) 
    x = x.view(x.size(0), -1) # 将(batch,32,7,7)展平为(batch,32*7*7) 
    output = self.out(x) 
    return output 
 
cnn = CNN() 
if if_use_gpu:
  cnn = cnn.cuda()

optimizer = torch.optim.Adam(cnn.parameters(), lr=LR) 
loss_function = nn.CrossEntropyLoss() 
 


for epoch in range(EPOCH): 
  start = time.time() 
  for step, (x, y) in enumerate(train_loader): 
    b_x = Variable(x, requires_grad=False) 
    b_y = Variable(y, requires_grad=False) 
    if if_use_gpu:
      b_x = b_x.cuda()
      b_y = b_y.cuda()
 
    output = cnn(b_x) 
    loss = loss_function(output, b_y) 
    optimizer.zero_grad() 
    loss.backward() 
    optimizer.step() 
 
    if step % 100 == 0: 
      print('Epoch:', epoch, '|Step:', step, 
         '|train loss:%.4f'%loss.data[0]) 
  duration = time.time() - start 
  print('Training duation: %.4f'%duration)
  
cnn = cnn.cpu()
test_output = cnn(test_x) 
pred_y = torch.max(test_output, 1)[1].data.squeeze()
accuracy = sum(pred_y == test_y) / test_y.size(0) 
print('Test Acc: %.4f'%accuracy)

以上这篇用Pytorch训练CNN(数据集MNIST,使用GPU的方法)就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
简单的连接MySQL与Python的Bottle框架的方法
Apr 30 Python
python中黄金分割法实现方法
May 06 Python
Django的session中对于用户验证的支持
Jul 23 Python
总结网络IO模型与select模型的Python实例讲解
Jun 27 Python
python利用lxml读写xml格式的文件
Aug 10 Python
python导入csv文件出现SyntaxError问题分析
Dec 15 Python
python对绑定事件的鼠标、按键的判断实例
Jul 17 Python
Python的Lambda函数用法详解
Sep 03 Python
Pycharm创建项目时如何自动添加头部信息
Nov 14 Python
Python装饰器用法与知识点小结
Mar 09 Python
pandas数据分组groupby()和统计函数agg()的使用
Mar 04 Python
python 爬取豆瓣网页的示例
Apr 13 Python
python opencv实现证件照换底功能
Aug 19 #Python
解决pytorch GPU 计算过程中出现内存耗尽的问题
Aug 19 #Python
将Pytorch模型从CPU转换成GPU的实现方法
Aug 19 #Python
pytorch 使用单个GPU与多个GPU进行训练与测试的方法
Aug 19 #Python
在pytorch中为Module和Tensor指定GPU的例子
Aug 19 #Python
pytorch使用指定GPU训练的实例
Aug 19 #Python
关于pytorch多GPU训练实例与性能对比分析
Aug 19 #Python
You might like
怎样使用php与jquery设置和读取cookies
2013/08/08 PHP
php安装swoole扩展的方法
2015/03/19 PHP
使用php-timeit估计php函数的执行时间
2015/09/06 PHP
php遍历解析xml字符串的方法
2016/05/05 PHP
PHP实现找出链表中环的入口节点
2018/01/16 PHP
thinkphp5框架实现的自定义扩展类操作示例
2019/05/16 PHP
javascript 操作cookies及正确使用cookies的属性
2009/10/15 Javascript
js将当前时间格式转换成时间搓(自写)
2013/09/26 Javascript
JS的get和set使用示例
2014/02/20 Javascript
js网页右下角提示框实例
2014/10/14 Javascript
jQuery实现隔行背景色变色
2014/11/24 Javascript
理解javascript闭包
2015/12/15 Javascript
js中数组结合字符串实现查找(屏蔽广告判断url等)
2016/03/30 Javascript
AngularJS的ng-repeat指令与scope继承关系实例详解
2017/01/21 Javascript
layui表格实现代码
2017/05/20 Javascript
微信小程序多列选择器range-key使用详解
2020/03/30 Javascript
vue 双向数据绑定的实现学习之监听器的实现方法
2018/11/30 Javascript
关于微信小程序获取小程序码并接受buffer流保存为图片的方法
2019/06/07 Javascript
使用layui实现的左侧菜单栏以及动态操作tab项方法
2019/09/10 Javascript
用js限制网页只在微信浏览器中打开(或者只能手机端访问)
2020/12/24 Javascript
Javascript摸拟自由落体与上抛运动原理与实现方法详解
2020/04/08 Javascript
JS替换字符串中指定位置的字符(多种方法)
2020/05/28 Javascript
Python获取暗黑破坏神3战网前1000命位玩家的英雄技能统计
2016/07/04 Python
python 寻找优化使成本函数最小的最优解的方法
2017/12/28 Python
django解决跨域请求的问题
2018/11/11 Python
python批量修改文件夹及其子文件夹下的文件内容
2019/03/15 Python
python中的global关键字的使用方法
2019/08/20 Python
Python读取分割压缩TXT文本文件实例
2020/02/14 Python
浅析Django 接收所有文件,前端展示文件(包括视频,文件,图片)ajax请求
2020/03/09 Python
市场营销专业毕业生自荐信
2013/11/02 职场文书
个人自我评价和职业目标
2014/01/24 职场文书
《海底世界》教学反思
2014/04/16 职场文书
环境卫生倡议书
2014/08/29 职场文书
2014年督导工作总结
2014/11/19 职场文书
警示教育片观后感
2015/06/17 职场文书
Win10鼠标轨迹怎么开 Win10显示鼠标轨迹方法
2022/04/06 数码科技