Pytorch 多块GPU的使用详解


Posted in Python onDecember 31, 2019

注:本文针对单个服务器上多块GPU的使用,不是多服务器多GPU的使用。

在一些实验中,由于Batch_size的限制或者希望提高训练速度等原因,我们需要使用多块GPU。本文针对Pytorch中多块GPU的使用进行说明。

1. 设置需要使用的GPU编号

import os
 
os.environ["CUDA_VISIBLE_DEVICES"] = "0,4"
ids = [0,1]

比如我们需要使用第0和第4块GPU,只用上述三行代码即可。

其中第二行指程序只能看到第1块和第4块GPU;

第三行的0即为第二行中编号为0的GPU;1即为编号为4的GPU。

2.更改网络,可以理解为将网络放入GPU

class CNN(nn.Module):
  def __init__(self):
    super(CNN,self).__init__()
    self.conv1 = nn.Sequential(
    ......
    )
    
    ......
    
    self.out = nn.Linear(Liner_input,2)
 
  ......
    
  def forward(self,x):
    x = self.conv1(x)
    ......
    output = self.out(x)
    return output,x
  
cnn = CNN()
 
# 更改,.cuda()表示将本存储到CPU的网络及其参数存储到GPU!
cnn.cuda()

3. 更改输出数据(如向量/矩阵/张量):

for epoch in range(EPOCH):
  epoch_loss = 0.
  for i, data in enumerate(train_loader2):
    image = data['image'] # data是字典,我们需要改的是其中的image
 
    #############更改!!!##################
    image = Variable(image).float().cuda()
    ############################################
 
    label = inputs['label']
    #############更改!!!##################
    label = Variable(label).type(torch.LongTensor).cuda()
    ############################################
    label = label.resize(BATCH_SIZE)
    output = cnn(image)[0]
    loss = loss_func(output, label)  # cross entropy loss
    optimizer.zero_grad()      # clear gradients for this training step
    loss.backward()         # backpropagation, compute gradients
    optimizer.step() 
    ... ...

4. 更改其他CPU与GPU冲突的地方

有些函数必要在GPU上完成,例如将Tensor转换为Numpy,就要使用data.cpu().numpy(),其中data是GPU上的Tensor。

若直接使用data.numpy()则会报错。除此之外,plot等也需要在CPU中完成。如果不是很清楚哪里要改的话可以先不改,等到程序报错了,再哪里错了改哪里,效率会更高。例如:

... ...
    #################################################
    pred_y = torch.max(test_train_output, 1)[1].data.cpu().numpy()
    
    accuracy = float((pred_y == label.cpu().numpy()).astype(int).sum()) / float(len(label.cpu().numpy()))

假如不加.cpu()便会报错,此时再改即可。

5. 更改前向传播函数,从而使用多块GPU

以VGG为例:

class VGG(nn.Module):
 
  def __init__(self, features, num_classes=2, init_weights=True):
    super(VGG, self).__init__()
... ...
 
  def forward(self, x):
    #x = self.features(x)
    #################Multi GPUS#############################
    x = nn.parallel.data_parallel(self.features,x,ids)
    x = x.view(x.size(0), -1)
    # x = self.classifier(x)
    x = nn.parallel.data_parallel(self.classifier,x,ids)
    return x
... ...

然后就可以看运行结果啦,nvidia-smi查看GPU使用情况:

Pytorch 多块GPU的使用详解

可以看到0和4都被使用啦

以上这篇Pytorch 多块GPU的使用详解就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
python类参数self使用示例
Feb 17 Python
Python程序员开发中常犯的10个错误
Jul 07 Python
基于python元祖与字典与集合的粗浅认识
Aug 23 Python
对python以16进制打印字节数组的方法详解
Jan 24 Python
Python中print和return的作用及区别解析
May 05 Python
python3.7简单的爬虫实例详解
Jul 08 Python
使用Django搭建一个基金模拟交易系统教程
Nov 18 Python
如何基于Python创建目录文件夹
Dec 31 Python
SpringBoot实现登录注册常见问题解决方案
Mar 04 Python
Python3.9.1中使用split()的处理方法(推荐)
Feb 07 Python
python中random模块详解
Mar 01 Python
pytorch 如何使用float64训练
May 24 Python
Pyorch之numpy与torch之间相互转换方式
Dec 31 #Python
pytorch sampler对数据进行采样的实现
Dec 31 #Python
关于pytorch处理类别不平衡的问题
Dec 31 #Python
pytorch 指定gpu训练与多gpu并行训练示例
Dec 31 #Python
浅析Django中关于session的使用
Dec 30 #Python
使用pickle存储数据dump 和 load实例讲解
Dec 30 #Python
在Python中利用pickle保存变量的实例
Dec 30 #Python
You might like
php求一个网段开始与结束IP地址的方法
2015/07/09 PHP
实现PHP搜索加分页
2016/10/12 PHP
PHP调试及性能分析工具Xdebug详解
2017/02/09 PHP
js函数般调用正则
2008/04/08 Javascript
javascript:以前写的xmlhttp池,代码
2008/05/18 Javascript
JS模拟的QQ面板上的多级可展开的菜单
2009/10/10 Javascript
jquery之超简单的div显示和隐藏特效demo(分享)
2013/07/09 Javascript
浅谈JavaScript函数参数的可修改性问题
2013/12/05 Javascript
兼容最新firefox、chrome和IE的javascript图片预览实现代码
2014/08/08 Javascript
jQuery Ajax使用实例
2015/04/16 Javascript
JQuery实现超链接鼠标提示效果的方法
2015/06/10 Javascript
基于jQuery实现滚动切换效果
2016/12/02 Javascript
jQuery插件zTree实现单独选中根节点中第一个节点示例
2017/03/08 Javascript
vue使用stompjs实现mqtt消息推送通知
2017/06/22 Javascript
简述jQuery Easyui一些用法
2017/08/01 jQuery
vue.js根据代码运行环境选择baseurl的方法
2018/02/28 Javascript
vue addRoutes实现动态权限路由菜单的示例
2018/05/15 Javascript
koa2 从入门到精通(小结)
2019/07/23 Javascript
详解JavaScript中的Object.is()与"==="运算符总结
2020/06/17 Javascript
基于vue的video播放器的实现示例
2021/02/19 Vue.js
python实现k均值算法示例(k均值聚类算法)
2014/03/16 Python
Python中time模块与datetime模块在使用中的不同之处
2015/11/24 Python
Python字典中的键映射多个值的方法(列表或者集合)
2018/10/17 Python
我喜欢你 抖音表白程序python版
2019/04/07 Python
Python Dict找出value大于某值或key大于某值的所有项方式
2020/06/05 Python
keras Lambda自定义层实现数据的切片方式,Lambda传参数
2020/06/11 Python
python3 kubernetes api的使用示例
2021/01/12 Python
Python中使用Selenium环境安装的方法步骤
2021/02/22 Python
CSS3 box-sizing属性
2009/04/17 HTML / CSS
Foot Locker加拿大官网:美国知名运动产品零售商
2019/07/21 全球购物
试述DBMS的主要功能
2016/11/13 面试题
经典演讲稿汇总
2014/05/19 职场文书
弘扬焦裕禄精神走群众路线思想汇报
2014/09/12 职场文书
小学德育工作总结2015
2015/05/12 职场文书
入党申请书怎么写?
2019/06/21 职场文书
浅谈PostgreSQL表分区的三种方式
2021/06/29 PostgreSQL