Pytorch 多块GPU的使用详解


Posted in Python onDecember 31, 2019

注:本文针对单个服务器上多块GPU的使用,不是多服务器多GPU的使用。

在一些实验中,由于Batch_size的限制或者希望提高训练速度等原因,我们需要使用多块GPU。本文针对Pytorch中多块GPU的使用进行说明。

1. 设置需要使用的GPU编号

import os
 
os.environ["CUDA_VISIBLE_DEVICES"] = "0,4"
ids = [0,1]

比如我们需要使用第0和第4块GPU,只用上述三行代码即可。

其中第二行指程序只能看到第1块和第4块GPU;

第三行的0即为第二行中编号为0的GPU;1即为编号为4的GPU。

2.更改网络,可以理解为将网络放入GPU

class CNN(nn.Module):
  def __init__(self):
    super(CNN,self).__init__()
    self.conv1 = nn.Sequential(
    ......
    )
    
    ......
    
    self.out = nn.Linear(Liner_input,2)
 
  ......
    
  def forward(self,x):
    x = self.conv1(x)
    ......
    output = self.out(x)
    return output,x
  
cnn = CNN()
 
# 更改,.cuda()表示将本存储到CPU的网络及其参数存储到GPU!
cnn.cuda()

3. 更改输出数据(如向量/矩阵/张量):

for epoch in range(EPOCH):
  epoch_loss = 0.
  for i, data in enumerate(train_loader2):
    image = data['image'] # data是字典,我们需要改的是其中的image
 
    #############更改!!!##################
    image = Variable(image).float().cuda()
    ############################################
 
    label = inputs['label']
    #############更改!!!##################
    label = Variable(label).type(torch.LongTensor).cuda()
    ############################################
    label = label.resize(BATCH_SIZE)
    output = cnn(image)[0]
    loss = loss_func(output, label)  # cross entropy loss
    optimizer.zero_grad()      # clear gradients for this training step
    loss.backward()         # backpropagation, compute gradients
    optimizer.step() 
    ... ...

4. 更改其他CPU与GPU冲突的地方

有些函数必要在GPU上完成,例如将Tensor转换为Numpy,就要使用data.cpu().numpy(),其中data是GPU上的Tensor。

若直接使用data.numpy()则会报错。除此之外,plot等也需要在CPU中完成。如果不是很清楚哪里要改的话可以先不改,等到程序报错了,再哪里错了改哪里,效率会更高。例如:

... ...
    #################################################
    pred_y = torch.max(test_train_output, 1)[1].data.cpu().numpy()
    
    accuracy = float((pred_y == label.cpu().numpy()).astype(int).sum()) / float(len(label.cpu().numpy()))

假如不加.cpu()便会报错,此时再改即可。

5. 更改前向传播函数,从而使用多块GPU

以VGG为例:

class VGG(nn.Module):
 
  def __init__(self, features, num_classes=2, init_weights=True):
    super(VGG, self).__init__()
... ...
 
  def forward(self, x):
    #x = self.features(x)
    #################Multi GPUS#############################
    x = nn.parallel.data_parallel(self.features,x,ids)
    x = x.view(x.size(0), -1)
    # x = self.classifier(x)
    x = nn.parallel.data_parallel(self.classifier,x,ids)
    return x
... ...

然后就可以看运行结果啦,nvidia-smi查看GPU使用情况:

Pytorch 多块GPU的使用详解

可以看到0和4都被使用啦

以上这篇Pytorch 多块GPU的使用详解就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
python中sets模块的用法实例
Sep 30 Python
Python中使用select模块实现非阻塞的IO
Feb 03 Python
SQLite3中文编码 Python的实现
Jan 11 Python
Python实现的求解最大公约数算法示例
May 03 Python
python 保存float类型的小数的位数方法
Oct 17 Python
解决python通过cx_Oracle模块连接Oracle乱码的问题
Oct 18 Python
PyTorch的深度学习入门教程之构建神经网络
Jun 27 Python
使用Python为中秋节绘制一块美味的月饼
Sep 11 Python
Python连接SQLite数据库并进行增册改查操作方法详解
Feb 18 Python
Python flask框架如何显示图像到web页面
Jun 03 Python
python面向对象版学生信息管理系统
Jun 24 Python
python使用matplotlib绘制图片时x轴的刻度处理
Aug 30 Python
Pyorch之numpy与torch之间相互转换方式
Dec 31 #Python
pytorch sampler对数据进行采样的实现
Dec 31 #Python
关于pytorch处理类别不平衡的问题
Dec 31 #Python
pytorch 指定gpu训练与多gpu并行训练示例
Dec 31 #Python
浅析Django中关于session的使用
Dec 30 #Python
使用pickle存储数据dump 和 load实例讲解
Dec 30 #Python
在Python中利用pickle保存变量的实例
Dec 30 #Python
You might like
php中3des加密代码(完全与.net中的兼容)
2012/08/02 PHP
php实现的二叉树遍历算法示例
2017/06/15 PHP
从零开始学习jQuery (二) 万能的选择器
2010/10/01 Javascript
js中scrollHeight,scrollWidth,scrollLeft,scrolltop等差别介绍
2012/05/16 Javascript
Function.prototype.call.apply结合用法分析示例
2013/07/03 Javascript
异步动态加载js与css文件的js代码
2013/09/15 Javascript
原生Javascript封装的一个AJAX函数分享
2014/10/11 Javascript
JS实现可展开折叠层的鼠标拖曳效果
2015/10/09 Javascript
thinkphp实现无限分类(使用递归)
2015/12/19 Javascript
从vue基础开始创建一个简单的增删改查的实例代码(推荐)
2018/02/11 Javascript
使用layer弹窗和layui表单实现新增功能
2018/08/09 Javascript
js中获取URL参数的共用方法getRequest()方法实例详解
2018/10/24 Javascript
PM2自动部署代码步骤流程总结
2018/12/10 Javascript
Vue触发隐藏input file的方法实例详解
2019/08/14 Javascript
Vue中实现回车键切换焦点的方法
2020/02/19 Javascript
js实现无缝轮播图
2020/03/09 Javascript
python encode和decode的妙用
2009/09/02 Python
Python笔记(叁)继续学习
2012/10/24 Python
python魔法方法-自定义序列详解
2016/07/21 Python
Python下的Softmax回归函数的实现方法(推荐)
2017/01/26 Python
理解python中生成器用法
2017/12/20 Python
Python使用一行代码获取上个月是几月
2018/08/30 Python
Python lxml解析HTML并用xpath获取元素的方法
2019/01/02 Python
python实现全盘扫描搜索功能的方法
2019/02/14 Python
Pycharm修改python路径过程图解
2020/05/22 Python
HTML5混合开发二维码扫描以及调用本地摄像头
2017/12/27 HTML / CSS
canvas 如何绘制线段的实现方法
2018/07/12 HTML / CSS
香港莎莎官网Sasa.com:亚洲著名国际化妆品商城
2019/11/10 全球购物
写出程序把一个链表中的接点顺序倒排
2014/04/28 面试题
外贸业务员岗位职责
2013/11/24 职场文书
社区党务公开实施方案
2014/03/18 职场文书
师德师风学习材料
2014/12/19 职场文书
2015小学教师德育工作总结
2015/05/12 职场文书
七年级英语教学反思
2016/02/15 职场文书
MySQL中rank() over、dense_rank() over、row_number() over用法介绍
2022/03/23 MySQL
MySQL脏读,幻读和不可重复读
2022/05/11 MySQL