PyTorch-GPU加速实例


Posted in Python onJune 23, 2020

硬件:NVIDIA-GTX1080

软件:Windows7、python3.6.5、pytorch-gpu-0.4.1

一、基础知识

将数据和网络都推到GPU,接上.cuda()

二、代码展示

import torch
import torch.nn as nn
import torch.utils.data as Data
import torchvision
# torch.manual_seed(1)
 
EPOCH = 1
BATCH_SIZE = 50
LR = 0.001
DOWNLOAD_MNIST = False
 
train_data = torchvision.datasets.MNIST(root='./mnist/', train=True, transform=torchvision.transforms.ToTensor(), download=DOWNLOAD_MNIST,)
train_loader = Data.DataLoader(dataset=train_data, batch_size=BATCH_SIZE, shuffle=True)
 
test_data = torchvision.datasets.MNIST(root='./mnist/', train=False)
 
# !!!!!!!! Change in here !!!!!!!!! #
test_x = torch.unsqueeze(test_data.test_data, dim=1).type(torch.FloatTensor)[:2000].cuda()/255. # Tensor on GPU
test_y = test_data.test_labels[:2000].cuda()
 
class CNN(nn.Module):
 def __init__(self):
  super(CNN, self).__init__()
  self.conv1 = nn.Sequential(nn.Conv2d(in_channels=1, out_channels=16, kernel_size=5, stride=1, padding=2,),
         nn.ReLU(), nn.MaxPool2d(kernel_size=2),)
  self.conv2 = nn.Sequential(nn.Conv2d(16, 32, 5, 1, 2), nn.ReLU(), nn.MaxPool2d(2),)
  self.out = nn.Linear(32 * 7 * 7, 10)
 
 def forward(self, x):
  x = self.conv1(x)
  x = self.conv2(x)
  x = x.view(x.size(0), -1)
  output = self.out(x)
  return output
 
cnn = CNN()
 
# !!!!!!!! Change in here !!!!!!!!! #
cnn.cuda()  # Moves all model parameters and buffers to the GPU.
 
optimizer = torch.optim.Adam(cnn.parameters(), lr=LR)
loss_func = nn.CrossEntropyLoss()
 
for epoch in range(EPOCH):
 for step, (x, y) in enumerate(train_loader):
 
  # !!!!!!!! Change in here !!!!!!!!! #
  b_x = x.cuda() # Tensor on GPU
  b_y = y.cuda() # Tensor on GPU
 
  output = cnn(b_x)
  loss = loss_func(output, b_y)
  optimizer.zero_grad()
  loss.backward()
  optimizer.step()
 
  if step % 50 == 0:
   test_output = cnn(test_x)
 
   # !!!!!!!! Change in here !!!!!!!!! #
   pred_y = torch.max(test_output, 1)[1].cuda().data # move the computation in GPU
 
   accuracy = torch.sum(pred_y == test_y).type(torch.FloatTensor) / test_y.size(0)
   print('Epoch: ', epoch, '| train loss: %.4f' % loss, '| test accuracy: %.2f' % accuracy)
 
test_output = cnn(test_x[:10])
 
# !!!!!!!! Change in here !!!!!!!!! #
pred_y = torch.max(test_output, 1)[1].cuda().data # move the computation in GPU
 
print(pred_y, 'prediction number')
print(test_y[:10], 'real number')

三、结果展示

PyTorch-GPU加速实例

补充知识:pytorch使用gpu对网络计算进行加速

1.基本要求

你的电脑里面有合适的GPU显卡(NVIDA),并且需要支持CUDA模块

你必须安装GPU版的Torch,(详细安装方法请移步pytorch官网)

2.使用GPU训练CNN

利用pytorch使用GPU进行加速方法主要就是将数据的形式变成GPU能读的形式,然后将CNN也变成GPU能读的形式,具体办法就是在后面加上.cuda()。

例如:

#如何检查自己电脑是否支持cuda
print torch.cuda.is_available()
# 返回True代表支持,False代表不支持
'''
注意在进行某种运算的时候使用.cuda()
'''
test_data=test_data.test_labels[:2000].cuda()
'''
对于CNN与损失函数利用cuda加速
'''
class CNN(nn.Module):
 ...
cnn=CNN()
cnn.cuda()
loss_f = t.nn.CrossEntropyLoss()
loss_f = loss_f.cuda()

而在train时,对于train_data训练过程进行GPU加速。也同样+.cuda()。

for epoch ..:
 for step, ...:
 1
