PyTorch-GPU加速实例


Posted in Python onJune 23, 2020

硬件:NVIDIA-GTX1080

软件:Windows7、python3.6.5、pytorch-gpu-0.4.1

一、基础知识

将数据和网络都推到GPU,接上.cuda()

二、代码展示

import torch
import torch.nn as nn
import torch.utils.data as Data
import torchvision
# torch.manual_seed(1)
 
EPOCH = 1
BATCH_SIZE = 50
LR = 0.001
DOWNLOAD_MNIST = False
 
train_data = torchvision.datasets.MNIST(root='./mnist/', train=True, transform=torchvision.transforms.ToTensor(), download=DOWNLOAD_MNIST,)
train_loader = Data.DataLoader(dataset=train_data, batch_size=BATCH_SIZE, shuffle=True)
 
test_data = torchvision.datasets.MNIST(root='./mnist/', train=False)
 
# !!!!!!!! Change in here !!!!!!!!! #
test_x = torch.unsqueeze(test_data.test_data, dim=1).type(torch.FloatTensor)[:2000].cuda()/255. # Tensor on GPU
test_y = test_data.test_labels[:2000].cuda()
 
class CNN(nn.Module):
 def __init__(self):
  super(CNN, self).__init__()
  self.conv1 = nn.Sequential(nn.Conv2d(in_channels=1, out_channels=16, kernel_size=5, stride=1, padding=2,),
         nn.ReLU(), nn.MaxPool2d(kernel_size=2),)
  self.conv2 = nn.Sequential(nn.Conv2d(16, 32, 5, 1, 2), nn.ReLU(), nn.MaxPool2d(2),)
  self.out = nn.Linear(32 * 7 * 7, 10)
 
 def forward(self, x):
  x = self.conv1(x)
  x = self.conv2(x)
  x = x.view(x.size(0), -1)
  output = self.out(x)
  return output
 
cnn = CNN()
 
# !!!!!!!! Change in here !!!!!!!!! #
cnn.cuda()  # Moves all model parameters and buffers to the GPU.
 
optimizer = torch.optim.Adam(cnn.parameters(), lr=LR)
loss_func = nn.CrossEntropyLoss()
 
for epoch in range(EPOCH):
 for step, (x, y) in enumerate(train_loader):
 
  # !!!!!!!! Change in here !!!!!!!!! #
  b_x = x.cuda() # Tensor on GPU
  b_y = y.cuda() # Tensor on GPU
 
  output = cnn(b_x)
  loss = loss_func(output, b_y)
  optimizer.zero_grad()
  loss.backward()
  optimizer.step()
 
  if step % 50 == 0:
   test_output = cnn(test_x)
 
   # !!!!!!!! Change in here !!!!!!!!! #
   pred_y = torch.max(test_output, 1)[1].cuda().data # move the computation in GPU
 
   accuracy = torch.sum(pred_y == test_y).type(torch.FloatTensor) / test_y.size(0)
   print('Epoch: ', epoch, '| train loss: %.4f' % loss, '| test accuracy: %.2f' % accuracy)
 
test_output = cnn(test_x[:10])
 
# !!!!!!!! Change in here !!!!!!!!! #
pred_y = torch.max(test_output, 1)[1].cuda().data # move the computation in GPU
 
print(pred_y, 'prediction number')
print(test_y[:10], 'real number')

三、结果展示

PyTorch-GPU加速实例

补充知识:pytorch使用gpu对网络计算进行加速

1.基本要求

你的电脑里面有合适的GPU显卡(NVIDA),并且需要支持CUDA模块

你必须安装GPU版的Torch,(详细安装方法请移步pytorch官网)

2.使用GPU训练CNN

利用pytorch使用GPU进行加速方法主要就是将数据的形式变成GPU能读的形式,然后将CNN也变成GPU能读的形式,具体办法就是在后面加上.cuda()。

例如:

#如何检查自己电脑是否支持cuda
print torch.cuda.is_available()
# 返回True代表支持,False代表不支持
'''
注意在进行某种运算的时候使用.cuda()
'''
test_data=test_data.test_labels[:2000].cuda()
'''
对于CNN与损失函数利用cuda加速
'''
class CNN(nn.Module):
 ...
cnn=CNN()
cnn.cuda()
loss_f = t.nn.CrossEntropyLoss()
loss_f = loss_f.cuda()

而在train时,对于train_data训练过程进行GPU加速。也同样+.cuda()。

for epoch ..:
 for step, ...:
 1
'''
若你的train_data在训练时需要进行操作
若没有其他操作仅仅只利用cnn()则无需另加.cuda()
'''
#eg
 train_data = torch.max(teain_data, 1)[1].cuda()

