keras多显卡训练方式


Posted in Python onJune 10, 2020

使用keras进行训练,默认使用单显卡,即使设置了os.environ['CUDA_VISIBLE_DEVICES']为两张显卡,也只是占满了显存,再设置tf.GPUOptions(allow_growth=True)之后可以清楚看到,只占用了第一张显卡,第二张显卡完全没用。

要使用多张显卡,需要按如下步骤:

(1)import multi_gpu_model函数:from keras.utils import multi_gpu_model

(2)在定义好model之后,使用multi_gpu_model设置模型由几张显卡训练,如下:

model=Model(...) #定义模型结构
model_parallel=multi_gpu_model(model,gpu=n) #使用几张显卡n等于几
model_parallel.compile(...) #注意是model_parallel,不是model

通过以上代码,model将作为CPU上的原始模型,而model_parallel将作为拷贝模型被复制到各个GPU上进行梯度计算。如果batchsize为128,显卡n=2,则每张显卡单独计算128/2=64张图像,然后在CPU上将两张显卡计算得到的梯度进行融合更新,并对模型权重进行更新后再将新模型拷贝到GPU再次训练。

(3)从上面可以看出,进行训练时,仍然在model_parallel上进行:

model_parallel.fit(...) #注意是model_parallel

(4)保存模型时,model_parallel保存了训练时显卡数量的信息,所以如果直接保存model_parallel的话,只能将模型设置为相同数量的显卡调用,否则训练的模型将不能调用。因此,为了之后的调用方便,只保存CPU上的模型,即model:

model.save(...) #注意是model,不是model_parallel

如果用到了callback函数,则默认保存的也是model_parallel(因为训练函数是针对model_parallel的),所以要用回调函数保存model的话需要自己对回调函数进行定义:

class OwnCheckpoint(keras.callbacks.Callback):
 def __init__(self,model):
  self.model_to_save=model
 def on_epoch_end(self,epoch,logs=None): #这里logs必须写
  self.model_to_save.save('model_advanced/model_%d.h5' % epoch)

定以后具体使用如下:

checkpoint=OwnCheckpoint(model)
model_parallel.fit_generator(...,callbacks=[checkpoint])

这样就没问题了!

补充知识:keras.fit_generator及多卡训练记录

1.环境问题

使用keras,以tensorflow为背景,tensorflow1.14多卡训练会出错 python3.6

2.代码

2.1

os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
os.environ['CUDA_VISIBLE_DEVICES'] = '4,5'

2.2 自定义generator函数

def img_image_generator(path_img, path_lab, batch_size, data_list):
 while True:
 # 'train_list.csv'
 file_list = pd.read_csv(data_list, sep=',',usecols=[1]).values.tolist()
 file_list = [i[0] for i in file_list]
 cnt = 0
 X = []
 Y1 = []
 for file_i in file_list:
 x = cv2.imread(path_img+'/'+file_i, cv2.IMREAD_GRAYSCALE)
 x = x.astype('float32')
 x /= 255.
 y = cv2.imread(path_lab+'/'+file_i, cv2.IMREAD_GRAYSCALE)
 y = y.astype('float32')
 y /= 255.
 X.append(x.reshape(256, 256, 1))
 Y1.append(y.reshape(256, 256, 1))
 cnt += 1
 if cnt == batch_size:
 cnt = 0
 yield (np.array(X), [np.array(Y1), np.array(Y1)])
 X = []
 Y1 = []

2.3 函数调用及训练

generator_train = img_image_generator(path1, path2, 4, pathcsv_train)
 generator_test= img_image_generator(path1, path2, 4, pathcsv_test)
 model.fit_generator(generator_train, steps_per_epoch=237*2, epochs=50, callbacks=callbacks_list, validation_data=generator_test, validation_steps=60*2)

3. 多卡训练

3.1 复制model

model_parallel = multi_gpu_model(model, gpus=2)

3.2 checkpoint 定义

class ParallelModelCheckpoint(ModelCheckpoint):
  def __init__(self, model, filepath, monitor='val_out_final_score', verbose=0,\
   save_best_only=False, save_weights_only=False, mode='auto', period=1):
   self.single_model = model 
   super(ParallelModelCheckpoint, self).__init__(filepath, monitor, verbose, save_best_only, save_weights_only, mode, period)
  
  def set_model(self, model):
   super(ParallelModelCheckpoint, self).set_model(self.single_model)

使用

