Pytorch 多块GPU的使用详解


Posted in Python onDecember 31, 2019

注:本文针对单个服务器上多块GPU的使用,不是多服务器多GPU的使用。

在一些实验中,由于Batch_size的限制或者希望提高训练速度等原因,我们需要使用多块GPU。本文针对Pytorch中多块GPU的使用进行说明。

1. 设置需要使用的GPU编号

import os
 
os.environ["CUDA_VISIBLE_DEVICES"] = "0,4"
ids = [0,1]

比如我们需要使用第0和第4块GPU,只用上述三行代码即可。

其中第二行指程序只能看到第1块和第4块GPU;

第三行的0即为第二行中编号为0的GPU;1即为编号为4的GPU。

2.更改网络,可以理解为将网络放入GPU

class CNN(nn.Module):
  def __init__(self):
    super(CNN,self).__init__()
    self.conv1 = nn.Sequential(
    ......
    )
    
    ......
    
    self.out = nn.Linear(Liner_input,2)
 
  ......
    
  def forward(self,x):
    x = self.conv1(x)
    ......
    output = self.out(x)
    return output,x
  
cnn = CNN()
 
# 更改,.cuda()表示将本存储到CPU的网络及其参数存储到GPU!
cnn.cuda()

3. 更改输出数据(如向量/矩阵/张量):

for epoch in range(EPOCH):
  epoch_loss = 0.
  for i, data in enumerate(train_loader2):
    image = data['image'] # data是字典,我们需要改的是其中的image
 
    #############更改!!!##################
    image = Variable(image).float().cuda()
    ############################################
 
    label = inputs['label']
    #############更改!!!##################
    label = Variable(label).type(torch.LongTensor).cuda()
    ############################################
    label = label.resize(BATCH_SIZE)
    output = cnn(image)[0]
    loss = loss_func(output, label)  # cross entropy loss
    optimizer.zero_grad()      # clear gradients for this training step
    loss.backward()         # backpropagation, compute gradients
    optimizer.step() 
    ... ...

4. 更改其他CPU与GPU冲突的地方

有些函数必要在GPU上完成,例如将Tensor转换为Numpy,就要使用data.cpu().numpy(),其中data是GPU上的Tensor。

若直接使用data.numpy()则会报错。除此之外,plot等也需要在CPU中完成。如果不是很清楚哪里要改的话可以先不改,等到程序报错了,再哪里错了改哪里,效率会更高。例如:

... ...
    #################################################
    pred_y = torch.max(test_train_output, 1)[1].data.cpu().numpy()
    
    accuracy = float((pred_y == label.cpu().numpy()).astype(int).sum()) / float(len(label.cpu().numpy()))

假如不加.cpu()便会报错,此时再改即可。

5. 更改前向传播函数,从而使用多块GPU

以VGG为例:

class VGG(nn.Module):
 
  def __init__(self, features, num_classes=2, init_weights=True):
    super(VGG, self).__init__()
... ...
 
  def forward(self, x):
    #x = self.features(x)
    #################Multi GPUS#############################
    x = nn.parallel.data_parallel(self.features,x,ids)
    x = x.view(x.size(0), -1)
    # x = self.classifier(x)
    x = nn.parallel.data_parallel(self.classifier,x,ids)
    return x
... ...

然后就可以看运行结果啦,nvidia-smi查看GPU使用情况:

Pytorch 多块GPU的使用详解

可以看到0和4都被使用啦

以上这篇Pytorch 多块GPU的使用详解就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
Python map和reduce函数用法示例
Feb 26 Python
Python比较文件夹比另一同名文件夹多出的文件并复制出来的方法
Mar 05 Python
使用Python实现下载网易云音乐的高清MV
Mar 16 Python
django js实现部分页面刷新的示例代码
May 28 Python
Selenium定时刷新网页的实现代码
Oct 31 Python
完美解决keras 读取多个hdf5文件进行训练的问题
Jul 01 Python
pandas.DataFrame.drop_duplicates 用法介绍
Jul 06 Python
python定义类的简单用法
Jul 24 Python
Python中使用aiohttp模拟服务器出现错误问题及解决方法
Oct 31 Python
Python3 用matplotlib绘制sigmoid函数的案例
Dec 11 Python
matplotlib源码解析标题实现(窗口标题,标题,子图标题不同之间的差异)
Feb 22 Python
Python进程间的通信之语法学习
Apr 11 Python
Pyorch之numpy与torch之间相互转换方式
Dec 31 #Python
pytorch sampler对数据进行采样的实现
Dec 31 #Python
关于pytorch处理类别不平衡的问题
Dec 31 #Python
pytorch 指定gpu训练与多gpu并行训练示例
Dec 31 #Python
浅析Django中关于session的使用
Dec 30 #Python
使用pickle存储数据dump 和 load实例讲解
Dec 30 #Python
在Python中利用pickle保存变量的实例
Dec 30 #Python
You might like
PHILIPS L4X25T电路分析和打理
2021/03/02 无线电
PHP删除数组中指定下标的元素方法
2018/02/03 PHP
Laravel程序架构设计思路之使用动作类
2018/06/07 PHP
使用js解决由border属性引起的div宽度问题
2013/11/26 Javascript
sogou地图API用法实例教程
2014/09/11 Javascript
jQuery满意度星级评价插件特效代码分享
2015/08/19 Javascript
阿里巴巴技术文章分享 Javascript继承机制的实现
2016/01/14 Javascript
Avalonjs双向数据绑定与监听的实例代码
2017/06/23 Javascript
使用javaScript实现鼠标拖拽事件
2020/04/03 Javascript
详解vue-element Tree树形控件填坑路
2019/03/26 Javascript
node.js中stream流中可读流和可写流的实现与使用方法实例分析
2020/02/13 Javascript
jQuery实现朋友圈查看图片
2020/09/11 jQuery
用python分割TXT文件成4K的TXT文件
2009/05/23 Python
从零学python系列之从文件读取和保存数据
2014/05/23 Python
Python中字典和JSON互转操作实例
2015/01/19 Python
Python合并多个装饰器小技巧
2015/04/28 Python
使用Mixin设计模式进行Python编程的方法讲解
2016/06/21 Python
浅谈编码,解码,乱码的问题
2016/12/30 Python
Python根据已知邻接矩阵绘制无向图操作示例
2018/06/23 Python
Django模板Templates使用方法详解
2019/07/19 Python
pygame库实现俄罗斯方块小游戏
2019/10/29 Python
Python网络爬虫四大选择器用法原理总结
2020/06/01 Python
python批量修改交换机密码的示例
2020/09/22 Python
HTML5时代CSS设置漂亮字体取代图片
2014/09/04 HTML / CSS
html5中去掉input type date默认样式的方法
2018/09/06 HTML / CSS
canvas 实现 github404动态效果的示例代码
2017/11/15 HTML / CSS
Canvas 像素处理之改变透明度的实现代码
2019/01/08 HTML / CSS
新秀丽拉杆箱美国官方网站:Samsonite美国
2016/07/25 全球购物
ET Mall东森购物网:东森严选
2017/03/06 全球购物
美国第二大连锁书店:Books-A-Million
2017/12/28 全球购物
Kangol帽子官网:坎戈尔袋鼠
2018/09/26 全球购物
CheapTickets香港机票预订网站:CheapTickets.hk
2019/06/26 全球购物
学习雷锋精神演讲稿
2014/05/10 职场文书
公司演讲稿开场白
2014/08/25 职场文书
2014最新开业庆典策划方案(5篇)
2014/09/15 职场文书
侵犯商业秘密的律师函
2015/05/27 职场文书