PyTorch-GPU加速实例


Posted in Python onJune 23, 2020

硬件:NVIDIA-GTX1080

软件:Windows7、python3.6.5、pytorch-gpu-0.4.1

一、基础知识

将数据和网络都推到GPU,接上.cuda()

二、代码展示

import torch
import torch.nn as nn
import torch.utils.data as Data
import torchvision
# torch.manual_seed(1)
 
EPOCH = 1
BATCH_SIZE = 50
LR = 0.001
DOWNLOAD_MNIST = False
 
train_data = torchvision.datasets.MNIST(root='./mnist/', train=True, transform=torchvision.transforms.ToTensor(), download=DOWNLOAD_MNIST,)
train_loader = Data.DataLoader(dataset=train_data, batch_size=BATCH_SIZE, shuffle=True)
 
test_data = torchvision.datasets.MNIST(root='./mnist/', train=False)
 
# !!!!!!!! Change in here !!!!!!!!! #
test_x = torch.unsqueeze(test_data.test_data, dim=1).type(torch.FloatTensor)[:2000].cuda()/255. # Tensor on GPU
test_y = test_data.test_labels[:2000].cuda()
 
class CNN(nn.Module):
 def __init__(self):
  super(CNN, self).__init__()
  self.conv1 = nn.Sequential(nn.Conv2d(in_channels=1, out_channels=16, kernel_size=5, stride=1, padding=2,),
         nn.ReLU(), nn.MaxPool2d(kernel_size=2),)
  self.conv2 = nn.Sequential(nn.Conv2d(16, 32, 5, 1, 2), nn.ReLU(), nn.MaxPool2d(2),)
  self.out = nn.Linear(32 * 7 * 7, 10)
 
 def forward(self, x):
  x = self.conv1(x)
  x = self.conv2(x)
  x = x.view(x.size(0), -1)
  output = self.out(x)
  return output
 
cnn = CNN()
 
# !!!!!!!! Change in here !!!!!!!!! #
cnn.cuda()  # Moves all model parameters and buffers to the GPU.
 
optimizer = torch.optim.Adam(cnn.parameters(), lr=LR)
loss_func = nn.CrossEntropyLoss()
 
for epoch in range(EPOCH):
 for step, (x, y) in enumerate(train_loader):
 
  # !!!!!!!! Change in here !!!!!!!!! #
  b_x = x.cuda() # Tensor on GPU
  b_y = y.cuda() # Tensor on GPU
 
  output = cnn(b_x)
  loss = loss_func(output, b_y)
  optimizer.zero_grad()
  loss.backward()
  optimizer.step()
 
  if step % 50 == 0:
   test_output = cnn(test_x)
 
   # !!!!!!!! Change in here !!!!!!!!! #
   pred_y = torch.max(test_output, 1)[1].cuda().data # move the computation in GPU
 
   accuracy = torch.sum(pred_y == test_y).type(torch.FloatTensor) / test_y.size(0)
   print('Epoch: ', epoch, '| train loss: %.4f' % loss, '| test accuracy: %.2f' % accuracy)
 
test_output = cnn(test_x[:10])
 
# !!!!!!!! Change in here !!!!!!!!! #
pred_y = torch.max(test_output, 1)[1].cuda().data # move the computation in GPU
 
print(pred_y, 'prediction number')
print(test_y[:10], 'real number')

三、结果展示

PyTorch-GPU加速实例

补充知识:pytorch使用gpu对网络计算进行加速

1.基本要求

你的电脑里面有合适的GPU显卡(NVIDA),并且需要支持CUDA模块

你必须安装GPU版的Torch,(详细安装方法请移步pytorch官网)

2.使用GPU训练CNN

利用pytorch使用GPU进行加速方法主要就是将数据的形式变成GPU能读的形式,然后将CNN也变成GPU能读的形式,具体办法就是在后面加上.cuda()。

例如:

#如何检查自己电脑是否支持cuda
print torch.cuda.is_available()
# 返回True代表支持,False代表不支持
'''
注意在进行某种运算的时候使用.cuda()
'''
test_data=test_data.test_labels[:2000].cuda()
'''
对于CNN与损失函数利用cuda加速
'''
class CNN(nn.Module):
 ...
cnn=CNN()
cnn.cuda()
loss_f = t.nn.CrossEntropyLoss()
loss_f = loss_f.cuda()

而在train时,对于train_data训练过程进行GPU加速。也同样+.cuda()。

for epoch ..:
 for step, ...:
 1
