Pytorch 多块GPU的使用详解


Posted in Python onDecember 31, 2019

注:本文针对单个服务器上多块GPU的使用,不是多服务器多GPU的使用。

在一些实验中,由于Batch_size的限制或者希望提高训练速度等原因,我们需要使用多块GPU。本文针对Pytorch中多块GPU的使用进行说明。

1. 设置需要使用的GPU编号

import os
 
os.environ["CUDA_VISIBLE_DEVICES"] = "0,4"
ids = [0,1]

比如我们需要使用第0和第4块GPU,只用上述三行代码即可。

其中第二行指程序只能看到第1块和第4块GPU;

第三行的0即为第二行中编号为0的GPU;1即为编号为4的GPU。

2.更改网络,可以理解为将网络放入GPU

class CNN(nn.Module):
  def __init__(self):
    super(CNN,self).__init__()
    self.conv1 = nn.Sequential(
    ......
    )
    
    ......
    
    self.out = nn.Linear(Liner_input,2)
 
  ......
    
  def forward(self,x):
    x = self.conv1(x)
    ......
    output = self.out(x)
    return output,x
  
cnn = CNN()
 
# 更改,.cuda()表示将本存储到CPU的网络及其参数存储到GPU!
cnn.cuda()

3. 更改输出数据(如向量/矩阵/张量):

for epoch in range(EPOCH):
  epoch_loss = 0.
  for i, data in enumerate(train_loader2):
    image = data['image'] # data是字典,我们需要改的是其中的image
 
    #############更改!!!##################
    image = Variable(image).float().cuda()
    ############################################
 
    label = inputs['label']
    #############更改!!!##################
    label = Variable(label).type(torch.LongTensor).cuda()
    ############################################
    label = label.resize(BATCH_SIZE)
    output = cnn(image)[0]
    loss = loss_func(output, label)  # cross entropy loss
    optimizer.zero_grad()      # clear gradients for this training step
    loss.backward()         # backpropagation, compute gradients
    optimizer.step() 
    ... ...

4. 更改其他CPU与GPU冲突的地方

有些函数必要在GPU上完成,例如将Tensor转换为Numpy,就要使用data.cpu().numpy(),其中data是GPU上的Tensor。

若直接使用data.numpy()则会报错。除此之外,plot等也需要在CPU中完成。如果不是很清楚哪里要改的话可以先不改,等到程序报错了,再哪里错了改哪里,效率会更高。例如:

... ...
    #################################################
    pred_y = torch.max(test_train_output, 1)[1].data.cpu().numpy()
    
    accuracy = float((pred_y == label.cpu().numpy()).astype(int).sum()) / float(len(label.cpu().numpy()))

假如不加.cpu()便会报错,此时再改即可。

5. 更改前向传播函数,从而使用多块GPU

以VGG为例:

class VGG(nn.Module):
 
  def __init__(self, features, num_classes=2, init_weights=True):
    super(VGG, self).__init__()
... ...
 
  def forward(self, x):
    #x = self.features(x)
    #################Multi GPUS#############################
    x = nn.parallel.data_parallel(self.features,x,ids)
    x = x.view(x.size(0), -1)
    # x = self.classifier(x)
    x = nn.parallel.data_parallel(self.classifier,x,ids)
    return x
... ...

然后就可以看运行结果啦,nvidia-smi查看GPU使用情况:

Pytorch 多块GPU的使用详解

可以看到0和4都被使用啦

以上这篇Pytorch 多块GPU的使用详解就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
python实现从一组颜色中找出与给定颜色最接近颜色的方法
Mar 19 Python
python 网络爬虫初级实现代码
Feb 27 Python
Python利用Beautiful Soup模块修改内容方法示例
Mar 27 Python
PyQt4实现下拉菜单可供选择并打印出来
Apr 20 Python
Python生成器定义与简单用法实例分析
Apr 30 Python
python中pygame安装过程(超级详细)
Aug 04 Python
pygame实现俄罗斯方块游戏(AI篇2)
Oct 29 Python
安装Pycharm2019以及配置anconda教程的方法步骤
Nov 11 Python
python使用pygame实现笑脸乒乓球弹珠球游戏
Nov 25 Python
python爬虫学习笔记之pyquery模块基本用法详解
Apr 09 Python
用python批量下载apk
Dec 29 Python
pandas针对excel处理的实现
Jan 15 Python
Pyorch之numpy与torch之间相互转换方式
Dec 31 #Python
pytorch sampler对数据进行采样的实现
Dec 31 #Python
关于pytorch处理类别不平衡的问题
Dec 31 #Python
pytorch 指定gpu训练与多gpu并行训练示例
Dec 31 #Python
浅析Django中关于session的使用
Dec 30 #Python
使用pickle存储数据dump 和 load实例讲解
Dec 30 #Python
在Python中利用pickle保存变量的实例
Dec 30 #Python
You might like
彻底删除thinkphp3.1案例blog标签的方法
2014/12/05 PHP
PHP用mb_string函数库处理与windows相关中文字符及Win环境下开启PHP Mb_String方法
2015/11/11 PHP
Android AsyncTack 异步任务实例详解
2016/11/02 PHP
Yii框架的布局文件实例分析
2019/09/04 PHP
JavaScript定义类或函数的几种方式小结
2011/01/09 Javascript
ext前台接收action传过来的json数据示例
2014/06/17 Javascript
JavaScript实现的in_array函数
2014/08/27 Javascript
jQuery插件jRumble实现网页元素抖动
2015/06/05 Javascript
基于ajax实现文件上传并显示进度条
2015/08/03 Javascript
jQuery幻灯片特效代码分享--鼠标滑过按钮时切换(2)
2020/11/18 Javascript
基于javascript显示当前时间以及倒计时功能
2016/03/18 Javascript
使用CSS+JavaScript或纯js实现半透明遮罩效果的实例分享
2016/05/09 Javascript
Javascript中字符串相关常用的使用方法总结
2017/03/13 Javascript
从零开始学习Node.js系列教程之SQLite3和MongoDB用法分析
2017/04/13 Javascript
vue实现多级菜单效果
2019/10/19 Javascript
Vue实现圆环进度条的示例
2021/02/06 Vue.js
python基础教程之面向对象的一些概念
2014/08/29 Python
在Python中使用Neo4j数据库的教程
2015/04/16 Python
分享一下Python数据分析常用的8款工具
2018/04/29 Python
Django中STATIC_ROOT和STATIC_URL及STATICFILES_DIRS浅析
2018/05/08 Python
对numpy中shape的深入理解
2018/06/15 Python
详解python Todo清单实战
2018/11/01 Python
Python自动发送邮件的方法实例总结
2018/12/08 Python
python实现公司年会抽奖程序
2019/01/22 Python
python 实现单通道转3通道
2019/12/03 Python
Python使用plt.boxplot() 参数绘制箱线图
2020/06/04 Python
django美化后台django-suit的安装配置操作
2020/07/12 Python
Python requests接口测试实现代码
2020/09/08 Python
HTML5新表单元素_动力节点Java学院整理
2017/07/12 HTML / CSS
高中生毕业自我鉴定
2013/10/10 职场文书
教你打造完美的创业计划书
2014/01/06 职场文书
中学生自我评价范文
2014/02/08 职场文书
材料专业毕业生求职信
2014/02/26 职场文书
社会实践先进工作者事迹材料
2014/05/06 职场文书
机械电子工程专业求职信
2014/06/22 职场文书
MySQL实例精讲单行函数以及字符数学日期流程控制
2021/10/15 MySQL