PyTorch-GPU加速实例


Posted in Python onJune 23, 2020

硬件:NVIDIA-GTX1080

软件:Windows7、python3.6.5、pytorch-gpu-0.4.1

一、基础知识

将数据和网络都推到GPU,接上.cuda()

二、代码展示

import torch
import torch.nn as nn
import torch.utils.data as Data
import torchvision
# torch.manual_seed(1)
 
EPOCH = 1
BATCH_SIZE = 50
LR = 0.001
DOWNLOAD_MNIST = False
 
train_data = torchvision.datasets.MNIST(root='./mnist/', train=True, transform=torchvision.transforms.ToTensor(), download=DOWNLOAD_MNIST,)
train_loader = Data.DataLoader(dataset=train_data, batch_size=BATCH_SIZE, shuffle=True)
 
test_data = torchvision.datasets.MNIST(root='./mnist/', train=False)
 
# !!!!!!!! Change in here !!!!!!!!! #
test_x = torch.unsqueeze(test_data.test_data, dim=1).type(torch.FloatTensor)[:2000].cuda()/255. # Tensor on GPU
test_y = test_data.test_labels[:2000].cuda()
 
class CNN(nn.Module):
 def __init__(self):
  super(CNN, self).__init__()
  self.conv1 = nn.Sequential(nn.Conv2d(in_channels=1, out_channels=16, kernel_size=5, stride=1, padding=2,),
         nn.ReLU(), nn.MaxPool2d(kernel_size=2),)
  self.conv2 = nn.Sequential(nn.Conv2d(16, 32, 5, 1, 2), nn.ReLU(), nn.MaxPool2d(2),)
  self.out = nn.Linear(32 * 7 * 7, 10)
 
 def forward(self, x):
  x = self.conv1(x)
  x = self.conv2(x)
  x = x.view(x.size(0), -1)
  output = self.out(x)
  return output
 
cnn = CNN()
 
# !!!!!!!! Change in here !!!!!!!!! #
cnn.cuda()  # Moves all model parameters and buffers to the GPU.
 
optimizer = torch.optim.Adam(cnn.parameters(), lr=LR)
loss_func = nn.CrossEntropyLoss()
 
for epoch in range(EPOCH):
 for step, (x, y) in enumerate(train_loader):
 
  # !!!!!!!! Change in here !!!!!!!!! #
  b_x = x.cuda() # Tensor on GPU
  b_y = y.cuda() # Tensor on GPU
 
  output = cnn(b_x)
  loss = loss_func(output, b_y)
  optimizer.zero_grad()
  loss.backward()
  optimizer.step()
 
  if step % 50 == 0:
   test_output = cnn(test_x)
 
   # !!!!!!!! Change in here !!!!!!!!! #
   pred_y = torch.max(test_output, 1)[1].cuda().data # move the computation in GPU
 
   accuracy = torch.sum(pred_y == test_y).type(torch.FloatTensor) / test_y.size(0)
   print('Epoch: ', epoch, '| train loss: %.4f' % loss, '| test accuracy: %.2f' % accuracy)
 
test_output = cnn(test_x[:10])
 
# !!!!!!!! Change in here !!!!!!!!! #
pred_y = torch.max(test_output, 1)[1].cuda().data # move the computation in GPU
 
print(pred_y, 'prediction number')
print(test_y[:10], 'real number')

三、结果展示

PyTorch-GPU加速实例

补充知识:pytorch使用gpu对网络计算进行加速

1.基本要求

你的电脑里面有合适的GPU显卡(NVIDA),并且需要支持CUDA模块

你必须安装GPU版的Torch,(详细安装方法请移步pytorch官网)

2.使用GPU训练CNN

利用pytorch使用GPU进行加速方法主要就是将数据的形式变成GPU能读的形式,然后将CNN也变成GPU能读的形式,具体办法就是在后面加上.cuda()。

例如:

#如何检查自己电脑是否支持cuda
print torch.cuda.is_available()
# 返回True代表支持,False代表不支持
'''
注意在进行某种运算的时候使用.cuda()
'''
test_data=test_data.test_labels[:2000].cuda()
'''
对于CNN与损失函数利用cuda加速
'''
class CNN(nn.Module):
 ...
cnn=CNN()
cnn.cuda()
loss_f = t.nn.CrossEntropyLoss()
loss_f = loss_f.cuda()

而在train时,对于train_data训练过程进行GPU加速。也同样+.cuda()。

for epoch ..:
 for step, ...:
 1
