keras 多gpu并行运行案例


Posted in Python onJune 10, 2020

一、多张gpu的卡上使用keras

有多张gpu卡时,推荐使用tensorflow 作为后端。使用多张gpu运行model,可以分为两种情况,一是数据并行,二是设备并行。

二、数据并行

数据并行将目标模型在多个设备上各复制一份,并使用每个设备上的复制品处理整个数据集的不同部分数据。

利用multi_gpu_model实现

keras.utils.multi_gpu_model(model, gpus=None, cpu_merge=True, cpu_relocation=False)

具体来说,该功能实现了单机多 GPU 数据并行性。 它的工作原理如下:

将模型的输入分成多个子批次。

在每个子批次上应用模型副本。 每个模型副本都在专用 GPU 上执行。

将结果(在 CPU 上)连接成一个大批量。

例如, 如果你的 batch_size 是 64,且你使用 gpus=2, 那么我们将把输入分为两个 32 个样本的子批次, 在 1 个 GPU 上处理 1 个子批次,然后返回完整批次的 64 个处理过的样本。

参数

model: 一个 Keras 模型实例。为了避免OOM错误,该模型可以建立在 CPU 上, 详见下面的使用样例。

gpus: 整数 >= 2 或整数列表,创建模型副本的 GPU 数量, 或 GPU ID 的列表。

cpu_merge: 一个布尔值,用于标识是否强制合并 CPU 范围内的模型权重。

cpu_relocation: 一个布尔值,用来确定是否在 CPU 的范围内创建模型的权重。如果模型没有在任何一个设备范围内定义,您仍然可以通过激活这个选项来拯救它。

返回

一个 Keras Model 实例,它可以像初始 model 参数一样使用,但它将工作负载分布在多个 GPU 上。

例子

import tensorflow as tf
from keras.applications import Xception
from keras.utils import multi_gpu_model
import numpy as np

num_samples = 1000
height = 224
width = 224
num_classes = 1000

# 实例化基础模型(或者「模版」模型)。
# 我们推荐在 CPU 设备范围内做此操作,
# 这样模型的权重就会存储在 CPU 内存中。
# 否则它们会存储在 GPU 上,而完全被共享。
with tf.device('/cpu:0'):
 model = Xception(weights=None,
   input_shape=(height, width, 3),
   classes=num_classes)

# 复制模型到 8 个 GPU 上。
# 这假设你的机器有 8 个可用 GPU。
parallel_model = multi_gpu_model(model, gpus=8)
parallel_model.compile(loss='categorical_crossentropy',
   optimizer='rmsprop')

# 生成虚拟数据
x = np.random.random((num_samples, height, width, 3))
y = np.random.random((num_samples, num_classes))

# 这个 `fit` 调用将分布在 8 个 GPU 上。
# 由于 batch size 是 256, 每个 GPU 将处理 32 个样本。
parallel_model.fit(x, y, epochs=20, batch_size=256)

# 通过模版模型存储模型(共享相同权重):
model.save('my_model.h5')

注意:

要保存多 GPU 模型,请通过模板模型(传递给 multi_gpu_model 的参数)调用 .save(fname) 或 .save_weights(fname) 以进行存储,而不是通过 multi_gpu_model 返回的模型。

即要用model来保存,而不是parallel_model来保存。

使用ModelCheckpoint() 遇到的问题

使用ModelCheckpoint()会遇到下面的问题:

TypeError: can't pickle ...(different text at different situation) objects

这个问题和保存问题类似,ModelCheckpoint() 会自动调用parallel_model.save()来保存,而不是model.save(),因此我们要自己写一个召回函数,使得ModelCheckpoint()用model.save()。

修改方法:

class ParallelModelCheckpoint(ModelCheckpoint):
 def __init__(self,model,filepath, monitor='val_loss', verbose=0,
   save_best_only=False, save_weights_only=False,
   mode='auto', period=1):
 self.single_model = model
 super(ParallelModelCheckpoint,self).__init__(filepath, monitor, verbose,save_best_only, save_weights_only,mode, period)

 def set_model(self, model):
 super(ParallelModelCheckpoint,self).set_model(self.single_model)

checkpoint = ParallelModelCheckpoint(original_model)

ParallelModelCheckpoint调用的时候,model应该为原来的model而不是parallel_model。

EarlyStopping 没有此类问题

二、设备并行

设备并行适用于多分支结构,一个分支用一个gpu。

这种并行方法可以通过使用TensorFlow device scopes实现,下面是一个例子:

# Model where a shared LSTM is used to encode two different sequences in parallel
input_a = keras.Input(shape=(140, 256))
input_b = keras.Input(shape=(140, 256))

shared_lstm = keras.layers.LSTM(64)

# Process the first sequence on one GPU
with tf.device_scope('/gpu:0'):
 encoded_a = shared_lstm(tweet_a)
