keras多显卡训练方式


Posted in Python onJune 10, 2020

使用keras进行训练,默认使用单显卡,即使设置了os.environ['CUDA_VISIBLE_DEVICES']为两张显卡,也只是占满了显存,再设置tf.GPUOptions(allow_growth=True)之后可以清楚看到,只占用了第一张显卡,第二张显卡完全没用。

要使用多张显卡,需要按如下步骤:

(1)import multi_gpu_model函数:from keras.utils import multi_gpu_model

(2)在定义好model之后,使用multi_gpu_model设置模型由几张显卡训练,如下:

model=Model(...) #定义模型结构
model_parallel=multi_gpu_model(model,gpu=n) #使用几张显卡n等于几
model_parallel.compile(...) #注意是model_parallel,不是model

通过以上代码,model将作为CPU上的原始模型,而model_parallel将作为拷贝模型被复制到各个GPU上进行梯度计算。如果batchsize为128,显卡n=2,则每张显卡单独计算128/2=64张图像,然后在CPU上将两张显卡计算得到的梯度进行融合更新,并对模型权重进行更新后再将新模型拷贝到GPU再次训练。

(3)从上面可以看出,进行训练时,仍然在model_parallel上进行:

model_parallel.fit(...) #注意是model_parallel

(4)保存模型时,model_parallel保存了训练时显卡数量的信息,所以如果直接保存model_parallel的话,只能将模型设置为相同数量的显卡调用,否则训练的模型将不能调用。因此,为了之后的调用方便,只保存CPU上的模型,即model:

model.save(...) #注意是model,不是model_parallel

如果用到了callback函数,则默认保存的也是model_parallel(因为训练函数是针对model_parallel的),所以要用回调函数保存model的话需要自己对回调函数进行定义:

class OwnCheckpoint(keras.callbacks.Callback):
 def __init__(self,model):
  self.model_to_save=model
 def on_epoch_end(self,epoch,logs=None): #这里logs必须写
  self.model_to_save.save('model_advanced/model_%d.h5' % epoch)

定以后具体使用如下:

checkpoint=OwnCheckpoint(model)
model_parallel.fit_generator(...,callbacks=[checkpoint])

这样就没问题了!

补充知识:keras.fit_generator及多卡训练记录

1.环境问题

使用keras,以tensorflow为背景,tensorflow1.14多卡训练会出错 python3.6

2.代码

2.1

os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
os.environ['CUDA_VISIBLE_DEVICES'] = '4,5'

2.2 自定义generator函数

def img_image_generator(path_img, path_lab, batch_size, data_list):
 while True:
 # 'train_list.csv'
 file_list = pd.read_csv(data_list, sep=',',usecols=[1]).values.tolist()
 file_list = [i[0] for i in file_list]
 cnt = 0
 X = []
 Y1 = []
 for file_i in file_list:
 x = cv2.imread(path_img+'/'+file_i, cv2.IMREAD_GRAYSCALE)
 x = x.astype('float32')
 x /= 255.
 y = cv2.imread(path_lab+'/'+file_i, cv2.IMREAD_GRAYSCALE)
 y = y.astype('float32')
 y /= 255.
 X.append(x.reshape(256, 256, 1))
 Y1.append(y.reshape(256, 256, 1))
 cnt += 1
 if cnt == batch_size:
 cnt = 0
 yield (np.array(X), [np.array(Y1), np.array(Y1)])
 X = []
 Y1 = []

2.3 函数调用及训练

generator_train = img_image_generator(path1, path2, 4, pathcsv_train)
 generator_test= img_image_generator(path1, path2, 4, pathcsv_test)
 model.fit_generator(generator_train, steps_per_epoch=237*2, epochs=50, callbacks=callbacks_list, validation_data=generator_test, validation_steps=60*2)

3. 多卡训练

3.1 复制model

model_parallel = multi_gpu_model(model, gpus=2)

3.2 checkpoint 定义

class ParallelModelCheckpoint(ModelCheckpoint):
  def __init__(self, model, filepath, monitor='val_out_final_score', verbose=0,\
   save_best_only=False, save_weights_only=False, mode='auto', period=1):
   self.single_model = model 
   super(ParallelModelCheckpoint, self).__init__(filepath, monitor, verbose, save_best_only, save_weights_only, mode, period)
  
  def set_model(self, model):
   super(ParallelModelCheckpoint, self).set_model(self.single_model)

