Pytorch 多块GPU的使用详解


Posted in Python onDecember 31, 2019

注:本文针对单个服务器上多块GPU的使用,不是多服务器多GPU的使用。

在一些实验中,由于Batch_size的限制或者希望提高训练速度等原因,我们需要使用多块GPU。本文针对Pytorch中多块GPU的使用进行说明。

1. 设置需要使用的GPU编号

import os
 
os.environ["CUDA_VISIBLE_DEVICES"] = "0,4"
ids = [0,1]

比如我们需要使用第0和第4块GPU,只用上述三行代码即可。

其中第二行指程序只能看到第1块和第4块GPU;

第三行的0即为第二行中编号为0的GPU;1即为编号为4的GPU。

2.更改网络,可以理解为将网络放入GPU

class CNN(nn.Module):
  def __init__(self):
    super(CNN,self).__init__()
    self.conv1 = nn.Sequential(
    ......
    )
    
    ......
    
    self.out = nn.Linear(Liner_input,2)
 
  ......
    
  def forward(self,x):
    x = self.conv1(x)
    ......
    output = self.out(x)
    return output,x
  
cnn = CNN()
 
# 更改,.cuda()表示将本存储到CPU的网络及其参数存储到GPU!
cnn.cuda()

3. 更改输出数据(如向量/矩阵/张量):

for epoch in range(EPOCH):
  epoch_loss = 0.
  for i, data in enumerate(train_loader2):
    image = data['image'] # data是字典,我们需要改的是其中的image
 
    #############更改!!!##################
    image = Variable(image).float().cuda()
    ############################################
 
    label = inputs['label']
    #############更改!!!##################
    label = Variable(label).type(torch.LongTensor).cuda()
    ############################################
    label = label.resize(BATCH_SIZE)
    output = cnn(image)[0]
    loss = loss_func(output, label)  # cross entropy loss
    optimizer.zero_grad()      # clear gradients for this training step
    loss.backward()         # backpropagation, compute gradients
    optimizer.step() 
    ... ...

4. 更改其他CPU与GPU冲突的地方

有些函数必要在GPU上完成,例如将Tensor转换为Numpy,就要使用data.cpu().numpy(),其中data是GPU上的Tensor。

若直接使用data.numpy()则会报错。除此之外,plot等也需要在CPU中完成。如果不是很清楚哪里要改的话可以先不改,等到程序报错了,再哪里错了改哪里,效率会更高。例如:

... ...
    #################################################
    pred_y = torch.max(test_train_output, 1)[1].data.cpu().numpy()
    
    accuracy = float((pred_y == label.cpu().numpy()).astype(int).sum()) / float(len(label.cpu().numpy()))

假如不加.cpu()便会报错,此时再改即可。

5. 更改前向传播函数,从而使用多块GPU

以VGG为例:

class VGG(nn.Module):
 
  def __init__(self, features, num_classes=2, init_weights=True):
    super(VGG, self).__init__()
... ...
 
  def forward(self, x):
    #x = self.features(x)
    #################Multi GPUS#############################
    x = nn.parallel.data_parallel(self.features,x,ids)
    x = x.view(x.size(0), -1)
    # x = self.classifier(x)
    x = nn.parallel.data_parallel(self.classifier,x,ids)
    return x
... ...

然后就可以看运行结果啦,nvidia-smi查看GPU使用情况:

Pytorch 多块GPU的使用详解

可以看到0和4都被使用啦

以上这篇Pytorch 多块GPU的使用详解就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
python实现每次处理一个字符的三种方法
Oct 09 Python
零基础写python爬虫之urllib2使用指南
Nov 05 Python
python版简单工厂模式
Oct 16 Python
恢复百度云盘本地误删的文件脚本(简单方法)
Oct 21 Python
python3.6.3转化为win-exe文件发布的方法
Oct 31 Python
Python快速转换numpy数组中Nan和Inf的方法实例说明
Feb 21 Python
python实现七段数码管和倒计时效果
Nov 23 Python
在python中创建指定大小的多维数组方式
Nov 28 Python
opencv3/Python 稠密光流calcOpticalFlowFarneback详解
Dec 11 Python
python游戏开发的五个案例分享
Mar 09 Python
keras读取h5文件load_weights、load代码操作
Jun 12 Python
深入理解python协程
Jun 15 Python
Pyorch之numpy与torch之间相互转换方式
Dec 31 #Python
pytorch sampler对数据进行采样的实现
Dec 31 #Python
关于pytorch处理类别不平衡的问题
Dec 31 #Python
pytorch 指定gpu训练与多gpu并行训练示例
Dec 31 #Python
浅析Django中关于session的使用
Dec 30 #Python
使用pickle存储数据dump 和 load实例讲解
Dec 30 #Python
在Python中利用pickle保存变量的实例
Dec 30 #Python
You might like
wordpress之wp-settings.php
2007/08/17 PHP
php中var_export与var_dump的区别分析
2010/08/21 PHP
YII路径的用法总结
2014/07/09 PHP
php实现用已经过去多长时间的方式显示时间
2015/06/05 PHP
PHP实现的简单适配器模式示例
2017/06/22 PHP
PHPMAILER实现PHP发邮件功能
2018/04/18 PHP
php如何获取Http请求
2020/04/30 PHP
用nodejs写的一个简单项目打包工具
2013/05/11 NodeJs
js用拖动滑块来控制图片大小的方法
2015/02/27 Javascript
jQuery制作可自定义大小的拼图游戏
2015/03/30 Javascript
获取JavaScript异步函数的返回值
2016/12/21 Javascript
使用jQuery实现页面定时弹出广告效果
2017/08/24 jQuery
nodejs对express中next函数的一些理解
2017/09/08 NodeJs
vue.js基于v-for实现批量渲染 Json数组对象列表数据示例
2019/08/03 Javascript
小程序中使用css var变量(使js可以动态设置css样式属性)
2020/03/31 Javascript
在Chrome DevTools中调试JavaScript的实现
2020/04/07 Javascript
js实现查询商品案例
2020/07/22 Javascript
Python标准库之循环器(itertools)介绍
2014/11/25 Python
尝试用最短的Python代码来实现服务器和代理服务器
2016/06/23 Python
python 字典(dict)按键和值排序
2016/06/28 Python
200 行python 代码实现 2048 游戏
2018/01/12 Python
Python+Turtle动态绘制一棵树实例分享
2018/01/16 Python
Django contenttypes 框架详解(小结)
2018/08/13 Python
python实现人工智能Ai抠图功能
2019/09/05 Python
TensorFlow tf.nn.conv2d_transpose是怎样实现反卷积的
2020/04/20 Python
pycharm第三方库安装失败的问题及解决经验分享
2020/05/09 Python
解决python中0x80072ee2错误的方法
2020/07/19 Python
Python使用urlretrieve实现直接远程下载图片的示例代码
2020/08/17 Python
css3中新增的样式使用示例附效果图
2014/08/19 HTML / CSS
canvas环形倒计时组件的示例代码
2018/06/14 HTML / CSS
html5移动端价格输入键盘的实现
2019/09/16 HTML / CSS
马来西亚最大的在线隐形眼镜商店:MrLens
2019/03/27 全球购物
网络工程专业毕业生推荐信
2013/10/28 职场文书
高二学年自我鉴定范文(2篇)
2014/09/26 职场文书
2014年医院科室工作总结
2014/12/20 职场文书
postgresql之greenplum字符串去重拼接方式
2023/05/08 PostgreSQL