'''
若你的train_data在训练时需要进行操作
若没有其他操作仅仅只利用cnn()则无需另加.cuda()
'''
#eg
 train_data = torch.max(teain_data, 1)[1].cuda()

补充:取出数据需要从GPU切换到CPU上进行操作

eg:

loss = loss.cpu()
acc = acc.cpu()

理解并不全,如有纰漏或者错误还望各位大佬指点迷津

以上这篇PyTorch-GPU加速实例就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
python中遍历文件的3个方法
Sep 02 Python
python线程池(threadpool)模块使用笔记详解
Nov 17 Python
Python Web程序部署到Ubuntu服务器上的方法
Feb 22 Python
python自动发邮件库yagmail的示例代码
Feb 23 Python
Python系统监控模块psutil功能与经典用法分析
May 24 Python
Python2 Selenium元素定位的实现(8种)
Feb 25 Python
django多种支付、并发订单处理实例代码
Dec 13 Python
Python continue语句实例用法
Feb 06 Python
python网络编程之五子棋游戏
May 14 Python
Python类的继承super相关原理解析
Oct 22 Python
pycharm 关闭search everywhere的解决操作
Jan 15 Python
Python使用Beautiful Soup(BS4)库解析HTML和XML
Jun 05 Python
Python基于yaml文件配置logging日志过程解析
Jun 23 #Python
解决pytorch多GPU训练保存的模型,在单GPU环境下加载出错问题
Jun 23 #Python
Python 程序报错崩溃后如何倒回到崩溃的位置(推荐)
Jun 23 #Python
浅谈pytorch中的BN层的注意事项
Jun 23 #Python
Python3与fastdfs分布式文件系统如何实现交互
Jun 23 #Python
踩坑:pytorch中eval模式下结果远差于train模式介绍
Jun 23 #Python
pytorch掉坑记录:model.eval的作用说明
Jun 23 #Python
You might like
PHP语法速查表
2007/01/02 PHP
php下载远程文件类(支持断点续传)
2008/11/14 PHP
PHP chmod 函数与批量修改文件目录权限
2010/05/10 PHP
php学习之数据类型之间的转换介绍
2011/06/09 PHP
3个PHP多维数组转为一维数组的方法实例
2014/03/13 PHP
PHP严重致命错误处理:php Fatal error: Cannot redeclare class or function
2017/02/05 PHP
Yii框架小部件(Widgets)用法实例详解
2020/05/15 PHP
动态改变textbox的宽高的js
2006/10/26 Javascript
Jqgrid表格随窗口大小改变而改变的简单实例
2013/12/28 Javascript
jquery实现勾选复选框触发事件给input赋值
2015/02/01 Javascript
js实现仿百度风云榜可重复多次调用的TAB切换选项卡效果
2015/08/31 Javascript
使用jQuery+EasyUI实现CheckBoxTree的级联选中特效
2015/12/06 Javascript
基于react后端渲染模板引擎noox发布使用
2018/01/11 Javascript
vue中实现在外部调用methods的方法(推荐)
2018/02/08 Javascript
vue多页面开发和打包正确处理方法
2018/04/20 Javascript
详解Vue+axios+Node+express实现文件上传(用户头像上传)
2018/08/10 Javascript
RequireJS用法简单示例
2018/08/20 Javascript
react-native滑动吸顶效果的实现过程
2019/06/03 Javascript
[02:21]十步杀一人,千里不留行——DOTA2全新英雄天涯墨客展示
2018/08/29 DOTA
Python 解析XML文件
2009/04/15 Python
举例讲解Python面相对象编程中对象的属性与类的方法
2016/01/19 Python
python通过elixir包操作mysql数据库实例代码
2018/01/31 Python
PyQt5实现拖放功能
2018/04/25 Python
python seaborn heatmap可视化相关性矩阵实例
2020/06/03 Python
Django限制API访问频率常用方法解析
2020/10/12 Python
IE浏览器单独写CSS样式的几种方法
2014/10/14 HTML / CSS
Camper鞋西班牙官方网上商店:西班牙马略卡岛的鞋类品牌
2019/03/14 全球购物
高校生生产实习自我鉴定
2013/09/21 职场文书
领导检查欢迎词
2014/01/14 职场文书
小学生评语集锦
2014/04/18 职场文书
三方协议书范本
2014/04/22 职场文书
任命书范本大全
2014/06/06 职场文书
2015年国庆节寄语
2015/08/17 职场文书
2016年三严三实党课学习心得体会
2016/01/06 职场文书
技术入股协议书
2016/03/22 职场文书
创业计划书之游泳馆
2019/09/16 职场文书