keras多显卡训练方式


Posted in Python onJune 10, 2020

使用keras进行训练,默认使用单显卡,即使设置了os.environ['CUDA_VISIBLE_DEVICES']为两张显卡,也只是占满了显存,再设置tf.GPUOptions(allow_growth=True)之后可以清楚看到,只占用了第一张显卡,第二张显卡完全没用。

要使用多张显卡,需要按如下步骤:

(1)import multi_gpu_model函数:from keras.utils import multi_gpu_model

(2)在定义好model之后,使用multi_gpu_model设置模型由几张显卡训练,如下:

model=Model(...) #定义模型结构
model_parallel=multi_gpu_model(model,gpu=n) #使用几张显卡n等于几
model_parallel.compile(...) #注意是model_parallel,不是model

通过以上代码,model将作为CPU上的原始模型,而model_parallel将作为拷贝模型被复制到各个GPU上进行梯度计算。如果batchsize为128,显卡n=2,则每张显卡单独计算128/2=64张图像,然后在CPU上将两张显卡计算得到的梯度进行融合更新,并对模型权重进行更新后再将新模型拷贝到GPU再次训练。

(3)从上面可以看出,进行训练时,仍然在model_parallel上进行:

model_parallel.fit(...) #注意是model_parallel

(4)保存模型时,model_parallel保存了训练时显卡数量的信息,所以如果直接保存model_parallel的话,只能将模型设置为相同数量的显卡调用,否则训练的模型将不能调用。因此,为了之后的调用方便,只保存CPU上的模型,即model:

model.save(...) #注意是model,不是model_parallel

如果用到了callback函数,则默认保存的也是model_parallel(因为训练函数是针对model_parallel的),所以要用回调函数保存model的话需要自己对回调函数进行定义:

class OwnCheckpoint(keras.callbacks.Callback):
 def __init__(self,model):
  self.model_to_save=model
 def on_epoch_end(self,epoch,logs=None): #这里logs必须写
  self.model_to_save.save('model_advanced/model_%d.h5' % epoch)

定以后具体使用如下:

checkpoint=OwnCheckpoint(model)
model_parallel.fit_generator(...,callbacks=[checkpoint])

这样就没问题了!

补充知识:keras.fit_generator及多卡训练记录

1.环境问题

使用keras,以tensorflow为背景,tensorflow1.14多卡训练会出错 python3.6

2.代码

2.1

os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
os.environ['CUDA_VISIBLE_DEVICES'] = '4,5'

2.2 自定义generator函数

def img_image_generator(path_img, path_lab, batch_size, data_list):
 while True:
 # 'train_list.csv'
 file_list = pd.read_csv(data_list, sep=',',usecols=[1]).values.tolist()
 file_list = [i[0] for i in file_list]
 cnt = 0
 X = []
 Y1 = []
 for file_i in file_list:
 x = cv2.imread(path_img+'/'+file_i, cv2.IMREAD_GRAYSCALE)
 x = x.astype('float32')
 x /= 255.
 y = cv2.imread(path_lab+'/'+file_i, cv2.IMREAD_GRAYSCALE)
 y = y.astype('float32')
 y /= 255.
 X.append(x.reshape(256, 256, 1))
 Y1.append(y.reshape(256, 256, 1))
 cnt += 1
 if cnt == batch_size:
 cnt = 0
 yield (np.array(X), [np.array(Y1), np.array(Y1)])
 X = []
 Y1 = []

2.3 函数调用及训练

generator_train = img_image_generator(path1, path2, 4, pathcsv_train)
 generator_test= img_image_generator(path1, path2, 4, pathcsv_test)
 model.fit_generator(generator_train, steps_per_epoch=237*2, epochs=50, callbacks=callbacks_list, validation_data=generator_test, validation_steps=60*2)

3. 多卡训练

3.1 复制model

model_parallel = multi_gpu_model(model, gpus=2)

3.2 checkpoint 定义

class ParallelModelCheckpoint(ModelCheckpoint):
  def __init__(self, model, filepath, monitor='val_out_final_score', verbose=0,\
   save_best_only=False, save_weights_only=False, mode='auto', period=1):
   self.single_model = model 
   super(ParallelModelCheckpoint, self).__init__(filepath, monitor, verbose, save_best_only, save_weights_only, mode, period)
  
  def set_model(self, model):
   super(ParallelModelCheckpoint, self).set_model(self.single_model)

使用