补充:取出数据需要从GPU切换到CPU上进行操作

eg:

loss = loss.cpu()
acc = acc.cpu()

理解并不全,如有纰漏或者错误还望各位大佬指点迷津

以上这篇PyTorch-GPU加速实例就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
使用beaker让Facebook的Bottle框架支持session功能
Apr 23 Python
Python中列表元素转为数字的方法分析
Jun 14 Python
Django框架模板注入操作示例【变量传递到模板】
Dec 19 Python
使用Python的OpenCV模块识别滑动验证码的缺口(推荐)
May 10 Python
Django Channels 实现点对点实时聊天和消息推送功能
Jul 17 Python
Django 源码WSGI剖析过程详解
Aug 05 Python
Python高级特性——详解多维数组切片(Slice)
Nov 26 Python
使用opencv识别图像红色区域,并输出红色区域中心点坐标
Jun 02 Python
怎么快速自学python
Jun 22 Python
Python基于callable函数检测对象是否可被调用
Oct 16 Python
python实现图片,视频人脸识别(opencv版)
Nov 18 Python
python 爬取百度文库并下载(免费文章限定)
Dec 04 Python
Python基于yaml文件配置logging日志过程解析
Jun 23 #Python
解决pytorch多GPU训练保存的模型,在单GPU环境下加载出错问题
Jun 23 #Python
Python 程序报错崩溃后如何倒回到崩溃的位置(推荐)
Jun 23 #Python
浅谈pytorch中的BN层的注意事项
Jun 23 #Python
Python3与fastdfs分布式文件系统如何实现交互
Jun 23 #Python
踩坑:pytorch中eval模式下结果远差于train模式介绍
Jun 23 #Python
pytorch掉坑记录:model.eval的作用说明
Jun 23 #Python
You might like
星际争霸中的热键
2020/03/04 星际争霸
使用session判断用户登录用户权限(超简单)
2013/06/08 PHP
PHP以json或xml格式返回请求数据的方法
2018/05/31 PHP
Laravel 5.1 框架Blade模板引擎用法实例分析
2020/01/04 PHP
PHP safe_mode开启对于PHP系统函数有什么影响
2020/11/10 PHP
window.parent调用父框架时 ie跟火狐不兼容问题
2009/07/30 Javascript
用JavaScript将从数据库中读取出来的日期型格式化为想要的类型。
2009/08/15 Javascript
JavaScript写的一个自定义弹出式对话框代码
2010/01/17 Javascript
javascript+iframe 实现无刷新载入整页的代码
2010/03/17 Javascript
基于jquery的tab切换 js原理
2010/04/01 Javascript
JavaScript 构造函数 面相对象学习必备知识
2010/06/09 Javascript
最短的javascript:地址栏载入脚本代码
2011/10/13 Javascript
iframe子页面获取父页面元素的方法
2013/11/05 Javascript
JS模拟简易滚动条效果代码(附demo源码)
2016/04/05 Javascript
一步一步的了解webpack4的splitChunk插件(小结)
2018/09/17 Javascript
利用Angular2的Observables实现交互控制的方法
2018/12/27 Javascript
泛谈JS逻辑判断选择器 || &&
2019/05/24 Javascript
element的el-table中记录滚动条位置的示例代码
2019/11/06 Javascript
js回调函数仿360开机
2019/12/26 Javascript
vant 自定义 van-dropdown-item的用法
2020/08/05 Javascript
Python2.x和3.x下maketrans与translate函数使用上的不同
2015/04/13 Python
JSONLINT:python的json数据验证库实例解析
2017/11/28 Python
python3使用SMTP发送HTML格式邮件
2018/06/19 Python
python实战串口助手_解决8串口多个发送的问题
2019/06/12 Python
python 哈希表实现简单python字典代码实例
2019/09/27 Python
python实现堆排序的实例讲解
2020/02/21 Python
KIKO MILANO英国官网:意大利知名化妆品和护肤品品牌
2017/09/25 全球购物
三星法国官方网站:Samsung法国
2019/10/31 全球购物
网络安全方面的面试题
2015/11/04 面试题
党员干部2014全国两会学习心得体会
2014/03/10 职场文书
吨的认识教学反思
2014/04/27 职场文书
大学新生军训自我鉴定范文
2014/09/13 职场文书
2016年情人节广告语
2016/01/28 职场文书
Python爬虫基础之爬虫的分类知识总结
2021/05/13 Python
Redis实战高并发之扣减库存项目
2022/04/14 Redis
Win11 Beta 预览版 22621.575 和 22622.575更新补丁KB5016694发布(附更新内容大全)
2022/08/14 数码科技