Pytorch 多块GPU的使用详解


Posted in Python onDecember 31, 2019

注:本文针对单个服务器上多块GPU的使用,不是多服务器多GPU的使用。

在一些实验中,由于Batch_size的限制或者希望提高训练速度等原因,我们需要使用多块GPU。本文针对Pytorch中多块GPU的使用进行说明。

1. 设置需要使用的GPU编号

import os
 
os.environ["CUDA_VISIBLE_DEVICES"] = "0,4"
ids = [0,1]

比如我们需要使用第0和第4块GPU,只用上述三行代码即可。

其中第二行指程序只能看到第1块和第4块GPU;

第三行的0即为第二行中编号为0的GPU;1即为编号为4的GPU。

2.更改网络,可以理解为将网络放入GPU

class CNN(nn.Module):
  def __init__(self):
    super(CNN,self).__init__()
    self.conv1 = nn.Sequential(
    ......
    )
    
    ......
    
    self.out = nn.Linear(Liner_input,2)
 
  ......
    
  def forward(self,x):
    x = self.conv1(x)
    ......
    output = self.out(x)
    return output,x
  
cnn = CNN()
 
# 更改,.cuda()表示将本存储到CPU的网络及其参数存储到GPU!
cnn.cuda()

3. 更改输出数据(如向量/矩阵/张量):

for epoch in range(EPOCH):
  epoch_loss = 0.
  for i, data in enumerate(train_loader2):
    image = data['image'] # data是字典,我们需要改的是其中的image
 
    #############更改!!!##################
    image = Variable(image).float().cuda()
    ############################################
 
    label = inputs['label']
    #############更改!!!##################
    label = Variable(label).type(torch.LongTensor).cuda()
    ############################################
    label = label.resize(BATCH_SIZE)
    output = cnn(image)[0]
    loss = loss_func(output, label)  # cross entropy loss
    optimizer.zero_grad()      # clear gradients for this training step
    loss.backward()         # backpropagation, compute gradients
    optimizer.step() 
    ... ...

4. 更改其他CPU与GPU冲突的地方

有些函数必要在GPU上完成,例如将Tensor转换为Numpy,就要使用data.cpu().numpy(),其中data是GPU上的Tensor。

若直接使用data.numpy()则会报错。除此之外,plot等也需要在CPU中完成。如果不是很清楚哪里要改的话可以先不改,等到程序报错了,再哪里错了改哪里,效率会更高。例如:

... ...
    #################################################
    pred_y = torch.max(test_train_output, 1)[1].data.cpu().numpy()
    
    accuracy = float((pred_y == label.cpu().numpy()).astype(int).sum()) / float(len(label.cpu().numpy()))

假如不加.cpu()便会报错,此时再改即可。

5. 更改前向传播函数,从而使用多块GPU

以VGG为例:

class VGG(nn.Module):
 
  def __init__(self, features, num_classes=2, init_weights=True):
    super(VGG, self).__init__()
... ...
 
  def forward(self, x):
    #x = self.features(x)
    #################Multi GPUS#############################
    x = nn.parallel.data_parallel(self.features,x,ids)
    x = x.view(x.size(0), -1)
    # x = self.classifier(x)
    x = nn.parallel.data_parallel(self.classifier,x,ids)
    return x
... ...

然后就可以看运行结果啦,nvidia-smi查看GPU使用情况:

Pytorch 多块GPU的使用详解

可以看到0和4都被使用啦

以上这篇Pytorch 多块GPU的使用详解就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
python基础教程之基本内置数据类型介绍
Feb 20 Python
Python的语言类型(详解)
Jun 24 Python
Python使用matplotlib实现绘制自定义图形功能示例
Jan 18 Python
如何优雅地处理Django中的favicon.ico图标详解
Jul 05 Python
pytorch permute维度转换方法
Dec 14 Python
解决pycharm 远程调试 上传 helpers 卡住的问题
Jun 27 Python
详解Python3定时器任务代码
Sep 23 Python
Python 网络编程之TCP客户端/服务端功能示例【基于socket套接字】
Oct 12 Python
pd.DataFrame统计各列数值多少的实例
Dec 05 Python
将keras的h5模型转换为tensorflow的pb模型操作
May 25 Python
python使用nibabel和sitk读取保存nii.gz文件实例
Jul 01 Python
pyqt5 textEdit、lineEdit操作的示例代码
Aug 12 Python
Pyorch之numpy与torch之间相互转换方式
Dec 31 #Python
pytorch sampler对数据进行采样的实现
Dec 31 #Python
关于pytorch处理类别不平衡的问题
Dec 31 #Python
pytorch 指定gpu训练与多gpu并行训练示例
Dec 31 #Python
浅析Django中关于session的使用
Dec 30 #Python
使用pickle存储数据dump 和 load实例讲解
Dec 30 #Python
在Python中利用pickle保存变量的实例
Dec 30 #Python
You might like
php 删除一个数组中的某个值.兼容多维数组!
2012/02/18 PHP
PHP常用开发函数解析之数组篇[未完结]
2012/07/30 PHP
PHPExcel读取EXCEL中的图片并保存到本地的方法
2015/02/14 PHP
Javascript处理DOM元素事件实现代码
2012/05/23 Javascript
基于jquery的DIV随滚动条滚动而滚动的代码
2012/07/20 Javascript
jquery ui对话框实例代码
2013/05/10 Javascript
js获取系统的根路径实现介绍
2013/09/08 Javascript
jquery中的查找parents与closest方法之间的区别
2013/12/02 Javascript
改变隐藏的input中value的值代码
2013/12/30 Javascript
document.execCommand()的用法小结
2014/01/08 Javascript
JS中如何比较两个Json对象是否相等实例代码
2016/07/13 Javascript
详解AngularJS中的表单验证(推荐)
2016/11/17 Javascript
Javascript Event(事件)的传播与冒泡
2017/01/23 Javascript
Node连接mysql数据库方法介绍
2017/02/07 Javascript
微信小程序使用toast消息对话框提示用户忘记输入用户名或密码功能【附源码下载】
2017/12/09 Javascript
优雅的在React项目中使用Redux的方法
2018/11/10 Javascript
详解基于webpack&gettext的前端多语言方案
2019/01/29 Javascript
Vue中常用rules校验规则(实例代码)
2019/11/14 Javascript
微信小程序实现倒计时功能
2020/11/19 Javascript
利用python实现数据分析
2017/01/11 Python
简述Python2与Python3的不同点
2018/01/21 Python
详解Python 协程的详细用法使用和例子
2018/06/15 Python
Python3 获取一大段文本之间两个关键字之间的内容方法
2018/10/11 Python
pandas 数据索引与选取的实现方法
2019/06/21 Python
python3射线法判断点是否在多边形内
2019/06/28 Python
python列表插入append(), extend(), insert()用法详解
2019/09/14 Python
Python项目跨域问题解决方案
2020/06/22 Python
跑步爱好者一站式服务网站:Jack Rabbit
2016/09/01 全球购物
诗普兰迪官方网站:Splendid
2018/09/18 全球购物
医学生自荐信
2013/12/03 职场文书
公司企业表扬信
2014/01/11 职场文书
竞选学生会演讲稿
2014/04/25 职场文书
2014个人年度工作总结范文
2014/12/24 职场文书
乡镇保密工作承诺书
2015/05/04 职场文书
golang定时器
2022/04/14 Golang
python实现双链表
2022/05/25 Python