keras多显卡训练方式


Posted in Python onJune 10, 2020

使用keras进行训练,默认使用单显卡,即使设置了os.environ['CUDA_VISIBLE_DEVICES']为两张显卡,也只是占满了显存,再设置tf.GPUOptions(allow_growth=True)之后可以清楚看到,只占用了第一张显卡,第二张显卡完全没用。

要使用多张显卡,需要按如下步骤:

(1)import multi_gpu_model函数:from keras.utils import multi_gpu_model

(2)在定义好model之后,使用multi_gpu_model设置模型由几张显卡训练,如下:

model=Model(...) #定义模型结构
model_parallel=multi_gpu_model(model,gpu=n) #使用几张显卡n等于几
model_parallel.compile(...) #注意是model_parallel,不是model

通过以上代码,model将作为CPU上的原始模型,而model_parallel将作为拷贝模型被复制到各个GPU上进行梯度计算。如果batchsize为128,显卡n=2,则每张显卡单独计算128/2=64张图像,然后在CPU上将两张显卡计算得到的梯度进行融合更新,并对模型权重进行更新后再将新模型拷贝到GPU再次训练。

(3)从上面可以看出,进行训练时,仍然在model_parallel上进行:

model_parallel.fit(...) #注意是model_parallel

(4)保存模型时,model_parallel保存了训练时显卡数量的信息,所以如果直接保存model_parallel的话,只能将模型设置为相同数量的显卡调用,否则训练的模型将不能调用。因此,为了之后的调用方便,只保存CPU上的模型,即model:

model.save(...) #注意是model,不是model_parallel

如果用到了callback函数,则默认保存的也是model_parallel(因为训练函数是针对model_parallel的),所以要用回调函数保存model的话需要自己对回调函数进行定义:

class OwnCheckpoint(keras.callbacks.Callback):
 def __init__(self,model):
  self.model_to_save=model
 def on_epoch_end(self,epoch,logs=None): #这里logs必须写
  self.model_to_save.save('model_advanced/model_%d.h5' % epoch)

定以后具体使用如下:

checkpoint=OwnCheckpoint(model)
model_parallel.fit_generator(...,callbacks=[checkpoint])

这样就没问题了!

补充知识:keras.fit_generator及多卡训练记录

1.环境问题

使用keras,以tensorflow为背景,tensorflow1.14多卡训练会出错 python3.6

2.代码

2.1

os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
os.environ['CUDA_VISIBLE_DEVICES'] = '4,5'

2.2 自定义generator函数

def img_image_generator(path_img, path_lab, batch_size, data_list):
 while True:
 # 'train_list.csv'
 file_list = pd.read_csv(data_list, sep=',',usecols=[1]).values.tolist()
 file_list = [i[0] for i in file_list]
 cnt = 0
 X = []
 Y1 = []
 for file_i in file_list:
 x = cv2.imread(path_img+'/'+file_i, cv2.IMREAD_GRAYSCALE)
 x = x.astype('float32')
 x /= 255.
 y = cv2.imread(path_lab+'/'+file_i, cv2.IMREAD_GRAYSCALE)
 y = y.astype('float32')
 y /= 255.
 X.append(x.reshape(256, 256, 1))
 Y1.append(y.reshape(256, 256, 1))
 cnt += 1
 if cnt == batch_size:
 cnt = 0
 yield (np.array(X), [np.array(Y1), np.array(Y1)])
 X = []
 Y1 = []

2.3 函数调用及训练

generator_train = img_image_generator(path1, path2, 4, pathcsv_train)
 generator_test= img_image_generator(path1, path2, 4, pathcsv_test)
 model.fit_generator(generator_train, steps_per_epoch=237*2, epochs=50, callbacks=callbacks_list, validation_data=generator_test, validation_steps=60*2)

3. 多卡训练

3.1 复制model

model_parallel = multi_gpu_model(model, gpus=2)

3.2 checkpoint 定义

class ParallelModelCheckpoint(ModelCheckpoint):
  def __init__(self, model, filepath, monitor='val_out_final_score', verbose=0,\
   save_best_only=False, save_weights_only=False, mode='auto', period=1):
   self.single_model = model 
   super(ParallelModelCheckpoint, self).__init__(filepath, monitor, verbose, save_best_only, save_weights_only, mode, period)
  
  def set_model(self, model):
   super(ParallelModelCheckpoint, self).set_model(self.single_model)

使用

model_checkpoint = ParallelModelCheckpoint(model=model, filepath=filepath, monitor='val_loss',verbose=1, save_best_only=True, mode='min')

