PyTorch-GPU加速实例


Posted in Python onJune 23, 2020

硬件:NVIDIA-GTX1080

软件:Windows7、python3.6.5、pytorch-gpu-0.4.1

一、基础知识

将数据和网络都推到GPU,接上.cuda()

二、代码展示

import torch
import torch.nn as nn
import torch.utils.data as Data
import torchvision
# torch.manual_seed(1)
 
EPOCH = 1
BATCH_SIZE = 50
LR = 0.001
DOWNLOAD_MNIST = False
 
train_data = torchvision.datasets.MNIST(root='./mnist/', train=True, transform=torchvision.transforms.ToTensor(), download=DOWNLOAD_MNIST,)
train_loader = Data.DataLoader(dataset=train_data, batch_size=BATCH_SIZE, shuffle=True)
 
test_data = torchvision.datasets.MNIST(root='./mnist/', train=False)
 
# !!!!!!!! Change in here !!!!!!!!! #
test_x = torch.unsqueeze(test_data.test_data, dim=1).type(torch.FloatTensor)[:2000].cuda()/255. # Tensor on GPU
test_y = test_data.test_labels[:2000].cuda()
 
class CNN(nn.Module):
 def __init__(self):
  super(CNN, self).__init__()
  self.conv1 = nn.Sequential(nn.Conv2d(in_channels=1, out_channels=16, kernel_size=5, stride=1, padding=2,),
         nn.ReLU(), nn.MaxPool2d(kernel_size=2),)
  self.conv2 = nn.Sequential(nn.Conv2d(16, 32, 5, 1, 2), nn.ReLU(), nn.MaxPool2d(2),)
  self.out = nn.Linear(32 * 7 * 7, 10)
 
 def forward(self, x):
  x = self.conv1(x)
  x = self.conv2(x)
  x = x.view(x.size(0), -1)
  output = self.out(x)
  return output
 
cnn = CNN()
 
# !!!!!!!! Change in here !!!!!!!!! #
cnn.cuda()  # Moves all model parameters and buffers to the GPU.
 
optimizer = torch.optim.Adam(cnn.parameters(), lr=LR)
loss_func = nn.CrossEntropyLoss()
 
for epoch in range(EPOCH):
 for step, (x, y) in enumerate(train_loader):
 
  # !!!!!!!! Change in here !!!!!!!!! #
  b_x = x.cuda() # Tensor on GPU
  b_y = y.cuda() # Tensor on GPU
 
  output = cnn(b_x)
  loss = loss_func(output, b_y)
  optimizer.zero_grad()
  loss.backward()
  optimizer.step()
 
  if step % 50 == 0:
   test_output = cnn(test_x)
 
   # !!!!!!!! Change in here !!!!!!!!! #
   pred_y = torch.max(test_output, 1)[1].cuda().data # move the computation in GPU
 
   accuracy = torch.sum(pred_y == test_y).type(torch.FloatTensor) / test_y.size(0)
   print('Epoch: ', epoch, '| train loss: %.4f' % loss, '| test accuracy: %.2f' % accuracy)
 
test_output = cnn(test_x[:10])
 
# !!!!!!!! Change in here !!!!!!!!! #
pred_y = torch.max(test_output, 1)[1].cuda().data # move the computation in GPU
 
print(pred_y, 'prediction number')
print(test_y[:10], 'real number')

三、结果展示

PyTorch-GPU加速实例

补充知识:pytorch使用gpu对网络计算进行加速

1.基本要求

你的电脑里面有合适的GPU显卡(NVIDA),并且需要支持CUDA模块

你必须安装GPU版的Torch,(详细安装方法请移步pytorch官网)

2.使用GPU训练CNN

利用pytorch使用GPU进行加速方法主要就是将数据的形式变成GPU能读的形式,然后将CNN也变成GPU能读的形式,具体办法就是在后面加上.cuda()。

例如:

#如何检查自己电脑是否支持cuda
print torch.cuda.is_available()
# 返回True代表支持,False代表不支持
'''
注意在进行某种运算的时候使用.cuda()
'''
test_data=test_data.test_labels[:2000].cuda()
'''
对于CNN与损失函数利用cuda加速
'''
class CNN(nn.Module):
 ...
cnn=CNN()
cnn.cuda()
loss_f = t.nn.CrossEntropyLoss()
loss_f = loss_f.cuda()

而在train时,对于train_data训练过程进行GPU加速。也同样+.cuda()。

for epoch ..:
 for step, ...:
 1
