Pytorch 多块GPU的使用详解


Posted in Python onDecember 31, 2019

注:本文针对单个服务器上多块GPU的使用,不是多服务器多GPU的使用。

在一些实验中,由于Batch_size的限制或者希望提高训练速度等原因,我们需要使用多块GPU。本文针对Pytorch中多块GPU的使用进行说明。

1. 设置需要使用的GPU编号

import os
 
os.environ["CUDA_VISIBLE_DEVICES"] = "0,4"
ids = [0,1]

比如我们需要使用第0和第4块GPU,只用上述三行代码即可。

其中第二行指程序只能看到第1块和第4块GPU;

第三行的0即为第二行中编号为0的GPU;1即为编号为4的GPU。

2.更改网络,可以理解为将网络放入GPU

class CNN(nn.Module):
  def __init__(self):
    super(CNN,self).__init__()
    self.conv1 = nn.Sequential(
    ......
    )
    
    ......
    
    self.out = nn.Linear(Liner_input,2)
 
  ......
    
  def forward(self,x):
    x = self.conv1(x)
    ......
    output = self.out(x)
    return output,x
  
cnn = CNN()
 
# 更改,.cuda()表示将本存储到CPU的网络及其参数存储到GPU!
cnn.cuda()

3. 更改输出数据(如向量/矩阵/张量):

for epoch in range(EPOCH):
  epoch_loss = 0.
  for i, data in enumerate(train_loader2):
    image = data['image'] # data是字典,我们需要改的是其中的image
 
    #############更改!!!##################
    image = Variable(image).float().cuda()
    ############################################
 
    label = inputs['label']
    #############更改!!!##################
    label = Variable(label).type(torch.LongTensor).cuda()
    ############################################
    label = label.resize(BATCH_SIZE)
    output = cnn(image)[0]
    loss = loss_func(output, label)  # cross entropy loss
    optimizer.zero_grad()      # clear gradients for this training step
    loss.backward()         # backpropagation, compute gradients
    optimizer.step() 
    ... ...

4. 更改其他CPU与GPU冲突的地方

有些函数必要在GPU上完成,例如将Tensor转换为Numpy,就要使用data.cpu().numpy(),其中data是GPU上的Tensor。

若直接使用data.numpy()则会报错。除此之外,plot等也需要在CPU中完成。如果不是很清楚哪里要改的话可以先不改,等到程序报错了,再哪里错了改哪里,效率会更高。例如:

... ...
    #################################################
    pred_y = torch.max(test_train_output, 1)[1].data.cpu().numpy()
    
    accuracy = float((pred_y == label.cpu().numpy()).astype(int).sum()) / float(len(label.cpu().numpy()))

假如不加.cpu()便会报错,此时再改即可。

5. 更改前向传播函数,从而使用多块GPU

以VGG为例:

class VGG(nn.Module):
 
  def __init__(self, features, num_classes=2, init_weights=True):
    super(VGG, self).__init__()
... ...
 
  def forward(self, x):
    #x = self.features(x)
    #################Multi GPUS#############################
    x = nn.parallel.data_parallel(self.features,x,ids)
    x = x.view(x.size(0), -1)
    # x = self.classifier(x)
    x = nn.parallel.data_parallel(self.classifier,x,ids)
    return x
... ...

然后就可以看运行结果啦,nvidia-smi查看GPU使用情况:

Pytorch 多块GPU的使用详解

可以看到0和4都被使用啦

以上这篇Pytorch 多块GPU的使用详解就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
Python读取ini文件、操作mysql、发送邮件实例
Jan 01 Python
Python中的random()方法的使用介绍
May 15 Python
Python之re操作方法(详解)
Jun 14 Python
Django模板变量如何传递给外部js调用的方法小结
Jul 24 Python
Python+Pandas 获取数据库并加入DataFrame的实例
Jul 25 Python
Python爬虫beautifulsoup4常用的解析方法总结
Feb 25 Python
Django中使用session保持用户登陆连接的例子
Aug 06 Python
python 利用jinja2模板生成html代码实例
Oct 10 Python
python实现人像动漫化的示例代码
May 17 Python
python 中的9个实用技巧,助你提高开发效率
Aug 30 Python
python实现KNN近邻算法
Dec 30 Python
Python实现京东抢秒杀功能
Jan 25 Python
Pyorch之numpy与torch之间相互转换方式
Dec 31 #Python
pytorch sampler对数据进行采样的实现
Dec 31 #Python
关于pytorch处理类别不平衡的问题
Dec 31 #Python
pytorch 指定gpu训练与多gpu并行训练示例
Dec 31 #Python
浅析Django中关于session的使用
Dec 30 #Python
使用pickle存储数据dump 和 load实例讲解
Dec 30 #Python
在Python中利用pickle保存变量的实例
Dec 30 #Python
You might like
PHP更新购物车数量(表单部分/PHP处理部分)
2013/05/03 PHP
php 判断网页是否是utf8编码的方法
2014/06/06 PHP
PHP中cookie和session的区别实例分析
2014/08/28 PHP
PHP使用pear自带的mail类库发邮件的方法
2015/07/08 PHP
Zend Framework教程之Bootstrap类用法概述
2016/03/14 PHP
tp5 sum某个字段相加得到总数的例子
2019/10/18 PHP
HTML中Select不用Disabled实现ReadOnly的效果
2008/04/07 Javascript
jQuery学习笔记之jQuery的事件
2010/12/22 Javascript
解决IE6的PNG透明JS插件使用介绍
2013/04/17 Javascript
客户端js性能优化小技巧整理
2013/11/05 Javascript
jQuery选择器源码解读(二):select方法
2015/03/31 Javascript
Bootstrap Table的使用总结
2016/10/08 Javascript
jQuery中的一些小技巧
2017/01/18 Javascript
大白话讲解JavaScript的Promise
2017/04/06 Javascript
vue2.x 父组件监听子组件事件并传回信息的方法
2017/07/17 Javascript
浅谈Vue.js 组件中的v-on绑定自定义事件理解
2017/11/17 Javascript
Angular-UI Bootstrap组件实现警报功能
2018/07/16 Javascript
Vue 列表上下过渡效果的实例代码
2019/06/25 Javascript
使用Bootstrap做一个朝代历史表
2019/12/10 Javascript
react的hooks的用法详解
2020/10/12 Javascript
Python中用Ctrl+C终止多线程程序的问题解决
2013/03/30 Python
分享Pycharm中一些不为人知的技巧
2018/04/03 Python
pandas带有重复索引操作方法
2018/06/08 Python
Python类和对象的定义与实际应用案例分析
2018/12/27 Python
对Python生成器、装饰器、递归的使用详解
2019/07/19 Python
HTML5 实现一个访问本地文件的实例
2012/12/13 HTML / CSS
法国综合购物网站:RueDuCommerce
2016/09/12 全球购物
海蓝之谜(LA MER)澳大利亚官方商城:全球高端奢华护肤品牌
2017/10/27 全球购物
Urban Outfitters德国官网:美国跨国生活方式零售公司
2018/05/21 全球购物
大学生入党自我鉴定
2013/10/31 职场文书
微信营销策划方案
2014/02/24 职场文书
业务员自荐信范文
2014/04/20 职场文书
2015年党风廉政建设责任书
2015/01/29 职场文书
民政局未婚证明
2015/06/15 职场文书
花田少年史观后感
2015/06/16 职场文书
HTML+VUE分页实现炫酷物联网大屏功能
2021/05/27 Vue.js