Pytorch 多块GPU的使用详解


Posted in Python onDecember 31, 2019

注:本文针对单个服务器上多块GPU的使用,不是多服务器多GPU的使用。

在一些实验中,由于Batch_size的限制或者希望提高训练速度等原因,我们需要使用多块GPU。本文针对Pytorch中多块GPU的使用进行说明。

1. 设置需要使用的GPU编号

import os
 
os.environ["CUDA_VISIBLE_DEVICES"] = "0,4"
ids = [0,1]

比如我们需要使用第0和第4块GPU,只用上述三行代码即可。

其中第二行指程序只能看到第1块和第4块GPU;

第三行的0即为第二行中编号为0的GPU;1即为编号为4的GPU。

2.更改网络,可以理解为将网络放入GPU

class CNN(nn.Module):
  def __init__(self):
    super(CNN,self).__init__()
    self.conv1 = nn.Sequential(
    ......
    )
    
    ......
    
    self.out = nn.Linear(Liner_input,2)
 
  ......
    
  def forward(self,x):
    x = self.conv1(x)
    ......
    output = self.out(x)
    return output,x
  
cnn = CNN()
 
# 更改,.cuda()表示将本存储到CPU的网络及其参数存储到GPU!
cnn.cuda()

3. 更改输出数据(如向量/矩阵/张量):

for epoch in range(EPOCH):
  epoch_loss = 0.
  for i, data in enumerate(train_loader2):
    image = data['image'] # data是字典,我们需要改的是其中的image
 
    #############更改!!!##################
    image = Variable(image).float().cuda()
    ############################################
 
    label = inputs['label']
    #############更改!!!##################
    label = Variable(label).type(torch.LongTensor).cuda()
    ############################################
    label = label.resize(BATCH_SIZE)
    output = cnn(image)[0]
    loss = loss_func(output, label)  # cross entropy loss
    optimizer.zero_grad()      # clear gradients for this training step
    loss.backward()         # backpropagation, compute gradients
    optimizer.step() 
    ... ...

4. 更改其他CPU与GPU冲突的地方

有些函数必要在GPU上完成,例如将Tensor转换为Numpy,就要使用data.cpu().numpy(),其中data是GPU上的Tensor。

若直接使用data.numpy()则会报错。除此之外,plot等也需要在CPU中完成。如果不是很清楚哪里要改的话可以先不改,等到程序报错了,再哪里错了改哪里,效率会更高。例如:

... ...
    #################################################
    pred_y = torch.max(test_train_output, 1)[1].data.cpu().numpy()
    
    accuracy = float((pred_y == label.cpu().numpy()).astype(int).sum()) / float(len(label.cpu().numpy()))

假如不加.cpu()便会报错,此时再改即可。

5. 更改前向传播函数,从而使用多块GPU

以VGG为例:

class VGG(nn.Module):
 
  def __init__(self, features, num_classes=2, init_weights=True):
    super(VGG, self).__init__()
... ...
 
  def forward(self, x):
    #x = self.features(x)
    #################Multi GPUS#############################
    x = nn.parallel.data_parallel(self.features,x,ids)
    x = x.view(x.size(0), -1)
    # x = self.classifier(x)
    x = nn.parallel.data_parallel(self.classifier,x,ids)
    return x
... ...

然后就可以看运行结果啦,nvidia-smi查看GPU使用情况:

Pytorch 多块GPU的使用详解

可以看到0和4都被使用啦

以上这篇Pytorch 多块GPU的使用详解就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
Python内置函数bin() oct()等实现进制转换
Dec 30 Python
收藏整理的一些Python常用方法和技巧
May 18 Python
Python开发如何在ubuntu 15.10 上配置vim
Jan 25 Python
Python语言实现获取主机名根据端口杀死进程
Mar 31 Python
Python正则表达式匹配中文用法示例
Jan 17 Python
python实现员工管理系统
Jan 11 Python
Python实现按逗号分隔列表的方法
Oct 23 Python
Django unittest 设置跳过某些case的方法
Dec 26 Python
django 做 migrate 时 表已存在的处理方法
Aug 31 Python
浅谈图像处理中掩膜(mask)的意义
Feb 19 Python
Python基本数据类型之字符串str
Jul 21 Python
Python实现批量将文件复制到新的目录中再修改名称
Apr 12 Python
Pyorch之numpy与torch之间相互转换方式
Dec 31 #Python
pytorch sampler对数据进行采样的实现
Dec 31 #Python
关于pytorch处理类别不平衡的问题
Dec 31 #Python
pytorch 指定gpu训练与多gpu并行训练示例
Dec 31 #Python
浅析Django中关于session的使用
Dec 30 #Python
使用pickle存储数据dump 和 load实例讲解
Dec 30 #Python
在Python中利用pickle保存变量的实例
Dec 30 #Python
You might like
PHP遍历文件夹与文件类及处理类用法实例
2014/09/23 PHP
php使用curl打开https网站的方法
2015/06/17 PHP
随机显示经典句子或诗歌的javascript脚本
2007/08/04 Javascript
原生JS可拖动弹窗效果实例代码
2013/11/09 Javascript
js判断undefined类型示例代码
2014/02/10 Javascript
Express.JS使用详解
2014/07/17 Javascript
js监控IE火狐浏览器关闭、刷新、回退、前进事件
2014/07/23 Javascript
node.js中的path.join方法使用说明
2014/12/08 Javascript
浅谈Unicode与JavaScript的发展史
2015/01/19 Javascript
EditPlus 正则表达式 实战(3)
2016/12/15 Javascript
Node.js设置CORS跨域请求中多域名白名单的方法
2017/03/28 Javascript
Vue axios设置访问基础路径方法
2018/09/19 Javascript
Web安全之XSS攻击与防御小结
2018/12/13 Javascript
vue中使用props传值的方法
2019/05/08 Javascript
微信小程序云开发获取文件夹下所有文件(推荐)
2019/11/14 Javascript
JS将指定的某个字符全部转换为其他字符实例代码
2020/10/13 Javascript
vue3使用vue-count-to组件的实现
2020/12/25 Vue.js
Python3实现连接SQLite数据库的方法
2014/08/23 Python
python去除空格和换行符的实现方法(推荐)
2017/01/04 Python
python爬虫使用cookie登录详解
2017/12/27 Python
Python一键安装全部依赖包的方法
2019/08/12 Python
python3.x 生成3维随机数组实例
2019/11/28 Python
解决tensorboard多个events文件显示紊乱的问题
2020/02/15 Python
天美时手表加拿大官网:Timex加拿大
2016/09/01 全球购物
西班牙鞋子和箱包在线销售网站:zapatos.es
2020/02/17 全球购物
捷科时代的软件测试笔试题
2015/11/09 面试题
教师演讲稿范文
2014/01/08 职场文书
《守株待兔》教学反思
2014/03/01 职场文书
体育课课后反思
2014/04/24 职场文书
村班子对照检查材料
2014/08/18 职场文书
个人委托书范文
2015/01/28 职场文书
爱鸟护鸟的宣传语
2015/07/13 职场文书
2019年教师入党申请书
2019/06/27 职场文书
使用python+pygame开发消消乐游戏附完整源码
2021/06/10 Python
CSS几步实现赛博朋克2077风格视觉效果
2021/06/16 HTML / CSS
php去除deprecated的实例方法
2021/11/17 PHP