PyTorch-GPU加速实例


Posted in Python onJune 23, 2020

硬件:NVIDIA-GTX1080

软件:Windows7、python3.6.5、pytorch-gpu-0.4.1

一、基础知识

将数据和网络都推到GPU,接上.cuda()

二、代码展示

import torch
import torch.nn as nn
import torch.utils.data as Data
import torchvision
# torch.manual_seed(1)
 
EPOCH = 1
BATCH_SIZE = 50
LR = 0.001
DOWNLOAD_MNIST = False
 
train_data = torchvision.datasets.MNIST(root='./mnist/', train=True, transform=torchvision.transforms.ToTensor(), download=DOWNLOAD_MNIST,)
train_loader = Data.DataLoader(dataset=train_data, batch_size=BATCH_SIZE, shuffle=True)
 
test_data = torchvision.datasets.MNIST(root='./mnist/', train=False)
 
# !!!!!!!! Change in here !!!!!!!!! #
test_x = torch.unsqueeze(test_data.test_data, dim=1).type(torch.FloatTensor)[:2000].cuda()/255. # Tensor on GPU
test_y = test_data.test_labels[:2000].cuda()
 
class CNN(nn.Module):
 def __init__(self):
  super(CNN, self).__init__()
  self.conv1 = nn.Sequential(nn.Conv2d(in_channels=1, out_channels=16, kernel_size=5, stride=1, padding=2,),
         nn.ReLU(), nn.MaxPool2d(kernel_size=2),)
  self.conv2 = nn.Sequential(nn.Conv2d(16, 32, 5, 1, 2), nn.ReLU(), nn.MaxPool2d(2),)
  self.out = nn.Linear(32 * 7 * 7, 10)
 
 def forward(self, x):
  x = self.conv1(x)
  x = self.conv2(x)
  x = x.view(x.size(0), -1)
  output = self.out(x)
  return output
 
cnn = CNN()
 
# !!!!!!!! Change in here !!!!!!!!! #
cnn.cuda()  # Moves all model parameters and buffers to the GPU.
 
optimizer = torch.optim.Adam(cnn.parameters(), lr=LR)
loss_func = nn.CrossEntropyLoss()
 
for epoch in range(EPOCH):
 for step, (x, y) in enumerate(train_loader):
 
  # !!!!!!!! Change in here !!!!!!!!! #
  b_x = x.cuda() # Tensor on GPU
  b_y = y.cuda() # Tensor on GPU
 
  output = cnn(b_x)
  loss = loss_func(output, b_y)
  optimizer.zero_grad()
  loss.backward()
  optimizer.step()
 
  if step % 50 == 0:
   test_output = cnn(test_x)
 
   # !!!!!!!! Change in here !!!!!!!!! #
   pred_y = torch.max(test_output, 1)[1].cuda().data # move the computation in GPU
 
   accuracy = torch.sum(pred_y == test_y).type(torch.FloatTensor) / test_y.size(0)
   print('Epoch: ', epoch, '| train loss: %.4f' % loss, '| test accuracy: %.2f' % accuracy)
 
test_output = cnn(test_x[:10])
 
# !!!!!!!! Change in here !!!!!!!!! #
pred_y = torch.max(test_output, 1)[1].cuda().data # move the computation in GPU
 
print(pred_y, 'prediction number')
print(test_y[:10], 'real number')

三、结果展示

PyTorch-GPU加速实例

补充知识:pytorch使用gpu对网络计算进行加速

1.基本要求

你的电脑里面有合适的GPU显卡(NVIDA),并且需要支持CUDA模块

你必须安装GPU版的Torch,(详细安装方法请移步pytorch官网)

2.使用GPU训练CNN

利用pytorch使用GPU进行加速方法主要就是将数据的形式变成GPU能读的形式,然后将CNN也变成GPU能读的形式,具体办法就是在后面加上.cuda()。

例如:

#如何检查自己电脑是否支持cuda
print torch.cuda.is_available()
# 返回True代表支持,False代表不支持
'''
注意在进行某种运算的时候使用.cuda()
'''
test_data=test_data.test_labels[:2000].cuda()
'''
对于CNN与损失函数利用cuda加速
'''
class CNN(nn.Module):
 ...
cnn=CNN()
cnn.cuda()
loss_f = t.nn.CrossEntropyLoss()
loss_f = loss_f.cuda()

而在train时,对于train_data训练过程进行GPU加速。也同样+.cuda()。

for epoch ..:
 for step, ...:
 1
