keras多显卡训练方式


Posted in Python onJune 10, 2020

使用keras进行训练,默认使用单显卡,即使设置了os.environ['CUDA_VISIBLE_DEVICES']为两张显卡,也只是占满了显存,再设置tf.GPUOptions(allow_growth=True)之后可以清楚看到,只占用了第一张显卡,第二张显卡完全没用。

要使用多张显卡,需要按如下步骤:

(1)import multi_gpu_model函数:from keras.utils import multi_gpu_model

(2)在定义好model之后,使用multi_gpu_model设置模型由几张显卡训练,如下:

model=Model(...) #定义模型结构
model_parallel=multi_gpu_model(model,gpu=n) #使用几张显卡n等于几
model_parallel.compile(...) #注意是model_parallel,不是model

通过以上代码,model将作为CPU上的原始模型,而model_parallel将作为拷贝模型被复制到各个GPU上进行梯度计算。如果batchsize为128,显卡n=2,则每张显卡单独计算128/2=64张图像,然后在CPU上将两张显卡计算得到的梯度进行融合更新,并对模型权重进行更新后再将新模型拷贝到GPU再次训练。

(3)从上面可以看出,进行训练时,仍然在model_parallel上进行:

model_parallel.fit(...) #注意是model_parallel

(4)保存模型时,model_parallel保存了训练时显卡数量的信息,所以如果直接保存model_parallel的话,只能将模型设置为相同数量的显卡调用,否则训练的模型将不能调用。因此,为了之后的调用方便,只保存CPU上的模型,即model:

model.save(...) #注意是model,不是model_parallel

如果用到了callback函数,则默认保存的也是model_parallel(因为训练函数是针对model_parallel的),所以要用回调函数保存model的话需要自己对回调函数进行定义:

class OwnCheckpoint(keras.callbacks.Callback):
 def __init__(self,model):
  self.model_to_save=model
 def on_epoch_end(self,epoch,logs=None): #这里logs必须写
  self.model_to_save.save('model_advanced/model_%d.h5' % epoch)

定以后具体使用如下:

checkpoint=OwnCheckpoint(model)
model_parallel.fit_generator(...,callbacks=[checkpoint])

这样就没问题了!

补充知识:keras.fit_generator及多卡训练记录

1.环境问题

使用keras,以tensorflow为背景,tensorflow1.14多卡训练会出错 python3.6

2.代码

2.1

os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
os.environ['CUDA_VISIBLE_DEVICES'] = '4,5'

2.2 自定义generator函数

def img_image_generator(path_img, path_lab, batch_size, data_list):
 while True:
 # 'train_list.csv'
 file_list = pd.read_csv(data_list, sep=',',usecols=[1]).values.tolist()
 file_list = [i[0] for i in file_list]
 cnt = 0
 X = []
 Y1 = []
 for file_i in file_list:
 x = cv2.imread(path_img+'/'+file_i, cv2.IMREAD_GRAYSCALE)
 x = x.astype('float32')
 x /= 255.
 y = cv2.imread(path_lab+'/'+file_i, cv2.IMREAD_GRAYSCALE)
 y = y.astype('float32')
 y /= 255.
 X.append(x.reshape(256, 256, 1))
 Y1.append(y.reshape(256, 256, 1))
 cnt += 1
 if cnt == batch_size:
 cnt = 0
 yield (np.array(X), [np.array(Y1), np.array(Y1)])
 X = []
 Y1 = []

2.3 函数调用及训练

generator_train = img_image_generator(path1, path2, 4, pathcsv_train)
 generator_test= img_image_generator(path1, path2, 4, pathcsv_test)
 model.fit_generator(generator_train, steps_per_epoch=237*2, epochs=50, callbacks=callbacks_list, validation_data=generator_test, validation_steps=60*2)

3. 多卡训练

3.1 复制model

model_parallel = multi_gpu_model(model, gpus=2)

3.2 checkpoint 定义

class ParallelModelCheckpoint(ModelCheckpoint):
  def __init__(self, model, filepath, monitor='val_out_final_score', verbose=0,\
   save_best_only=False, save_weights_only=False, mode='auto', period=1):
   self.single_model = model 
   super(ParallelModelCheckpoint, self).__init__(filepath, monitor, verbose, save_best_only, save_weights_only, mode, period)
  
  def set_model(self, model):
   super(ParallelModelCheckpoint, self).set_model(self.single_model)

