PyTorch-GPU加速实例


Posted in Python onJune 23, 2020

硬件:NVIDIA-GTX1080

软件:Windows7、python3.6.5、pytorch-gpu-0.4.1

一、基础知识

将数据和网络都推到GPU,接上.cuda()

二、代码展示

import torch
import torch.nn as nn
import torch.utils.data as Data
import torchvision
# torch.manual_seed(1)
 
EPOCH = 1
BATCH_SIZE = 50
LR = 0.001
DOWNLOAD_MNIST = False
 
train_data = torchvision.datasets.MNIST(root='./mnist/', train=True, transform=torchvision.transforms.ToTensor(), download=DOWNLOAD_MNIST,)
train_loader = Data.DataLoader(dataset=train_data, batch_size=BATCH_SIZE, shuffle=True)
 
test_data = torchvision.datasets.MNIST(root='./mnist/', train=False)
 
# !!!!!!!! Change in here !!!!!!!!! #
test_x = torch.unsqueeze(test_data.test_data, dim=1).type(torch.FloatTensor)[:2000].cuda()/255. # Tensor on GPU
test_y = test_data.test_labels[:2000].cuda()
 
class CNN(nn.Module):
 def __init__(self):
  super(CNN, self).__init__()
  self.conv1 = nn.Sequential(nn.Conv2d(in_channels=1, out_channels=16, kernel_size=5, stride=1, padding=2,),
         nn.ReLU(), nn.MaxPool2d(kernel_size=2),)
  self.conv2 = nn.Sequential(nn.Conv2d(16, 32, 5, 1, 2), nn.ReLU(), nn.MaxPool2d(2),)
  self.out = nn.Linear(32 * 7 * 7, 10)
 
 def forward(self, x):
  x = self.conv1(x)
  x = self.conv2(x)
  x = x.view(x.size(0), -1)
  output = self.out(x)
  return output
 
cnn = CNN()
 
# !!!!!!!! Change in here !!!!!!!!! #
cnn.cuda()  # Moves all model parameters and buffers to the GPU.
 
optimizer = torch.optim.Adam(cnn.parameters(), lr=LR)
loss_func = nn.CrossEntropyLoss()
 
for epoch in range(EPOCH):
 for step, (x, y) in enumerate(train_loader):
 
  # !!!!!!!! Change in here !!!!!!!!! #
  b_x = x.cuda() # Tensor on GPU
  b_y = y.cuda() # Tensor on GPU
 
  output = cnn(b_x)
  loss = loss_func(output, b_y)
  optimizer.zero_grad()
  loss.backward()
  optimizer.step()
 
  if step % 50 == 0:
   test_output = cnn(test_x)
 
   # !!!!!!!! Change in here !!!!!!!!! #
   pred_y = torch.max(test_output, 1)[1].cuda().data # move the computation in GPU
 
   accuracy = torch.sum(pred_y == test_y).type(torch.FloatTensor) / test_y.size(0)
   print('Epoch: ', epoch, '| train loss: %.4f' % loss, '| test accuracy: %.2f' % accuracy)
 
test_output = cnn(test_x[:10])
 
# !!!!!!!! Change in here !!!!!!!!! #
pred_y = torch.max(test_output, 1)[1].cuda().data # move the computation in GPU
 
print(pred_y, 'prediction number')
print(test_y[:10], 'real number')

三、结果展示

PyTorch-GPU加速实例

补充知识:pytorch使用gpu对网络计算进行加速

1.基本要求

你的电脑里面有合适的GPU显卡(NVIDA),并且需要支持CUDA模块

你必须安装GPU版的Torch,(详细安装方法请移步pytorch官网)

2.使用GPU训练CNN

利用pytorch使用GPU进行加速方法主要就是将数据的形式变成GPU能读的形式,然后将CNN也变成GPU能读的形式,具体办法就是在后面加上.cuda()。

例如:

#如何检查自己电脑是否支持cuda
print torch.cuda.is_available()
# 返回True代表支持,False代表不支持
'''
注意在进行某种运算的时候使用.cuda()
'''
test_data=test_data.test_labels[:2000].cuda()
'''
对于CNN与损失函数利用cuda加速
'''
class CNN(nn.Module):
 ...
cnn=CNN()
cnn.cuda()
loss_f = t.nn.CrossEntropyLoss()
loss_f = loss_f.cuda()

而在train时,对于train_data训练过程进行GPU加速。也同样+.cuda()。

for epoch ..:
 for step, ...:
 1
'''
若你的train_data在训练时需要进行操作
若没有其他操作仅仅只利用cnn()则无需另加.cuda()
'''
#eg
 train_data = torch.max(teain_data, 1)[1].cuda()