3.3 注意的问题

保存模型是时候需要使用以原来的模型保存,不能使用model_parallel保存

以上这篇keras多显卡训练方式就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
Python实现将一个正整数分解质因数的方法分析
Dec 14 Python
详解Django之admin组件的使用和源码剖析
May 04 Python
python学习之hook钩子的原理和使用
Oct 25 Python
python从zip中删除指定后缀文件(推荐)
Dec 05 Python
Python求平面内点到直线距离的实现
Jan 19 Python
PageFactory设计模式基于python实现
Apr 14 Python
python实现秒杀商品的微信自动提醒功能(代码详解)
Apr 27 Python
python代码实现将列表中重复元素之间的内容全部滤除
May 22 Python
Python matplotlib读取excel数据并用for循环画多个子图subplot操作
Jul 14 Python
Python3基于plotly模块保存图片表格
Aug 03 Python
Python制作运行进度条的实现效果(代码运行不无聊)
Feb 24 Python
Python爬虫之用Xpath获取关键标签实现自动评论盖楼抽奖(二)
Jun 07 Python
使用SQLAlchemy操作数据库表过程解析
Jun 10 #Python
keras 多gpu并行运行案例
Jun 10 #Python
Keras自定义IOU方式
Jun 10 #Python
Python实现在线批量美颜功能过程解析
Jun 10 #Python
浅谈keras中的目标函数和优化函数MSE用法
Jun 10 #Python
keras 解决加载lstm+crf模型出错的问题
Jun 10 #Python
使用Keras加载含有自定义层或函数的模型操作
Jun 10 #Python
You might like
简单采集了yahoo的一些数据
2007/02/14 PHP
php中的四舍五入函数代码(floor函数、ceil函数、round与intval)
2014/07/14 PHP
php求数组全排列,元素所有组合的方法
2016/05/05 PHP
Thinkphp5框架使用validate实现验证功能的方法
2019/08/27 PHP
Laravel框架查询构造器 CURD操作示例
2019/09/04 PHP
PHP超级全局变量【$GLOBALS,$_SERVER,$_REQUEST等】用法实例分析
2019/12/11 PHP
屏蔽鼠标右键、Ctrl+n、shift+F10、F5刷新、退格键 的javascript代码
2007/04/01 Javascript
JavaScript confirm选择判断
2008/10/18 Javascript
JQuery中的ready函数冲突的解决方法
2010/05/17 Javascript
网站基于flash实现的Banner图切换效果代码
2014/10/14 Javascript
JS逆序遍历实现代码
2014/12/02 Javascript
JavaScript学习笔记之基础语法
2015/01/22 Javascript
JS实现在线统计一个页面内鼠标点击次数的方法
2015/02/28 Javascript
nodejs连接mongodb数据库实现增删改查
2016/12/01 NodeJs
JavaScript奇技淫巧44招【实用】
2016/12/11 Javascript
防止页面url缓存中ajax中post请求的处理方法
2017/10/10 Javascript
对Angular中单向数据流的深入理解
2018/03/31 Javascript
LayerClose弹窗关闭刷新方法
2018/08/17 Javascript
vue添加自定义右键菜单的完整实例
2020/12/08 Vue.js
[03:09]2014DOTA2国际邀请赛 赛场上的美丽风景线 中国Coser也爱DOTA2
2014/07/20 DOTA
python实现定时自动备份文件到其他主机的实例代码
2018/02/23 Python
使用Python从零开始撸一个区块链
2018/03/14 Python
Python去除、替换字符串空格的处理方法
2018/04/01 Python
python使用turtle库绘制树
2018/06/25 Python
对numpy中的transpose和swapaxes函数详解
2018/08/02 Python
Python中浅拷贝copy与深拷贝deepcopy的简单理解
2018/10/26 Python
scikit-learn线性回归,多元回归,多项式回归的实现
2019/08/29 Python
在Python中使用MySQL--PyMySQL的基本使用方法
2019/11/19 Python
python主要用于哪些方向
2020/07/05 Python
python多线程semaphore实现线程数控制的示例
2020/08/10 Python
维多利亚的秘密官方网站:Victoria’s Secret
2018/10/24 全球购物
印度电子产品购物网站:Vijay Sales
2021/02/16 全球购物
mysql有关权限的表都有哪几个
2015/04/22 面试题
支教自我鉴定
2014/01/18 职场文书
Django实现在线无水印抖音视频下载(附源码及地址)
2021/05/06 Python
解决pycharm安装scrapy DLL load failed:找不到指定的程序的问题
2021/06/08 Python