Keras - GPU ID 和显存占用设定步骤


Posted in Python onJune 22, 2020

初步尝试 Keras (基于 Tensorflow 后端)深度框架时, 发现其对于 GPU 的使用比较神奇, 默认竟然是全部占满显存, 1080Ti 跑个小分类问题, 就一下子满了. 而且是服务器上的两张 1080Ti.

服务器上的多张 GPU 都占满, 有点浪费性能.

因此, 需要类似于 Caffe 等框架的可以设定 GPU ID 和显存自动按需分配.

实际中发现, Keras 还可以限制 GPU 显存占用量.

这里涉及到的内容有:

GPU ID 设定

GPU 显存占用按需分配

GPU 显存占用限制

GPU 显存优化

1. GPU ID 设定

#! -- coding: utf-8 --*--
import os
os.environ["CUDA_VISIBLE_DEVICES"] = "1"

这里将 GPU ID 设为 1.

GPU ID 从 0 开始, GPUID=1 即表示第二块 GPU.

2. GPU 显存占用按需分配

#! -- coding: utf-8 --*--
import tensorflow as tf
import keras.backend.tensorflow_backend as ktf

# GPU 显存自动调用
config = tf.ConfigProto()
config.gpu_options.allow_growth=True
session = tf.Session(config=config)
ktf.set_session(session)

3. GPU 显存占用限制

#! -- coding: utf-8 --*--
import tensorflow as tf
import keras.backend.tensorflow_backend as ktf

# 设定 GPU 显存占用比例为 0.3
config = tf.ConfigProto()
config.gpu_options.per_process_gpu_memory_fraction = 0.3
session = tf.Session(config=config)
ktf.set_session(session )

这里虽然是设定了 GPU 显存占用的限制比例(0.3), 但如果训练所需实际显存占用超过该比例, 仍能正常训练, 类似于了按需分配.

设定 GPU 显存占用比例实际上是避免一定的显存资源浪费.

4. GPU ID 设定与显存按需分配

#! -- coding: utf-8 --*--
import os
import tensorflow as tf
import keras.backend.tensorflow_backend as ktf

# GPU 显存自动分配
config = tf.ConfigProto()
config.gpu_options.allow_growth=True
#config.gpu_options.per_process_gpu_memory_fraction = 0.3
session = tf.Session(config=config)
ktf.set_session(session)

# 指定GPUID, 第一块GPU可用
os.environ["CUDA_VISIBLE_DEVICES"] = "0"

5. 利用fit_generator最小化显存占用比例/数据Batch化

#! -- coding: utf-8 --*--

# 将内存中的数据分批(batch_size)送到显存中进行运算
def generate_arrays_from_memory(data_train, labels_train, batch_size):
  x = data_train
  y=labels_train
  ylen=len(y)
  loopcount=ylen // batch_size
  while True:
    i = np.random.randint(0,loopcount)
    yield x[i*batch_size:(i+1)*batch_size],y[i*batch_size:(i+1)*batch_size]

# load数据到内存
data_train=np.loadtxt("./data_train.txt")
labels_train=np.loadtxt('./labels_train.txt')
data_val=np.loadtxt('./data_val.txt')
labels_val=np.loadtxt('./labels_val.txt')

hist=model.fit_generator(generate_arrays_from_memory(data_train,
                           labels_train,
                           batch_size),
             steps_per_epoch=int(train_size/bs),
             epochs=ne,
             validation_data=(data_val,labels_val),
             callbacks=callbacks )

5.1 数据 Batch 化

#! -- coding: utf-8 --*--

def process_line(line): 
  tmp = [int(val) for val in line.strip().split(',')] 
  x = np.array(tmp[:-1]) 
  y = np.array(tmp[-1:]) 
  return x,y 

def generate_arrays_from_file(path,batch_size): 
  while 1: 
    f = open(path) 
    cnt = 0 
    X =[] 
    Y =[] 
    for line in f: 
      # create Numpy arrays of input data 
      # and labels, from each line in the file 
      x, y = process_line(line) 
      X.append(x) 
      Y.append(y) 
      cnt += 1 
      if cnt==batch_size: 
        cnt = 0 
        yield (np.array(X), np.array(Y)) 
        X = [] 
        Y = [] 
  f.close()

补充知识:Keras+Tensorflow指定运行显卡以及关闭session空出显存

Keras - GPU ID 和显存占用设定步骤

Step1: 查看GPU

watch -n 3 nvidia-smi #在命令行窗口中查看当前GPU使用的情况, 3为刷新频率

Step2: 导入模块

导入必要的模块

import os
import tensorflow as tf
from keras.backend.tensorflow_backend import set_session
from numba import cuda

Step3: 指定GPU

程序开头指定程序运行的GPU

os.environ['CUDA_VISIBLE_DEVICES'] = '1' # 使用单块GPU,指定其编号即可 (0 or 1or 2 or 3)
os.environ['CUDA_VISIBLE_DEVICES'] = '1,2,3' # 使用多块GPU,指定其编号即可 (引号中指定即可)

