keras多显卡训练方式


Posted in Python onJune 10, 2020

使用keras进行训练,默认使用单显卡,即使设置了os.environ['CUDA_VISIBLE_DEVICES']为两张显卡,也只是占满了显存,再设置tf.GPUOptions(allow_growth=True)之后可以清楚看到,只占用了第一张显卡,第二张显卡完全没用。

要使用多张显卡,需要按如下步骤:

(1)import multi_gpu_model函数:from keras.utils import multi_gpu_model

(2)在定义好model之后,使用multi_gpu_model设置模型由几张显卡训练,如下:

model=Model(...) #定义模型结构
model_parallel=multi_gpu_model(model,gpu=n) #使用几张显卡n等于几
model_parallel.compile(...) #注意是model_parallel,不是model

通过以上代码,model将作为CPU上的原始模型,而model_parallel将作为拷贝模型被复制到各个GPU上进行梯度计算。如果batchsize为128,显卡n=2,则每张显卡单独计算128/2=64张图像,然后在CPU上将两张显卡计算得到的梯度进行融合更新,并对模型权重进行更新后再将新模型拷贝到GPU再次训练。

(3)从上面可以看出,进行训练时,仍然在model_parallel上进行:

model_parallel.fit(...) #注意是model_parallel

(4)保存模型时,model_parallel保存了训练时显卡数量的信息,所以如果直接保存model_parallel的话,只能将模型设置为相同数量的显卡调用,否则训练的模型将不能调用。因此,为了之后的调用方便,只保存CPU上的模型,即model:

model.save(...) #注意是model,不是model_parallel

如果用到了callback函数,则默认保存的也是model_parallel(因为训练函数是针对model_parallel的),所以要用回调函数保存model的话需要自己对回调函数进行定义:

class OwnCheckpoint(keras.callbacks.Callback):
 def __init__(self,model):
  self.model_to_save=model
 def on_epoch_end(self,epoch,logs=None): #这里logs必须写
  self.model_to_save.save('model_advanced/model_%d.h5' % epoch)

定以后具体使用如下:

checkpoint=OwnCheckpoint(model)
model_parallel.fit_generator(...,callbacks=[checkpoint])

这样就没问题了!

补充知识:keras.fit_generator及多卡训练记录

1.环境问题

使用keras,以tensorflow为背景,tensorflow1.14多卡训练会出错 python3.6

2.代码

2.1

os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
os.environ['CUDA_VISIBLE_DEVICES'] = '4,5'

2.2 自定义generator函数

def img_image_generator(path_img, path_lab, batch_size, data_list):
 while True:
 # 'train_list.csv'
 file_list = pd.read_csv(data_list, sep=',',usecols=[1]).values.tolist()
 file_list = [i[0] for i in file_list]
 cnt = 0
 X = []
 Y1 = []
 for file_i in file_list:
 x = cv2.imread(path_img+'/'+file_i, cv2.IMREAD_GRAYSCALE)
 x = x.astype('float32')
 x /= 255.
 y = cv2.imread(path_lab+'/'+file_i, cv2.IMREAD_GRAYSCALE)
 y = y.astype('float32')
 y /= 255.
 X.append(x.reshape(256, 256, 1))
 Y1.append(y.reshape(256, 256, 1))
 cnt += 1
 if cnt == batch_size:
 cnt = 0
 yield (np.array(X), [np.array(Y1), np.array(Y1)])
 X = []
 Y1 = []

2.3 函数调用及训练

generator_train = img_image_generator(path1, path2, 4, pathcsv_train)
 generator_test= img_image_generator(path1, path2, 4, pathcsv_test)
 model.fit_generator(generator_train, steps_per_epoch=237*2, epochs=50, callbacks=callbacks_list, validation_data=generator_test, validation_steps=60*2)

3. 多卡训练

3.1 复制model

model_parallel = multi_gpu_model(model, gpus=2)

3.2 checkpoint 定义

class ParallelModelCheckpoint(ModelCheckpoint):
  def __init__(self, model, filepath, monitor='val_out_final_score', verbose=0,\
   save_best_only=False, save_weights_only=False, mode='auto', period=1):
   self.single_model = model 
   super(ParallelModelCheckpoint, self).__init__(filepath, monitor, verbose, save_best_only, save_weights_only, mode, period)
  
  def set_model(self, model):
   super(ParallelModelCheckpoint, self).set_model(self.single_model)