使用

model_checkpoint = ParallelModelCheckpoint(model=model, filepath=filepath, monitor='val_loss',verbose=1, save_best_only=True, mode='min')

3.3 注意的问题

保存模型是时候需要使用以原来的模型保存,不能使用model_parallel保存

以上这篇keras多显卡训练方式就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
Python 连连看连接算法
Nov 22 Python
Saltstack快速入门简单汇总
Mar 01 Python
详细介绍Python的鸭子类型
Sep 12 Python
Python3使用PyQt5制作简单的画板/手写板实例
Oct 19 Python
理解python中生成器用法
Dec 20 Python
Python解决八皇后问题示例
Apr 22 Python
深入理解Python异常处理的哲学
Feb 01 Python
python制作抖音代码舞
Apr 07 Python
解决yum对python依赖版本问题
Jul 05 Python
详解python中的index函数用法
Aug 06 Python
利用python读取YUV文件 转RGB 8bit/10bit通用
Dec 09 Python
python网络编程之五子棋游戏
May 14 Python
使用SQLAlchemy操作数据库表过程解析
Jun 10 #Python
keras 多gpu并行运行案例
Jun 10 #Python
Keras自定义IOU方式
Jun 10 #Python
Python实现在线批量美颜功能过程解析
Jun 10 #Python
浅谈keras中的目标函数和优化函数MSE用法
Jun 10 #Python
keras 解决加载lstm+crf模型出错的问题
Jun 10 #Python
使用Keras加载含有自定义层或函数的模型操作
Jun 10 #Python
You might like
php代码把全角数字转为半角数字
2007/12/10 PHP
php下判断数组中是否存在相同的值array_unique
2008/03/25 PHP
php通过获取头信息判断图片类型的方法
2015/06/26 PHP
类似CSDN图片切换效果脚本
2009/09/17 Javascript
密码强度检测效果实现原理与代码
2013/01/04 Javascript
js采用map取到id集合组并且实现点击一行选中一行
2013/12/16 Javascript
实现前后端数据交互方法汇总
2015/04/07 Javascript
jquery实现不包含当前项的选择器实例
2015/06/25 Javascript
jQuery 3.0中存在问题及解决办法
2016/07/15 Javascript
点击页面任何位置隐藏div的实现方法
2016/09/05 Javascript
基于JS实现弹出一个隐藏的div窗口body页面变成灰色并且不可被编辑
2016/12/14 Javascript
微信小程序 密码输入(源码下载)
2017/06/27 Javascript
简单实现js进度条加载效果
2020/03/25 Javascript
vue初尝试--项目结构(推荐)
2018/01/30 Javascript
Vue props用法详解(小结)
2018/07/03 Javascript
微信小程序 SOTER 生物认证DEMO 指纹识别功能
2019/12/13 Javascript
[00:43]TI7不朽珍藏III——幽鬼不朽展示
2017/07/15 DOTA
[01:05:30]VP vs TNC 2018国际邀请赛小组赛BO2 第一场 8.17
2018/08/20 DOTA
python str与repr的区别
2013/03/23 Python
Python爬虫抓取手机APP的传输数据
2016/01/22 Python
Python用zip函数同时遍历多个迭代器示例详解
2016/11/14 Python
Python网络编程之TCP套接字简单用法示例
2018/04/09 Python
python生成每日报表数据(Excel)并邮件发送的实例
2019/02/03 Python
python中 * 的用法详解
2019/07/10 Python
如何在Django配置文件里配置session链接
2019/08/06 Python
自定义Django默认的sitemap站点地图样式
2020/03/04 Python
Python实现Wordcloud生成词云图的示例
2020/03/30 Python
CSS3实现曲线阴影和翘边阴影
2016/05/03 HTML / CSS
当文件系统受到破坏时,如何检查和修复系统?
2012/03/09 面试题
团日活动总结书格式
2014/05/08 职场文书
新书发布会策划方案
2014/06/09 职场文书
小学感恩教育活动总结
2014/07/07 职场文书
财务检查整改报告
2014/11/06 职场文书
简单的离婚协议书范本
2014/11/16 职场文书
幼儿园个人总结
2015/02/28 职场文书
慈善募捐倡议书
2015/04/27 职场文书