使用

model_checkpoint = ParallelModelCheckpoint(model=model, filepath=filepath, monitor='val_loss',verbose=1, save_best_only=True, mode='min')

3.3 注意的问题

保存模型是时候需要使用以原来的模型保存,不能使用model_parallel保存

以上这篇keras多显卡训练方式就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
用Python输出一个杨辉三角的例子
Jun 13 Python
python实现从网络下载文件并获得文件大小及类型的方法
Apr 28 Python
python实现解数独程序代码
Apr 12 Python
Python实现随机生成有效手机号码及身份证功能示例
Jun 05 Python
python kmeans聚类简单介绍和实现代码
Feb 23 Python
将Dataframe数据转化为ndarry数据的方法
Jun 28 Python
Django中信号signals的简单使用方法
Jul 04 Python
Python3.7+tkinter实现查询界面功能
Dec 24 Python
Python接口测试get请求过程详解
Feb 28 Python
Mac中PyCharm配置Anaconda环境的方法
Mar 04 Python
Python3爬虫里关于识别微博宫格验证码的知识点详解
Jul 30 Python
python logging 重复写日志问题解决办法详解
Aug 04 Python
使用SQLAlchemy操作数据库表过程解析
Jun 10 #Python
keras 多gpu并行运行案例
Jun 10 #Python
Keras自定义IOU方式
Jun 10 #Python
Python实现在线批量美颜功能过程解析
Jun 10 #Python
浅谈keras中的目标函数和优化函数MSE用法
Jun 10 #Python
keras 解决加载lstm+crf模型出错的问题
Jun 10 #Python
使用Keras加载含有自定义层或函数的模型操作
Jun 10 #Python
You might like
WinXP + Apache +PHP5 + MySQL + phpMyAdmin安装全功略
2006/07/09 PHP
php笔记之常用文件操作
2010/10/12 PHP
apache php模块整合操作指南
2012/11/16 PHP
PHP实现的链式队列结构示例
2017/09/15 PHP
tp5.1 框架数据库常见操作详解【添加、删除、更新、查询】
2020/05/26 PHP
JS宝典学习笔记(下)
2007/01/10 Javascript
javascript+xml技术实现分页浏览
2008/07/27 Javascript
javascript中动态加载js文件多种解决办法总结
2013/11/15 Javascript
JS与C#编码解码
2013/12/03 Javascript
文本框文本自动补全效果示例分享
2014/01/19 Javascript
Javascript 数组排序详解
2014/10/22 Javascript
jQuery事件绑定和委托实例
2014/11/25 Javascript
详解JavaScript中void语句的使用
2015/06/04 Javascript
jQuery插件Validate实现自定义表单验证
2016/01/18 Javascript
Easyui 之 Treegrid 笔记
2016/04/29 Javascript
JS使用JSON作为参数实例分析
2016/06/23 Javascript
JS命令模式例子之菜单程序
2016/10/10 Javascript
JS实现点击网页判断是否安装app并打开否则跳转app store
2016/11/18 Javascript
浅析JavaScriptSerializer类的序列化与反序列化
2016/11/22 Javascript
react native基于FlatList下拉刷新上拉加载实现代码示例
2018/09/30 Javascript
浅谈angularJs函数的使用方法(大小写转换,拷贝,扩充对象)
2018/10/08 Javascript
从零开始用electron手撸一个截屏工具的示例代码
2018/10/10 Javascript
微信小程序列表时间戳转换实现过程解析
2019/10/12 Javascript
vue实现图片裁剪后上传
2020/12/16 Vue.js
jQuery实现全选按钮
2021/01/01 jQuery
Python tempfile模块学习笔记(临时文件)
2014/05/25 Python
python 读取视频,处理后,实时计算帧数fps的方法
2018/07/10 Python
python识别图像并提取文字的实现方法
2019/06/28 Python
python中 * 的用法详解
2019/07/10 Python
遗体告别仪式答谢词
2014/01/23 职场文书
国际贸易专业求职信
2014/06/04 职场文书
如何签定毕业生就业协议书
2014/09/28 职场文书
个人租房协议书
2014/11/28 职场文书
详解Python常用的魔法方法
2021/06/03 Python
Redis 限流器
2022/05/15 Redis
Shell中的单中括号和双中括号的用法详解
2022/12/24 Servers