model_checkpoint = ParallelModelCheckpoint(model=model, filepath=filepath, monitor='val_loss',verbose=1, save_best_only=True, mode='min')

3.3 注意的问题

保存模型是时候需要使用以原来的模型保存,不能使用model_parallel保存

以上这篇keras多显卡训练方式就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
教大家玩转Python字符串处理的七种技巧
Mar 31 Python
浅谈numpy中linspace的用法 (等差数列创建函数)
Jun 07 Python
Python编程之string相关操作实例详解
Jul 22 Python
Django框架教程之正则表达式URL误区详解
Jan 28 Python
python tools实现视频的每一帧提取并保存
Mar 20 Python
pandas DataFrame的修改方法(值、列、索引)
Aug 02 Python
Python函数中的可变长参数详解
Sep 12 Python
django框架使用views.py的函数对表进行增删改查内容操作详解【models.py中表的创建、views.py中函数的使用,基于对象的跨表查询】
Dec 12 Python
python序列类型种类详解
Feb 26 Python
pycharm软件实现设置自动保存操作
Jun 08 Python
Python使用urlretrieve实现直接远程下载图片的示例代码
Aug 17 Python
python 如何将office文件转换为PDF
Sep 22 Python
使用SQLAlchemy操作数据库表过程解析
Jun 10 #Python
keras 多gpu并行运行案例
Jun 10 #Python
Keras自定义IOU方式
Jun 10 #Python
Python实现在线批量美颜功能过程解析
Jun 10 #Python
浅谈keras中的目标函数和优化函数MSE用法
Jun 10 #Python
keras 解决加载lstm+crf模型出错的问题
Jun 10 #Python
使用Keras加载含有自定义层或函数的模型操作
Jun 10 #Python
You might like
PL-880隐藏功能
2021/03/01 无线电
php检测iis环境是否支持htaccess的方法
2014/02/18 PHP
PHP经典算法集锦【经典收藏】
2016/09/14 PHP
Yii2框架BootStrap样式的深入理解
2016/11/07 PHP
PHP实现的DES加密解密封装类完整实例
2017/04/29 PHP
javascript延时重复执行函数 lLoopRun.js
2007/06/29 Javascript
javascript 语法基础 想学习js的朋友可以看看
2009/12/16 Javascript
小议Javascript中的this指针
2010/03/18 Javascript
jQuery 中DOM 操作详解
2015/01/13 Javascript
javascript模拟评分控件实现方法
2015/05/13 Javascript
JavaScript+CSS实现仿Mootools竖排弹性动画菜单效果
2015/10/14 Javascript
Bootstrap精简教程
2015/11/27 Javascript
jQuery Validate验证框架详解(推荐)
2016/12/17 Javascript
值得分享和收藏的xmlplus组件学习教程
2017/05/05 Javascript
解决Vue2.0自带浏览器里无法打开的原因(兼容处理)
2017/07/28 Javascript
Vue.js watch监视属性知识点总结
2019/11/11 Javascript
jQuery+PHP+Ajax实现动态数字统计展示功能
2019/12/25 jQuery
[03:48]DOTA2完美大师赛主赛事第二日精彩集锦
2017/11/24 DOTA
Python是编译运行的验证方法
2015/01/30 Python
Python最火、R极具潜力 2017机器学习调查报告
2017/12/11 Python
Python语言描述连续子数组的最大和
2018/01/04 Python
Python3字符串encode与decode的讲解
2019/04/02 Python
Python3.5集合及其常见运算实例详解
2019/05/01 Python
python常用库之NumPy和sklearn入门
2019/07/11 Python
Python3内置模块random随机方法小结
2019/07/13 Python
Python实现密钥密码(加解密)实例详解
2020/04/26 Python
澳大利亚婴儿、幼儿和儿童在线设计师商店:Smooch Baby
2019/02/16 全球购物
自查自纠工作总结
2014/10/15 职场文书
2014年煤矿工人工作总结
2014/12/08 职场文书
应聘教师求职信范文
2015/03/20 职场文书
2015年前台文员工作总结
2015/05/18 职场文书
Python实现Telnet自动连接检测密码的示例
2021/04/16 Python
对Golang中的FORM相关字段理解
2021/05/02 Golang
详解JS ES6编码规范
2021/05/07 Javascript
Python 中的Sympy详细使用
2021/08/07 Python
Apache Hudi 加速传统的批处理模式
2022/04/24 Servers