'''
若你的train_data在训练时需要进行操作
若没有其他操作仅仅只利用cnn()则无需另加.cuda()
'''
#eg
 train_data = torch.max(teain_data, 1)[1].cuda()

补充:取出数据需要从GPU切换到CPU上进行操作

eg:

loss = loss.cpu()
acc = acc.cpu()

理解并不全,如有纰漏或者错误还望各位大佬指点迷津

以上这篇PyTorch-GPU加速实例就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
一文总结学习Python的14张思维导图
Oct 17 Python
Python决策树分类算法学习
Dec 22 Python
Python将list中的string批量转化成int/float的方法
Jun 26 Python
Python神奇的内置函数locals的实例讲解
Feb 22 Python
Python 一键制作微信好友图片墙的方法
May 16 Python
python 猴子补丁(monkey patch)
Jun 26 Python
浅析Python与Mongodb数据库之间的操作方法
Jul 01 Python
这可能是最好玩的python GUI入门实例(推荐)
Jul 19 Python
python Django 反向访问器的外键冲突解决
May 20 Python
利用Python实现某OA系统的自动定位功能
May 27 Python
Keras 数据增强ImageDataGenerator多输入多输出实例
Jul 03 Python
Python使用sys.exc_info()方法获取异常信息
Jul 23 Python
Python基于yaml文件配置logging日志过程解析
Jun 23 #Python
解决pytorch多GPU训练保存的模型,在单GPU环境下加载出错问题
Jun 23 #Python
Python 程序报错崩溃后如何倒回到崩溃的位置(推荐)
Jun 23 #Python
浅谈pytorch中的BN层的注意事项
Jun 23 #Python
Python3与fastdfs分布式文件系统如何实现交互
Jun 23 #Python
踩坑:pytorch中eval模式下结果远差于train模式介绍
Jun 23 #Python
pytorch掉坑记录:model.eval的作用说明
Jun 23 #Python
You might like
Search Engine Friendly的URL设计
2006/10/09 PHP
PHP实现QQ快速登录的方法
2016/09/28 PHP
javascript import css实例代码
2008/07/18 Javascript
js no-repeat写法 背景不重复
2009/03/18 Javascript
js 页面关闭前的出现提示的实现代码
2011/05/25 Javascript
javascript parseInt() 函数的进制转换注意细节
2013/01/08 Javascript
jquery实现输入框动态增减的实例代码
2013/07/14 Javascript
带左右箭头图片轮播的JS代码
2013/12/18 Javascript
jquery中attr和prop的区别分析
2015/03/16 Javascript
jQuery+PHP+MySQL实现无限级联下拉框效果
2016/02/19 Javascript
jQuery animate和CSS3相结合实现缓动追逐效果附源码下载
2016/04/18 Javascript
jQuery移动端日期(datedropper)和时间(timedropper)选择器附源码下载
2016/04/19 Javascript
JavaScript中Promise的使用详解
2017/02/26 Javascript
JavaScript ES6中export、import与export default的用法和区别
2017/03/14 Javascript
新手vue构建单页面应用实例代码
2017/09/18 Javascript
结合Vue控制字符和字节的显示个数的示例
2018/05/17 Javascript
Vue中常用rules校验规则(实例代码)
2019/11/14 Javascript
JavaScript对象原型链原理解析
2020/01/22 Javascript
[01:59]游戏“zheng”当时试玩会
2019/08/21 DOTA
Python xlrd读取excel日期类型的2种方法
2015/04/28 Python
Windows下实现Python2和Python3两个版共存的方法
2015/06/12 Python
pygame游戏之旅 添加碰撞效果的方法
2018/11/20 Python
pycharm的console输入实现换行的方法
2019/01/16 Python
python遍历文件目录、批量处理同类文件
2019/08/31 Python
Python IDLE或shell中切换路径的操作
2020/03/09 Python
python学习将数据写入文件并保存方法
2020/06/07 Python
HTML5和以前HTML4的区别整理
2013/10/20 HTML / CSS
巴西本土电商平台:Americanas
2020/06/21 全球购物
工商企业管理实习自我鉴定
2013/12/04 职场文书
心得体会开头
2014/01/01 职场文书
《神奇的克隆》教学反思
2014/04/10 职场文书
乡镇民主生活会发言材料
2014/10/20 职场文书
铁路安全反思材料
2014/12/24 职场文书
离婚案件原告代理词
2015/05/23 职场文书
2015大学生暑假调查报告
2015/07/13 职场文书
十个Python自动化常用操作,即拿即用
2021/05/10 Python