'''
若你的train_data在训练时需要进行操作
若没有其他操作仅仅只利用cnn()则无需另加.cuda()
'''
#eg
 train_data = torch.max(teain_data, 1)[1].cuda()

补充:取出数据需要从GPU切换到CPU上进行操作

eg:

loss = loss.cpu()
acc = acc.cpu()

理解并不全,如有纰漏或者错误还望各位大佬指点迷津

以上这篇PyTorch-GPU加速实例就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
详解Python中for循环的使用方法
May 14 Python
Python实现随机选择元素功能
Sep 14 Python
python cs架构实现简单文件传输
Mar 20 Python
APIStar:一个专为Python3设计的API框架
Sep 26 Python
使用PYTHON解析Wireshark的PCAP文件方法
Jul 23 Python
python多线程+代理池爬取天天基金网、股票数据过程解析
Aug 13 Python
新年福利来一波之Python轻松集齐五福(demo)
Jan 20 Python
Python接口自动化测试框架运行原理及流程
Nov 30 Python
Python中全局变量和局部变量的理解与区别
Feb 07 Python
pytorch 把图片数据转化成tensor的操作
Mar 04 Python
python用字节处理文件实例讲解
Apr 13 Python
Python time库的时间时钟处理
May 02 Python
Python基于yaml文件配置logging日志过程解析
Jun 23 #Python
解决pytorch多GPU训练保存的模型,在单GPU环境下加载出错问题
Jun 23 #Python
Python 程序报错崩溃后如何倒回到崩溃的位置(推荐)
Jun 23 #Python
浅谈pytorch中的BN层的注意事项
Jun 23 #Python
Python3与fastdfs分布式文件系统如何实现交互
Jun 23 #Python
踩坑:pytorch中eval模式下结果远差于train模式介绍
Jun 23 #Python
pytorch掉坑记录:model.eval的作用说明
Jun 23 #Python
You might like
利用PHP将部分内容用星号替换
2020/04/21 PHP
PHP实现清除MySQL死连接的方法
2016/07/23 PHP
yii2安装详细流程
2018/05/23 PHP
PHP使用HTML5 FormData对象提交表单操作示例
2019/07/02 PHP
gearman中任务的优先级和返回状态实例分析
2020/02/27 PHP
基于prototype扩展的JavaScript常用函数库
2010/11/30 Javascript
JS自动适应的图片弹窗实例
2013/06/29 Javascript
jQuery setTimeout传递字符串参数报错的解决方法
2014/06/09 Javascript
escape编码与unescape解码汉字出现乱码的解决方法
2014/07/02 Javascript
dreamweaver 8实现Jquery自动提示
2014/12/04 Javascript
javascript包装对象实例分析
2015/03/27 Javascript
HTML5 Shiv完美解决IE(IE6/IE7/IE8)不兼容HTML5标签的方法
2015/11/25 Javascript
Bootstrap每天必学之折叠
2016/04/12 Javascript
AngularJS控制器继承自另一控制器
2016/05/09 Javascript
全面了解JavaScript对象进阶
2016/07/19 Javascript
JS中微信小程序自定义底部弹出框
2016/12/22 Javascript
JS简单实现获取元素的封装操作示例
2017/04/07 Javascript
Vuex 入门教程
2018/01/10 Javascript
vue 简单自动补全的输入框的示例
2018/03/12 Javascript
Vue 去除路径中的#号
2018/04/19 Javascript
js中的 || 与 && 运算符详解
2018/05/24 Javascript
微信小程序中显示倒计时代码实例
2019/05/09 Javascript
微信小程序实现页面分享onShareAppMessage
2019/08/12 Javascript
Vue 嵌套路由使用总结(推荐)
2020/01/13 Javascript
Python字符串、元组、列表、字典互相转换的方法
2016/01/23 Python
Bottle框架中的装饰器类和描述符应用详解
2017/10/28 Python
python3.X 抓取火车票信息【修正版】
2018/06/19 Python
python去掉 unicode 字符串前面的u方法
2018/10/21 Python
Matplotlib绘制雷达图和三维图的示例代码
2020/01/07 Python
英国豪华针织品牌John Smedley的在线销售商:The Outlet by John Smedley
2018/04/08 全球购物
英国最大的天然和有机产品在线零售商之一:Big Green Smile
2020/05/06 全球购物
咨询公司各岗位职责
2013/12/02 职场文书
工程安全员岗位职责
2014/03/09 职场文书
青春寄语大全
2014/04/09 职场文书
2014年药剂科工作总结
2014/11/26 职场文书
写作技巧:怎样写好一份优秀工作总结?
2019/08/14 职场文书