model_checkpoint = ParallelModelCheckpoint(model=model, filepath=filepath, monitor='val_loss',verbose=1, save_best_only=True, mode='min')

3.3 注意的问题

保存模型是时候需要使用以原来的模型保存,不能使用model_parallel保存

以上这篇keras多显卡训练方式就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
详解在Python程序中使用Cookie的教程
Apr 30 Python
python获取指定路径下所有指定后缀文件的方法
May 26 Python
Python黑魔法Descriptor描述符的实例解析
Jun 02 Python
python类中super()和__init__()的区别
Oct 18 Python
解决seaborn在pycharm中绘图不出图的问题
May 24 Python
Python 实现微信防撤回功能
Apr 29 Python
python实现列表中最大最小值输出的示例
Jul 09 Python
使用wxpy实现自动发送微信消息功能
Feb 28 Python
django的403/404/500错误自定义页面的配置方式
May 21 Python
什么是python的必选参数
Jun 21 Python
Django REST Swagger实现指定api参数
Jul 07 Python
解决python存数据库速度太慢的问题
Apr 23 Python
使用SQLAlchemy操作数据库表过程解析
Jun 10 #Python
keras 多gpu并行运行案例
Jun 10 #Python
Keras自定义IOU方式
Jun 10 #Python
Python实现在线批量美颜功能过程解析
Jun 10 #Python
浅谈keras中的目标函数和优化函数MSE用法
Jun 10 #Python
keras 解决加载lstm+crf模型出错的问题
Jun 10 #Python
使用Keras加载含有自定义层或函数的模型操作
Jun 10 #Python
You might like
一步一步学习PHP(5) 类和对象
2010/02/16 PHP
PHP 截取字符串函数整理(支持gb2312和utf-8)
2010/02/16 PHP
CI框架文件上传类及图像处理类用法分析
2016/05/18 PHP
thinkPHP中session()方法用法详解
2016/12/08 PHP
php实现base64图片上传方式实例代码
2017/02/22 PHP
jquery中交替点击事件toggle方法的使用示例
2013/12/08 Javascript
扩展jQuery对象时如何扩展成员变量具体怎么实现
2014/04/25 Javascript
JS获取浏览器语言动态加载JS文件示例代码
2014/10/31 Javascript
对比分析AngularJS中的$http.post与jQuery.post的区别
2015/02/27 Javascript
javascript中JSON对象与JSON字符串相互转换实例
2015/07/11 Javascript
IE10中flexigrid无法显示数据的解决方法
2015/07/26 Javascript
jquery按回车键实现表单提交的简单实例
2016/05/25 Javascript
node.js报错:Cannot find module 'ejs'的解决办法
2016/12/14 Javascript
JavaScript简介_动力节点Java学院整理
2017/06/26 Javascript
推荐10款扩展Web表单的JS插件
2017/12/25 Javascript
Javascript的console['']常用输入方法汇总
2018/04/26 Javascript
微信小程序配置服务器提示验证token失败的解决方法
2019/04/03 Javascript
微信小程序中显示倒计时代码实例
2019/05/09 Javascript
小程序实现搜索界面 小程序实现推荐搜索列表效果
2019/05/18 Javascript
[02:36]DOTA2-DPC中国联赛 正赛 PSG.LGD vs Magma 选手采访
2021/03/11 DOTA
tensorflow学习笔记之简单的神经网络训练和测试
2018/04/15 Python
python使用代理ip访问网站的实例
2018/05/07 Python
Pycharm 实现下一个文件引用另外一个文件的方法
2019/01/17 Python
Python 通过微信控制实现app定位发送到个人服务器再转发微信服务器接收位置信息
2019/08/05 Python
Python项目跨域问题解决方案
2020/06/22 Python
美国网上花店:JustFlowers
2017/02/12 全球购物
汇集了世界上最好的天然和有机美容产品:LoveLula
2018/02/05 全球购物
微软台湾官方网站:Microsoft台湾
2018/08/15 全球购物
C语言中break与continue的区别
2012/07/12 面试题
教导处工作制度
2014/01/18 职场文书
领班岗位职责范文
2014/02/06 职场文书
2014年圣诞节促销方案
2014/03/14 职场文书
2014年平安夜寄语
2014/12/08 职场文书
婚前保证书范文
2015/02/28 职场文书
nginx简单配置多个server的方法
2021/03/31 Servers
mysql优化之query_cache_limit参数说明
2021/07/01 MySQL