'''
若你的train_data在训练时需要进行操作
若没有其他操作仅仅只利用cnn()则无需另加.cuda()
'''
#eg
 train_data = torch.max(teain_data, 1)[1].cuda()

补充:取出数据需要从GPU切换到CPU上进行操作

eg:

loss = loss.cpu()
acc = acc.cpu()

理解并不全,如有纰漏或者错误还望各位大佬指点迷津

以上这篇PyTorch-GPU加速实例就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
Python接收Gmail新邮件并发送到gtalk的方法
Mar 10 Python
用Python将动态GIF图片倒放播放的方法
Nov 02 Python
python 计算两个日期相差多少个月实例代码
May 24 Python
python将文本中的空格替换为换行的方法
Mar 19 Python
查看django版本的方法分享
May 14 Python
Python高级用法总结
May 26 Python
python实现flappy bird小游戏
Dec 24 Python
Django打印出在数据库中执行的语句问题
Jul 25 Python
Python Scrapy多页数据爬取实现过程解析
Jun 12 Python
Python getattr()函数使用方法代码实例
Aug 10 Python
解决import tensorflow导致jupyter内核死亡的问题
Feb 06 Python
Python中22个万用公式的小结
Jul 21 Python
Python基于yaml文件配置logging日志过程解析
Jun 23 #Python
解决pytorch多GPU训练保存的模型,在单GPU环境下加载出错问题
Jun 23 #Python
Python 程序报错崩溃后如何倒回到崩溃的位置(推荐)
Jun 23 #Python
浅谈pytorch中的BN层的注意事项
Jun 23 #Python
Python3与fastdfs分布式文件系统如何实现交互
Jun 23 #Python
踩坑:pytorch中eval模式下结果远差于train模式介绍
Jun 23 #Python
pytorch掉坑记录:model.eval的作用说明
Jun 23 #Python
You might like
php中配置文件操作 如config.php文件的读取修改等操作
2012/07/07 PHP
关于Iframe如何跨域访问Cookie和Session的解决方法
2013/04/15 PHP
PHP 魔术变量和魔术函数详解
2015/02/25 PHP
一个简单安全的PHP验证码类 附调用方法
2016/06/24 PHP
php遍历替换目录下文件指定内容的方法
2016/11/10 PHP
老生常谈PHP数组函数array_merge(必看篇)
2017/05/25 PHP
Laravel框架表单验证操作实例分析
2019/09/30 PHP
jQuery页面滚动浮动层智能定位实例代码
2011/08/23 Javascript
关于火狐(firefox)及ie下event获取的两种方法
2012/12/27 Javascript
将两个div左右并列显示并实现点击标题切换内容
2013/10/22 Javascript
jQuery对val和atrr("value")赋值的区别介绍
2014/09/26 Javascript
Javascript 数组排序详解
2014/10/22 Javascript
js的[defer]和[async]属性
2014/11/24 Javascript
jQuery可见性过滤器:hidden和:visibility用法实例
2015/06/24 Javascript
纯js实现瀑布流布局及ajax动态新增数据
2016/04/07 Javascript
详解javascript中对数据格式化的思考
2017/01/23 Javascript
vue+element实现图片上传及裁剪功能
2020/06/29 Javascript
vue 动态创建组件的两种方法
2020/12/31 Vue.js
Python获取linux主机ip的简单实现方法
2016/04/18 Python
python http基本验证方法
2018/12/26 Python
总结python中pass的作用
2019/02/27 Python
Python使用sklearn库实现的各种分类算法简单应用小结
2019/07/04 Python
Python 可变类型和不可变类型及引用过程解析
2019/09/27 Python
澳大利亚第一的设计师礼服租赁网站:GlamCorner
2017/08/13 全球购物
伊芙丽官方旗舰店:中国淑女一线品牌
2017/12/01 全球购物
俄罗斯苹果优质经销商商店:iPort
2020/05/27 全球购物
什么是java序列化,如何实现java序列化
2012/11/14 面试题
啤酒销售实习自我鉴定
2013/09/24 职场文书
开服装店计划书
2014/08/15 职场文书
判缓刑人员个人思想汇报
2014/10/10 职场文书
初中教师个人工作总结
2015/02/10 职场文书
社区端午节活动总结
2015/02/11 职场文书
离婚起诉书范本
2015/05/18 职场文书
音乐之声观后感
2015/06/04 职场文书
哪类餐饮行业,最适合在高校创业?
2019/08/19 职场文书
MySQL单表千万级数据处理的思路分享
2021/06/05 MySQL