PyTorch-GPU加速实例


Posted in Python onJune 23, 2020

硬件:NVIDIA-GTX1080

软件:Windows7、python3.6.5、pytorch-gpu-0.4.1

一、基础知识

将数据和网络都推到GPU,接上.cuda()

二、代码展示

import torch
import torch.nn as nn
import torch.utils.data as Data
import torchvision
# torch.manual_seed(1)
 
EPOCH = 1
BATCH_SIZE = 50
LR = 0.001
DOWNLOAD_MNIST = False
 
train_data = torchvision.datasets.MNIST(root='./mnist/', train=True, transform=torchvision.transforms.ToTensor(), download=DOWNLOAD_MNIST,)
train_loader = Data.DataLoader(dataset=train_data, batch_size=BATCH_SIZE, shuffle=True)
 
test_data = torchvision.datasets.MNIST(root='./mnist/', train=False)
 
# !!!!!!!! Change in here !!!!!!!!! #
test_x = torch.unsqueeze(test_data.test_data, dim=1).type(torch.FloatTensor)[:2000].cuda()/255. # Tensor on GPU
test_y = test_data.test_labels[:2000].cuda()
 
class CNN(nn.Module):
 def __init__(self):
  super(CNN, self).__init__()
  self.conv1 = nn.Sequential(nn.Conv2d(in_channels=1, out_channels=16, kernel_size=5, stride=1, padding=2,),
         nn.ReLU(), nn.MaxPool2d(kernel_size=2),)
  self.conv2 = nn.Sequential(nn.Conv2d(16, 32, 5, 1, 2), nn.ReLU(), nn.MaxPool2d(2),)
  self.out = nn.Linear(32 * 7 * 7, 10)
 
 def forward(self, x):
  x = self.conv1(x)
  x = self.conv2(x)
  x = x.view(x.size(0), -1)
  output = self.out(x)
  return output
 
cnn = CNN()
 
# !!!!!!!! Change in here !!!!!!!!! #
cnn.cuda()  # Moves all model parameters and buffers to the GPU.
 
optimizer = torch.optim.Adam(cnn.parameters(), lr=LR)
loss_func = nn.CrossEntropyLoss()
 
for epoch in range(EPOCH):
 for step, (x, y) in enumerate(train_loader):
 
  # !!!!!!!! Change in here !!!!!!!!! #
  b_x = x.cuda() # Tensor on GPU
  b_y = y.cuda() # Tensor on GPU
 
  output = cnn(b_x)
  loss = loss_func(output, b_y)
  optimizer.zero_grad()
  loss.backward()
  optimizer.step()
 
  if step % 50 == 0:
   test_output = cnn(test_x)
 
   # !!!!!!!! Change in here !!!!!!!!! #
   pred_y = torch.max(test_output, 1)[1].cuda().data # move the computation in GPU
 
   accuracy = torch.sum(pred_y == test_y).type(torch.FloatTensor) / test_y.size(0)
   print('Epoch: ', epoch, '| train loss: %.4f' % loss, '| test accuracy: %.2f' % accuracy)
 
test_output = cnn(test_x[:10])
 
# !!!!!!!! Change in here !!!!!!!!! #
pred_y = torch.max(test_output, 1)[1].cuda().data # move the computation in GPU
 
print(pred_y, 'prediction number')
print(test_y[:10], 'real number')

三、结果展示

PyTorch-GPU加速实例

补充知识:pytorch使用gpu对网络计算进行加速

1.基本要求

你的电脑里面有合适的GPU显卡(NVIDA),并且需要支持CUDA模块

你必须安装GPU版的Torch,(详细安装方法请移步pytorch官网)

2.使用GPU训练CNN

利用pytorch使用GPU进行加速方法主要就是将数据的形式变成GPU能读的形式,然后将CNN也变成GPU能读的形式,具体办法就是在后面加上.cuda()。

例如:

#如何检查自己电脑是否支持cuda
print torch.cuda.is_available()
# 返回True代表支持,False代表不支持
'''
注意在进行某种运算的时候使用.cuda()
'''
test_data=test_data.test_labels[:2000].cuda()
'''
对于CNN与损失函数利用cuda加速
'''
class CNN(nn.Module):
 ...
cnn=CNN()
cnn.cuda()
loss_f = t.nn.CrossEntropyLoss()
loss_f = loss_f.cuda()

而在train时,对于train_data训练过程进行GPU加速。也同样+.cuda()。

for epoch ..:
 for step, ...:
 1