补充:取出数据需要从GPU切换到CPU上进行操作

eg:

loss = loss.cpu()
acc = acc.cpu()

理解并不全,如有纰漏或者错误还望各位大佬指点迷津

以上这篇PyTorch-GPU加速实例就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
详解Python中的from..import绝对导入语句
Jun 21 Python
详解python中 os._exit() 和 sys.exit(), exit(0)和exit(1) 的用法和区别
Jun 23 Python
Python中函数及默认参数的定义与调用操作实例分析
Jul 25 Python
Python MD5加密实例详解
Aug 02 Python
pyqt5 comboBox获得下标、文本和事件选中函数的方法
Jun 14 Python
Django Channels 实现点对点实时聊天和消息推送功能
Jul 17 Python
Django认证系统实现的web页面实现代码
Aug 12 Python
python 操作hive pyhs2方式
Dec 21 Python
Python实现钉钉订阅消息功能
Jan 14 Python
Python 支持向量机分类器的实现
Jan 15 Python
上手简单,功能强大的Python爬虫框架——feapder
Apr 27 Python
利用Python实现Picgo图床工具
Nov 23 Python
Python基于yaml文件配置logging日志过程解析
Jun 23 #Python
解决pytorch多GPU训练保存的模型,在单GPU环境下加载出错问题
Jun 23 #Python
Python 程序报错崩溃后如何倒回到崩溃的位置(推荐)
Jun 23 #Python
浅谈pytorch中的BN层的注意事项
Jun 23 #Python
Python3与fastdfs分布式文件系统如何实现交互
Jun 23 #Python
踩坑:pytorch中eval模式下结果远差于train模式介绍
Jun 23 #Python
pytorch掉坑记录:model.eval的作用说明
Jun 23 #Python
You might like
关于UEditor编辑器远程图片上传失败的解决办法
2012/08/31 PHP
PHP5.3的垃圾回收机制(动态存储分配方案)深入理解
2012/12/10 PHP
php对大文件进行读取操作的实现代码
2013/01/23 PHP
thinkphp实现数组分页示例
2014/04/13 PHP
PHP统计nginx访问日志中的搜索引擎抓取404链接页面路径
2014/06/30 PHP
php生成图片缩略图的方法
2015/04/07 PHP
Yii2框架dropDownList下拉菜单用法实例分析
2016/07/18 PHP
基于jquery的lazy loader插件实现图片的延迟加载[简单使用]
2011/05/07 Javascript
Javascript 创建类并动态添加属性及方法的简单实现
2016/10/20 Javascript
js实现砖头在页面拖拉效果
2020/11/20 Javascript
JS IOS/iPhone的Safari浏览器不兼容Javascript中的Date()问题如何解决
2016/11/11 Javascript
Vuex之理解Mutations的用法实例
2017/04/19 Javascript
javascript ES6 新增了let命令使用介绍
2017/07/07 Javascript
JS实现根据指定值删除数组中的元素操作示例
2018/08/02 Javascript
nodejs中函数的调用实例详解
2018/10/31 NodeJs
用Fundebug插件记录网络请求异常的方法
2019/02/21 Javascript
Node.js API详解之 console模块用法详解
2020/05/12 Javascript
js实现微信聊天效果
2020/08/09 Javascript
vue 函数调用加括号与不加括号的区别
2020/10/29 Javascript
[05:40]DOTA2荣耀之路6:Wings最后进攻
2018/05/30 DOTA
[01:13:59]LGD vs Mineski Supermajor 胜者组 BO3 第三场 6.5
2018/06/06 DOTA
[56:42]VP vs RNG 2019国际邀请赛小组赛 BO2 第二场 8.15
2019/08/17 DOTA
Python设计模式编程中解释器模式的简单程序示例分享
2016/03/02 Python
对python中的乘法dot和对应分量相乘multiply详解
2018/11/14 Python
python傅里叶变换FFT绘制频谱图
2019/07/19 Python
python对常见数据类型的遍历解析
2019/08/27 Python
浅谈django 重载str 方法
2020/05/19 Python
python实现数字炸弹游戏
2020/07/17 Python
全面总结使用CSS实现水平垂直居中效果的方法
2016/03/10 HTML / CSS
阿迪达斯奥地利官方商城:adidas.at
2016/10/16 全球购物
麦德龙官方海外旗舰店:德国麦德龙超市
2017/12/23 全球购物
教师岗位职责
2013/11/17 职场文书
集体备课反思
2014/02/12 职场文书
物流毕业生个人的自我评价
2014/02/13 职场文书
怎样写工作总结啊!
2019/06/18 职场文书
浅谈MySQL之浅入深出页原理
2021/06/23 MySQL