'''
若你的train_data在训练时需要进行操作
若没有其他操作仅仅只利用cnn()则无需另加.cuda()
'''
#eg
 train_data = torch.max(teain_data, 1)[1].cuda()

补充:取出数据需要从GPU切换到CPU上进行操作

eg:

loss = loss.cpu()
acc = acc.cpu()

理解并不全,如有纰漏或者错误还望各位大佬指点迷津

以上这篇PyTorch-GPU加速实例就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
Python 随机生成中文验证码的实例代码
Mar 20 Python
Python探索之实现一个简单的HTTP服务器
Oct 28 Python
python寻找list中最大值、最小值并返回其所在位置的方法
Jun 27 Python
Django添加feeds功能的示例
Aug 07 Python
python使用pymongo操作mongo的完整步骤
Apr 13 Python
django多对多表的创建,级联删除及手动创建第三张表
Jul 25 Python
Python logging设置和logger解析
Aug 28 Python
Python RabbitMQ实现简单的进程间通信示例
Jul 02 Python
django 将自带的数据库sqlite3改成mysql实例
Jul 09 Python
Python使用正则表达式实现爬虫数据抽取
Aug 17 Python
Java Unsafe类实现原理及测试代码
Sep 15 Python
python面向对象版学生信息管理系统
Jun 24 Python
Python基于yaml文件配置logging日志过程解析
Jun 23 #Python
解决pytorch多GPU训练保存的模型,在单GPU环境下加载出错问题
Jun 23 #Python
Python 程序报错崩溃后如何倒回到崩溃的位置(推荐)
Jun 23 #Python
浅谈pytorch中的BN层的注意事项
Jun 23 #Python
Python3与fastdfs分布式文件系统如何实现交互
Jun 23 #Python
踩坑:pytorch中eval模式下结果远差于train模式介绍
Jun 23 #Python
pytorch掉坑记录:model.eval的作用说明
Jun 23 #Python
You might like
微盾PHP脚本加密专家php解密算法
2020/09/13 PHP
getimagesize获取图片尺寸实例
2014/11/15 PHP
PHP如何将XML转成数组
2016/04/04 PHP
Laravel框架验证码类用法实例分析
2019/09/11 PHP
php生成短网址/短链接原理和用法实例分析
2020/05/29 PHP
javascript不同类型数据之间的运算的转换方法
2014/02/13 Javascript
现如今最流行的JavaScript代码规范
2014/03/08 Javascript
编写高效jQuery代码的4个原则和5个技巧
2014/04/24 Javascript
轻松创建nodejs服务器(6):作出响应
2014/12/18 NodeJs
jQuery基础语法实例入门
2014/12/23 Javascript
Angular中的Promise对象($q介绍)
2015/03/03 Javascript
关于延迟加载JavaScript
2015/05/05 Javascript
详解js中class的多种函数封装方法
2016/01/03 Javascript
jQuery内容过滤选择器用法示例
2016/09/09 Javascript
基于angular2 的 http服务封装的实例代码
2017/06/29 Javascript
详解Vue.js在页面加载时执行某个方法
2018/11/20 Javascript
JS实现单张或多张图片持续无缝滚动的示例代码
2020/05/10 Javascript
js实现批量删除功能
2020/08/27 Javascript
Python实现同时兼容老版和新版Socket协议的一个简单WebSocket服务器
2014/06/04 Python
Python入门篇之文件
2014/10/20 Python
python实现百万答题自动百度搜索答案
2018/01/16 Python
Python实现获取系统临时目录及临时文件的方法示例
2019/06/26 Python
使用python接受tgam的脑波数据实例
2020/04/09 Python
django 利用Q对象与F对象进行查询的实现
2020/05/15 Python
django queryset 去重 .distinct()说明
2020/05/19 Python
Python celery原理及运行流程解析
2020/06/13 Python
整理HTML5中表单的常用属性及新属性
2016/02/19 HTML / CSS
伦敦最著名的老字号百货公司:Selfridges(塞尔福里奇百货)
2016/07/25 全球购物
举例说明类变量和实例变量的区别
2016/06/30 面试题
投标诚信承诺书
2014/05/26 职场文书
2015初中团委工作总结
2015/07/28 职场文书
golang interface判断为空nil的实现代码
2021/04/24 Golang
详解MySQL多版本并发控制机制(MVCC)源码
2021/06/23 MySQL
Go语言应该什么情况使用指针
2021/07/25 Golang
Windows Server 修改远程桌面端口的实现
2022/06/25 Servers
CSS 实现磨砂玻璃(毛玻璃)效果样式
2023/05/21 HTML / CSS