# Process the next sequence on another GPU
with tf.device_scope('/gpu:1'):
 encoded_b = shared_lstm(tweet_b)

# Concatenate results on CPU
with tf.device_scope('/cpu:0'):
 merged_vector = keras.layers.concatenate([encoded_a, encoded_b],
      axis=-1)

三、分布式运行

keras的分布式是利用TensorFlow实现的,要想完成分布式的训练,你需要将Keras注册在连接一个集群的TensorFlow会话上:

server = tf.train.Server.create_local_server()
sess = tf.Session(server.target)

from keras import backend as K
K.set_session(sess)

以上这篇keras 多gpu并行运行案例就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
在Python中使用itertools模块中的组合函数的教程
Apr 13 Python
python解析xml文件实例分析
May 27 Python
python批量制作雷达图的实现方法
Jul 26 Python
Python 中urls.py:URL dispatcher(路由配置文件)详解
Mar 24 Python
Python实现对百度云的文件上传(实例讲解)
Oct 21 Python
selenium+python环境配置教程详解
May 28 Python
python 计算两个列表的相关系数的实现
Aug 29 Python
Pycharm最新激活码2019(推荐)
Dec 31 Python
为什么黑客都用python(123个黑客必备的Python工具)
Jan 31 Python
Python连接SQLite数据库并进行增册改查操作方法详解
Feb 18 Python
JupyterNotebook 输出窗口的显示效果调整实现
Sep 22 Python
Python 如何解决稀疏矩阵运算
May 26 Python
Keras自定义IOU方式
Jun 10 #Python
Python实现在线批量美颜功能过程解析
Jun 10 #Python
浅谈keras中的目标函数和优化函数MSE用法
Jun 10 #Python
keras 解决加载lstm+crf模型出错的问题
Jun 10 #Python
使用Keras加载含有自定义层或函数的模型操作
Jun 10 #Python
keras 获取某层的输入/输出 tensor 尺寸操作
Jun 10 #Python
Python 字典中的所有方法及用法
Jun 10 #Python
You might like
无法在发生错误时创建会话,请检查 PHP 或网站服务器日志,并正确配置 PHP 安装最快的解决办法
2010/08/01 PHP
PHP编码规范的深入探讨
2013/06/06 PHP
php生成数组的使用示例 php全组合算法
2014/01/16 PHP
Yii分页用法实例详解
2014/12/04 PHP
php实现在多维数组中查找特定value的方法
2015/07/29 PHP
PHP实现验证码校验功能
2017/11/16 PHP
laravel框架数据库操作、查询构建器、Eloquent ORM操作实例分析
2019/12/20 PHP
ImageZoom 图片放大镜效果(多功能扩展篇)
2010/04/14 Javascript
js 页面关闭前的出现提示的实现代码
2011/05/25 Javascript
不使用XMLHttpRequest实现异步加载 Iframe和script
2012/10/29 Javascript
基于jquery实现的定时显示与隐藏div广告的实现代码
2013/08/22 Javascript
javascript alert乱码的解决方法
2013/11/05 Javascript
vue.js删除动态绑定的radio的指定项
2017/06/02 Javascript
详解微信第三方小程序代开发
2017/06/23 Javascript
vue input 输入校验字母数字组合且长度小于30的实现代码
2018/05/16 Javascript
JavaScript fetch接口案例解析
2018/08/30 Javascript
在Express中提供静态文件的实现方法
2019/10/17 Javascript
详解Vue template 如何支持多个根结点
2020/02/10 Javascript
[50:45]2018DOTA2亚洲邀请赛 4.6 淘汰赛 VP vs TNC 第一场
2018/04/10 DOTA
[47:36]Optic vs Newbee 2018国际邀请赛小组赛BO2 第二场 8.17
2018/08/18 DOTA
[00:59]DOTA2背景故事第二期之四大基本法则
2020/07/07 DOTA
Python json 错误xx is not JSON serializable解决办法
2017/03/15 Python
python实现二分查找算法
2017/09/21 Python
轻松实现TensorFlow微信跳一跳的AI
2018/01/05 Python
python实现求解列表中元素的排列和组合问题
2018/03/15 Python
Python运维之获取系统CPU信息的实现方法
2018/06/11 Python
tensorflow 查看梯度方式
2020/02/04 Python
django中的数据库迁移的实现
2020/03/16 Python
关于tf.matmul() 和tf.multiply() 的区别说明
2020/06/18 Python
python爬虫利器之requests库的用法(超全面的爬取网页案例)
2020/12/17 Python
HTML5使用DOM进行自定义控制示例代码
2013/06/08 HTML / CSS
新西兰第一的行李箱网站:luggage.co.nz
2019/07/22 全球购物
介绍一下SQL中union,intersect和minus
2012/04/05 面试题
项目资料员岗位职责
2013/12/10 职场文书
优质服务口号
2014/06/11 职场文书
党员“一帮一”活动总结
2015/05/07 职场文书