Step4: 创建会话,指定显存使用百分比

创建tensorflow的Session

config = tf.ConfigProto()
config.gpu_options.per_process_gpu_memory_fraction = 0.1 # 设定显存的利用率
set_session(tf.Session(config=config))

Step5: 释放显存

确保Volatile GPU-Util显示0%

程序运行完毕,关闭Session

K.clear_session() # 方法一:如果不关闭,则会一直占用显存

cuda.select_device(1) # 方法二:选择GPU1
cuda.close() #关闭选择的GPU

Keras - GPU ID 和显存占用设定步骤

以上这篇Keras - GPU ID 和显存占用设定步骤就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
Python中用altzone()方法处理时区的教程
May 22 Python
python实现实时监控文件的方法
Aug 26 Python
win系统下为Python3.5安装flask-mongoengine 库
Dec 20 Python
Windows系统下多版本pip的共存问题详解
Oct 10 Python
python中numpy.zeros(np.zeros)的使用方法
Nov 07 Python
Python利用splinter实现浏览器自动化操作方法
May 11 Python
pycharm重置设置,恢复默认设置的方法
Oct 22 Python
Python字典的核心底层原理讲解
Jan 24 Python
python GUI库图形界面开发之PyQt5树形结构控件QTreeWidget详细使用方法与实例
Mar 02 Python
Python安装第三方库攻略(pip和Anaconda)
Oct 15 Python
python将下载到本地m3u8视频合成MP4的代码详解
Nov 24 Python
Python之基础函数案例详解
Aug 30 Python
Python 基于jwt实现认证机制流程解析
Jun 22 #Python
python中format函数如何使用
Jun 22 #Python
Tensorflow与Keras自适应使用显存方式
Jun 22 #Python
python数据类型强制转换实例详解
Jun 22 #Python
keras 指定程序在某块卡上训练实例
Jun 22 #Python
python Socket网络编程实现C/S模式和P2P
Jun 22 #Python
Python手动或自动协程操作方法解析
Jun 22 #Python
You might like
PHP+MYSQL的文章管理系统(二)
2006/10/09 PHP
php下载文件的代码示例
2012/06/29 PHP
ThinkPHP验证码和分页实例教程
2014/08/22 PHP
php 把数字转换成汉字的代码
2015/07/21 PHP
nicejforms——美化表单不用愁
2007/02/20 Javascript
JavaScript继承方式实例
2010/10/29 Javascript
基于jQuery的左右滚动实现代码
2010/12/03 Javascript
Iframe自适应高度绝对好使的代码 兼容IE,遨游,火狐
2011/01/27 Javascript
jquery ui bootstrap 实现自定义风格
2014/11/14 Javascript
JavaScript里实用的原生API汇总
2015/05/14 Javascript
jQuery+css实现的蓝色水平二级导航菜单效果代码
2015/09/11 Javascript
JavaScript提升性能的常用技巧总结【经典】
2016/06/20 Javascript
如何安装控制器JavaScript生成插件详解
2018/10/21 Javascript
Vue使用Proxy监听所有接口状态的方法实现
2019/06/07 Javascript
vue使用keep-alive实现组件切换时保存原组件数据方法
2020/10/30 Javascript
[45:16]完美世界DOTA2联赛循环赛 IO vs FTD BO2第二场 11.05
2020/11/06 DOTA
python实现自动更换ip的方法
2015/05/05 Python
5款非常棒的Python工具
2018/01/05 Python
Pycharm导入Python包,模块的图文教程
2018/06/13 Python
解决python3读取Python2存储的pickle文件问题
2018/10/25 Python
利用Python对文件夹下图片数据进行批量改名的代码实例
2019/02/21 Python
windows下python虚拟环境virtualenv安装和使用详解
2019/07/16 Python
Python:slice与indices的用法
2019/11/25 Python
Python Pickle 实现在同一个文件中序列化多个对象
2019/12/30 Python
基于python+selenium的二次封装的实现
2020/01/06 Python
Python标准库json模块和pickle模块使用详解
2020/03/10 Python
Python约瑟夫生者死者小游戏实例讲解
2021/01/04 Python
css3 transform及原生js实现鼠标拖动3D立方体旋转
2016/06/20 HTML / CSS
泰国排名第一的家居用品中心:HomePro
2020/11/18 全球购物
垃圾回收的优点和原理。并考虑2种回收机制
2016/10/16 面试题
数控技术专业毕业自荐书范文
2014/02/05 职场文书
《口技》教学反思
2014/02/21 职场文书
《爱如茉莉》教后反思
2014/04/12 职场文书
2015年教学管理工作总结
2015/05/20 职场文书
go语言中fallthrough的用法说明
2021/05/06 Golang
MySQL中CURRENT_TIMESTAMP的使用方式
2021/11/27 MySQL