PyTorch-GPU加速实例


Posted in Python onJune 23, 2020

硬件:NVIDIA-GTX1080

软件:Windows7、python3.6.5、pytorch-gpu-0.4.1

一、基础知识

将数据和网络都推到GPU,接上.cuda()

二、代码展示

import torch
import torch.nn as nn
import torch.utils.data as Data
import torchvision
# torch.manual_seed(1)
 
EPOCH = 1
BATCH_SIZE = 50
LR = 0.001
DOWNLOAD_MNIST = False
 
train_data = torchvision.datasets.MNIST(root='./mnist/', train=True, transform=torchvision.transforms.ToTensor(), download=DOWNLOAD_MNIST,)
train_loader = Data.DataLoader(dataset=train_data, batch_size=BATCH_SIZE, shuffle=True)
 
test_data = torchvision.datasets.MNIST(root='./mnist/', train=False)
 
# !!!!!!!! Change in here !!!!!!!!! #
test_x = torch.unsqueeze(test_data.test_data, dim=1).type(torch.FloatTensor)[:2000].cuda()/255. # Tensor on GPU
test_y = test_data.test_labels[:2000].cuda()
 
class CNN(nn.Module):
 def __init__(self):
  super(CNN, self).__init__()
  self.conv1 = nn.Sequential(nn.Conv2d(in_channels=1, out_channels=16, kernel_size=5, stride=1, padding=2,),
         nn.ReLU(), nn.MaxPool2d(kernel_size=2),)
  self.conv2 = nn.Sequential(nn.Conv2d(16, 32, 5, 1, 2), nn.ReLU(), nn.MaxPool2d(2),)
  self.out = nn.Linear(32 * 7 * 7, 10)
 
 def forward(self, x):
  x = self.conv1(x)
  x = self.conv2(x)
  x = x.view(x.size(0), -1)
  output = self.out(x)
  return output
 
cnn = CNN()
 
# !!!!!!!! Change in here !!!!!!!!! #
cnn.cuda()  # Moves all model parameters and buffers to the GPU.
 
optimizer = torch.optim.Adam(cnn.parameters(), lr=LR)
loss_func = nn.CrossEntropyLoss()
 
for epoch in range(EPOCH):
 for step, (x, y) in enumerate(train_loader):
 
  # !!!!!!!! Change in here !!!!!!!!! #
  b_x = x.cuda() # Tensor on GPU
  b_y = y.cuda() # Tensor on GPU
 
  output = cnn(b_x)
  loss = loss_func(output, b_y)
  optimizer.zero_grad()
  loss.backward()
  optimizer.step()
 
  if step % 50 == 0:
   test_output = cnn(test_x)
 
   # !!!!!!!! Change in here !!!!!!!!! #
   pred_y = torch.max(test_output, 1)[1].cuda().data # move the computation in GPU
 
   accuracy = torch.sum(pred_y == test_y).type(torch.FloatTensor) / test_y.size(0)
   print('Epoch: ', epoch, '| train loss: %.4f' % loss, '| test accuracy: %.2f' % accuracy)
 
test_output = cnn(test_x[:10])
 
# !!!!!!!! Change in here !!!!!!!!! #
pred_y = torch.max(test_output, 1)[1].cuda().data # move the computation in GPU
 
print(pred_y, 'prediction number')
print(test_y[:10], 'real number')

三、结果展示

PyTorch-GPU加速实例

补充知识:pytorch使用gpu对网络计算进行加速

1.基本要求

你的电脑里面有合适的GPU显卡(NVIDA),并且需要支持CUDA模块

你必须安装GPU版的Torch,(详细安装方法请移步pytorch官网)

2.使用GPU训练CNN

利用pytorch使用GPU进行加速方法主要就是将数据的形式变成GPU能读的形式,然后将CNN也变成GPU能读的形式,具体办法就是在后面加上.cuda()。

例如:

#如何检查自己电脑是否支持cuda
print torch.cuda.is_available()
# 返回True代表支持,False代表不支持
'''
注意在进行某种运算的时候使用.cuda()
'''
test_data=test_data.test_labels[:2000].cuda()
'''
对于CNN与损失函数利用cuda加速
'''
class CNN(nn.Module):
 ...
cnn=CNN()
cnn.cuda()
loss_f = t.nn.CrossEntropyLoss()
loss_f = loss_f.cuda()

而在train时,对于train_data训练过程进行GPU加速。也同样+.cuda()。

for epoch ..:
 for step, ...:
 1
'''
若你的train_data在训练时需要进行操作
若没有其他操作仅仅只利用cnn()则无需另加.cuda()
'''
#eg
 train_data = torch.max(teain_data, 1)[1].cuda()