'''
若你的train_data在训练时需要进行操作
若没有其他操作仅仅只利用cnn()则无需另加.cuda()
'''
#eg
 train_data = torch.max(teain_data, 1)[1].cuda()

补充:取出数据需要从GPU切换到CPU上进行操作

eg:

loss = loss.cpu()
acc = acc.cpu()

理解并不全,如有纰漏或者错误还望各位大佬指点迷津

以上这篇PyTorch-GPU加速实例就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
python回调函数的使用方法
Jan 23 Python
Python中的各种装饰器详解
Apr 11 Python
pygame学习笔记(3):运动速率、时间、事件、文字
Apr 15 Python
Python实现的计数排序算法示例
Nov 29 Python
Python遍历某目录下的所有文件夹与文件路径
Mar 15 Python
在python 中split()使用多符号分割的例子
Jul 15 Python
python 遍历pd.Series的index和value
Nov 26 Python
python虚拟环境模块venv使用及示例
Mar 04 Python
Django后端分离 使用element-ui文件上传方式
Jul 12 Python
python em算法的实现
Oct 03 Python
Python实现简单的2048小游戏
Mar 01 Python
python+opencv实现视频抽帧示例代码
Jun 11 Python
Python基于yaml文件配置logging日志过程解析
Jun 23 #Python
解决pytorch多GPU训练保存的模型,在单GPU环境下加载出错问题
Jun 23 #Python
Python 程序报错崩溃后如何倒回到崩溃的位置(推荐)
Jun 23 #Python
浅谈pytorch中的BN层的注意事项
Jun 23 #Python
Python3与fastdfs分布式文件系统如何实现交互
Jun 23 #Python
踩坑:pytorch中eval模式下结果远差于train模式介绍
Jun 23 #Python
pytorch掉坑记录:model.eval的作用说明
Jun 23 #Python
You might like
PHP简单遍历对象示例
2016/09/28 PHP
PHP实现的mysql主从数据库状态检测功能示例
2017/07/20 PHP
PHP连接MySQL数据库三种实现方法
2020/12/10 PHP
解决jquery .ajax 在IE下卡死问题的解决方法
2009/10/26 Javascript
CSS和Javascript简单复习资料
2010/06/29 Javascript
关于JavaScript定义类和对象的几种方式
2010/11/09 Javascript
JSONP 跨域共享信息
2012/08/16 Javascript
js中 关于undefined和null的区别介绍
2013/04/16 Javascript
jQuery中removeAttr()方法用法实例
2015/01/05 Javascript
js对象的复制继承实例
2015/01/10 Javascript
Javascript核心读书有感之类型、值和变量
2015/02/11 Javascript
bootstrap table 表格中增加下拉菜单末行出现滚动条的快速解决方法
2017/01/05 Javascript
BootStrap表单验证实例代码
2017/01/13 Javascript
javascript中json基础知识详解
2017/01/19 Javascript
遍历json获得数据的几种方法小结
2017/01/21 Javascript
前端图片懒加载(lazyload)的实现方法(提高用户体验)
2017/08/21 Javascript
探索Vue高阶组件的使用
2018/01/08 Javascript
jquery 输入框查找关键字并提亮颜色的实例代码
2018/01/23 jQuery
浅谈Vue.use的使用
2018/08/29 Javascript
微信小程序开发之自定义tabBar的实现
2018/09/06 Javascript
微信小程序中的上拉、下拉菜单功能
2020/03/13 Javascript
解决ant design vue中树形控件defaultExpandAll设置无效的问题
2020/10/26 Javascript
python 打印直角三角形,等边三角形,菱形,正方形的代码
2017/11/21 Python
numpy使用fromstring创建矩阵的实例
2018/06/15 Python
python模块导入的细节详解
2018/12/10 Python
用python生成与调用cntk模型代码演示方法
2019/08/26 Python
Python2与Python3关于字符串编码处理的差别总结
2020/09/07 Python
html5 canvas实现给图片添加平铺水印
2019/08/20 HTML / CSS
一套Java笔试题
2016/08/20 面试题
2013英文求职信模板范文
2013/11/15 职场文书
安全生产目标责任书
2014/04/14 职场文书
效能风暴心得体会
2014/09/04 职场文书
试用期自我评价怎么写
2015/03/10 职场文书
机关保密工作承诺书
2015/05/04 职场文书
干部培训简讯
2015/07/20 职场文书
JavaScript实现优先级队列
2021/12/06 Javascript