使用

model_checkpoint = ParallelModelCheckpoint(model=model, filepath=filepath, monitor='val_loss',verbose=1, save_best_only=True, mode='min')

3.3 注意的问题

保存模型是时候需要使用以原来的模型保存,不能使用model_parallel保存

以上这篇keras多显卡训练方式就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
将Python中的数据存储到系统本地的简单方法
Apr 11 Python
C#返回当前系统所有可用驱动器符号的方法
Apr 18 Python
Python运算符重载用法实例
May 28 Python
详解Python中heapq模块的用法
Jun 28 Python
Python面向对象class类属性及子类用法分析
Feb 02 Python
机器学习之KNN算法原理及Python实现方法详解
Jul 09 Python
python随机在一张图像上截取任意大小图片的方法
Jan 24 Python
用python一行代码得到数组中某个元素的个数方法
Jan 28 Python
零基础使用Python读写处理Excel表格的方法
May 02 Python
Python 使用type来定义类的实现
Nov 19 Python
Python实现把多维数组展开成DataFrame
Nov 30 Python
Python如何读写字节数据
Aug 05 Python
使用SQLAlchemy操作数据库表过程解析
Jun 10 #Python
keras 多gpu并行运行案例
Jun 10 #Python
Keras自定义IOU方式
Jun 10 #Python
Python实现在线批量美颜功能过程解析
Jun 10 #Python
浅谈keras中的目标函数和优化函数MSE用法
Jun 10 #Python
keras 解决加载lstm+crf模型出错的问题
Jun 10 #Python
使用Keras加载含有自定义层或函数的模型操作
Jun 10 #Python
You might like
PHP入门学习的几个不错的实例代码
2008/07/13 PHP
php下mysql数据库操作类(改自discuz)
2010/07/03 PHP
PHP基于单例模式编写PDO类的方法
2016/09/13 PHP
详谈phpAdmin修改密码后拒绝访问的问题
2017/04/03 PHP
使用php自动备份数据库表的实现方法
2017/07/28 PHP
PHP开发api接口安全验证操作实例详解
2020/03/26 PHP
javascript YUI 读码日记之 YAHOO.util.Dom - Part.4
2008/03/22 Javascript
javascript 模拟点击广告
2010/01/02 Javascript
javascript当onmousedown、onmouseup、onclick同时应用于同一个标签节点Element
2010/01/05 Javascript
如何学习Javascript入门指导
2013/11/01 Javascript
javascript实现阻止iOS APP中的链接打开Safari浏览器
2014/06/12 Javascript
JS实现标签页切换效果
2017/05/04 Javascript
JS实现搜索关键词的智能提示功能
2017/07/07 Javascript
Vue下的国际化处理方法
2017/12/18 Javascript
基于nodejs res.end和res.send的区别
2018/05/14 NodeJs
vuex进阶知识点巩固
2018/05/20 Javascript
JS的函数调用栈stack size的计算方法
2018/06/24 Javascript
JavaScript设计模式之装饰者模式实例详解
2019/01/17 Javascript
快速了解Vue父子组件传值以及父调子方法、子调父方法
2020/07/15 Javascript
Vue执行方法,方法获取data值,设置data值,方法传值操作
2020/08/05 Javascript
Python封装原理与实现方法详解
2018/08/28 Python
python创建学生成绩管理系统
2019/11/22 Python
python GUI库图形界面开发之PyQt5线程类QThread详细使用方法
2020/02/26 Python
关于HTML5的22个初级技巧(图文教程)
2012/06/21 HTML / CSS
贝斯特韦斯特酒店集团官网:Best Western
2019/01/03 全球购物
腾讯公司的一个sql题
2013/01/22 面试题
cf收人广告词
2014/03/14 职场文书
十佳青年事迹材料
2014/08/21 职场文书
学校周年庆活动方案
2014/08/22 职场文书
2014年工商所工作总结
2014/12/09 职场文书
个人总结与自我评价2015
2015/03/11 职场文书
2016年基层党组织创先争优承诺书
2016/03/25 职场文书
高中班主任寄语
2019/06/21 职场文书
人生感悟经典句子
2019/08/20 职场文书
Python selenium模拟网页点击爬虫交管12123违章数据
2021/05/26 Python
Pygame Event事件模块的详细示例
2021/11/17 Python