Pytorch 多块GPU的使用详解


Posted in Python onDecember 31, 2019

注:本文针对单个服务器上多块GPU的使用,不是多服务器多GPU的使用。

在一些实验中,由于Batch_size的限制或者希望提高训练速度等原因,我们需要使用多块GPU。本文针对Pytorch中多块GPU的使用进行说明。

1. 设置需要使用的GPU编号

import os
 
os.environ["CUDA_VISIBLE_DEVICES"] = "0,4"
ids = [0,1]

比如我们需要使用第0和第4块GPU,只用上述三行代码即可。

其中第二行指程序只能看到第1块和第4块GPU;

第三行的0即为第二行中编号为0的GPU;1即为编号为4的GPU。

2.更改网络,可以理解为将网络放入GPU

class CNN(nn.Module):
  def __init__(self):
    super(CNN,self).__init__()
    self.conv1 = nn.Sequential(
    ......
    )
    
    ......
    
    self.out = nn.Linear(Liner_input,2)
 
  ......
    
  def forward(self,x):
    x = self.conv1(x)
    ......
    output = self.out(x)
    return output,x
  
cnn = CNN()
 
# 更改,.cuda()表示将本存储到CPU的网络及其参数存储到GPU!
cnn.cuda()

3. 更改输出数据(如向量/矩阵/张量):

for epoch in range(EPOCH):
  epoch_loss = 0.
  for i, data in enumerate(train_loader2):
    image = data['image'] # data是字典,我们需要改的是其中的image
 
    #############更改!!!##################
    image = Variable(image).float().cuda()
    ############################################
 
    label = inputs['label']
    #############更改!!!##################
    label = Variable(label).type(torch.LongTensor).cuda()
    ############################################
    label = label.resize(BATCH_SIZE)
    output = cnn(image)[0]
    loss = loss_func(output, label)  # cross entropy loss
    optimizer.zero_grad()      # clear gradients for this training step
    loss.backward()         # backpropagation, compute gradients
    optimizer.step() 
    ... ...

4. 更改其他CPU与GPU冲突的地方

有些函数必要在GPU上完成,例如将Tensor转换为Numpy,就要使用data.cpu().numpy(),其中data是GPU上的Tensor。

若直接使用data.numpy()则会报错。除此之外,plot等也需要在CPU中完成。如果不是很清楚哪里要改的话可以先不改,等到程序报错了,再哪里错了改哪里,效率会更高。例如:

... ...
    #################################################
    pred_y = torch.max(test_train_output, 1)[1].data.cpu().numpy()
    
    accuracy = float((pred_y == label.cpu().numpy()).astype(int).sum()) / float(len(label.cpu().numpy()))

假如不加.cpu()便会报错,此时再改即可。

5. 更改前向传播函数,从而使用多块GPU

以VGG为例:

class VGG(nn.Module):
 
  def __init__(self, features, num_classes=2, init_weights=True):
    super(VGG, self).__init__()
... ...
 
  def forward(self, x):
    #x = self.features(x)
    #################Multi GPUS#############################
    x = nn.parallel.data_parallel(self.features,x,ids)
    x = x.view(x.size(0), -1)
    # x = self.classifier(x)
    x = nn.parallel.data_parallel(self.classifier,x,ids)
    return x
... ...

然后就可以看运行结果啦,nvidia-smi查看GPU使用情况:

Pytorch 多块GPU的使用详解

可以看到0和4都被使用啦

以上这篇Pytorch 多块GPU的使用详解就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
mac安装scrapy并创建项目的实例讲解
Jun 13 Python
Python对象属性自动更新操作示例
Jun 15 Python
python 正确保留多位小数的实例
Jul 16 Python
python通过paramiko复制远程文件及文件目录到本地
Apr 30 Python
python移位运算的实现
Jul 15 Python
Python+AutoIt实现界面工具开发过程详解
Aug 07 Python
tensorflow如何批量读取图片
Aug 29 Python
python中dict()的高级用法实现
Nov 13 Python
详解用Pytest+Allure生成漂亮的HTML图形化测试报告
Mar 31 Python
最新版 Windows10上安装Python 3.8.5的步骤详解
Nov 28 Python
Django扫码抽奖平台的配置过程详解
Jan 14 Python
单身狗福利?Python爬取某婚恋网征婚数据
Jun 03 Python
Pyorch之numpy与torch之间相互转换方式
Dec 31 #Python
pytorch sampler对数据进行采样的实现
Dec 31 #Python
关于pytorch处理类别不平衡的问题
Dec 31 #Python
pytorch 指定gpu训练与多gpu并行训练示例
Dec 31 #Python
浅析Django中关于session的使用
Dec 30 #Python
使用pickle存储数据dump 和 load实例讲解
Dec 30 #Python
在Python中利用pickle保存变量的实例
Dec 30 #Python
You might like
PHP邮件专题
2006/10/09 PHP
BBS(php & mysql)完整版(二)
2006/10/09 PHP
php+mysql写的简单留言本实例代码
2008/07/25 PHP
yii2中添加验证码的实现方法
2016/01/09 PHP
示例详解Laravel重置密码代码重构
2016/08/10 PHP
php版微信开发之接收消息,自动判断及回复相应消息的方法
2016/09/23 PHP
redirect_uri参数错误的解决方法(必看)
2017/02/16 PHP
详解php中serialize()和unserialize()函数
2017/07/08 PHP
PHP生成随机字符串实例代码(字母+数字)
2019/09/11 PHP
javascript学习随笔(使用window和frame)的技巧
2007/03/08 Javascript
国外大牛IE版本检测!现在IE都到9了,IE检测代码
2012/01/04 Javascript
对Web开发中前端框架与前端类库的一些思考
2015/03/27 Javascript
JavaScript小技巧整理
2015/12/30 Javascript
jquery中validate与form插件提交的方式小结
2016/03/26 Javascript
JavaScript操作表单实例讲解(上)
2016/06/20 Javascript
使用ES6语法重构React代码详解
2017/05/09 Javascript
D3.js实现拓扑图的示例代码
2018/06/30 Javascript
基于vue-cli、elementUI的Vue超简单入门小例子(推荐)
2019/04/17 Javascript
阿望教你用vue写扫雷小游戏
2020/01/20 Javascript
JavaScript队列结构Queue实现过程解析
2020/03/07 Javascript
剖析Python的Tornado框架中session支持的实现代码
2015/08/21 Python
Python类反射机制使用实例解析
2019/12/30 Python
keras绘制acc和loss曲线图实例
2020/06/15 Python
Python scrapy爬取小说代码案例详解
2020/07/09 Python
Python加速程序运行的方法
2020/07/29 Python
python程序实现BTC(比特币)挖矿的完整代码
2021/01/20 Python
美国瑜伽品牌:Gaiam
2017/10/31 全球购物
马德里运动鞋商店:Nigra Mercato
2020/02/16 全球购物
创新比赛获奖感言
2014/02/13 职场文书
四风查摆问题及整改措施
2014/10/10 职场文书
婚礼家长致辞
2015/07/27 职场文书
小学体育队列队形教学反思
2016/02/16 职场文书
Nginx的rewrite模块详解
2021/03/31 Servers
python正则表达式re.search()的基本使用教程
2021/05/21 Python
SqlServer数据库远程连接案例教程
2021/07/15 SQL Server
TV动画《神废柴☆偶像》公布先导PV
2022/03/20 日漫