补充:取出数据需要从GPU切换到CPU上进行操作

eg:

loss = loss.cpu()
acc = acc.cpu()

理解并不全,如有纰漏或者错误还望各位大佬指点迷津

以上这篇PyTorch-GPU加速实例就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
跟老齐学Python之编写类之二方法
Oct 11 Python
Python按行读取文件的简单实现方法
Jun 22 Python
Python操作Access数据库基本步骤分析
Sep 19 Python
python中利用xml.dom模块解析xml的方法教程
May 24 Python
R vs. Python 数据分析中谁与争锋?
Oct 18 Python
Python中pandas模块DataFrame创建方法示例
Jun 20 Python
Python双向循环链表实现方法分析
Jul 30 Python
Python实现求两个数组交集的方法示例
Feb 23 Python
分析运行中的 Python 进程详细解析
Jun 22 Python
mac系统下Redis安装和使用步骤详解
Jul 09 Python
利用Python实现字幕挂载(把字幕文件与视频合并)思路详解
Oct 21 Python
5道关于python基础 while循环练习题
Nov 27 Python
Python基于yaml文件配置logging日志过程解析
Jun 23 #Python
解决pytorch多GPU训练保存的模型,在单GPU环境下加载出错问题
Jun 23 #Python
Python 程序报错崩溃后如何倒回到崩溃的位置(推荐)
Jun 23 #Python
浅谈pytorch中的BN层的注意事项
Jun 23 #Python
Python3与fastdfs分布式文件系统如何实现交互
Jun 23 #Python
踩坑:pytorch中eval模式下结果远差于train模式介绍
Jun 23 #Python
pytorch掉坑记录:model.eval的作用说明
Jun 23 #Python
You might like
简单的方法让你的后台登录更加安全(php中加session验证)
2012/08/22 PHP
PHP多进程编程实例
2014/10/15 PHP
PHP PDOStatement::getAttribute讲解
2019/02/01 PHP
PHP获取当前时间不准确问题解决方案
2020/08/14 PHP
漂亮的widgets,支持换肤和后期开发新皮肤(2007-4-27已更新1.7alpha)
2007/04/27 Javascript
jquery遍历input取得input的name
2009/04/27 Javascript
js操作iframe的一些方法介绍
2013/06/25 Javascript
利用jquery包将字符串生成二维码图片
2013/09/12 Javascript
javascript预加载图片、css、js的方法示例介绍
2013/10/14 Javascript
js中如何复制一个对象并获取其所有属性和属性对应的值
2013/10/24 Javascript
php is_numberic函数造成的SQL注入漏洞
2014/03/10 Javascript
javascript中interval与setTimeOut的区别示例介绍
2014/03/14 Javascript
JavaScript实现大数的运算
2014/11/24 Javascript
jQuery中delegate()方法用法实例
2015/01/19 Javascript
提交按钮的name='submit'引起的js失效问题及原因
2015/02/25 Javascript
JavaScript三元运算符的多种使用技巧
2015/04/16 Javascript
javascript实现网页屏蔽Backspace事件,输入框不屏蔽
2015/07/21 Javascript
Bootstrap组件(一)之菜单
2016/05/11 Javascript
vue addRoutes实现动态权限路由菜单的示例
2018/05/15 Javascript
angular 实现的输入框数字千分位及保留几位小数点功能示例
2018/06/19 Javascript
微信小程序非跳转式组件授权登录的方法示例
2019/05/22 Javascript
vue-cli3项目展示本地Markdown文件的方法
2019/06/07 Javascript
在Python下尝试多线程编程
2015/04/28 Python
Python下载懒人图库JavaScript特效
2015/05/28 Python
Python实现一个转存纯真IP数据库的脚本分享
2017/05/21 Python
Python使用ffmpy将amr格式的音频转化为mp3格式的例子
2019/08/08 Python
YUV转为jpg图像的实现
2019/12/09 Python
3种python调用其他脚本的方法
2020/01/06 Python
基于python实现操作redis及消息队列
2020/08/27 Python
Big Green Smile法国:领先的英国有机和天然产品在线商店
2021/01/02 全球购物
俄罗斯首家面向中国消费者的一站式购物网站:Wruru
2020/05/08 全球购物
服务生自我鉴定
2014/01/22 职场文书
企业员工爱岗敬业演讲稿
2014/08/26 职场文书
学习张丽丽心得体会
2014/09/03 职场文书
教你用Python matplotlib库制作简单的动画
2021/06/11 Python
Python编写冷笑话生成器
2022/04/20 Python