keras多显卡训练方式


Posted in Python onJune 10, 2020

使用keras进行训练,默认使用单显卡,即使设置了os.environ['CUDA_VISIBLE_DEVICES']为两张显卡,也只是占满了显存,再设置tf.GPUOptions(allow_growth=True)之后可以清楚看到,只占用了第一张显卡,第二张显卡完全没用。

要使用多张显卡,需要按如下步骤:

(1)import multi_gpu_model函数:from keras.utils import multi_gpu_model

(2)在定义好model之后,使用multi_gpu_model设置模型由几张显卡训练,如下:

model=Model(...) #定义模型结构
model_parallel=multi_gpu_model(model,gpu=n) #使用几张显卡n等于几
model_parallel.compile(...) #注意是model_parallel,不是model

通过以上代码,model将作为CPU上的原始模型,而model_parallel将作为拷贝模型被复制到各个GPU上进行梯度计算。如果batchsize为128,显卡n=2,则每张显卡单独计算128/2=64张图像,然后在CPU上将两张显卡计算得到的梯度进行融合更新,并对模型权重进行更新后再将新模型拷贝到GPU再次训练。

(3)从上面可以看出,进行训练时,仍然在model_parallel上进行:

model_parallel.fit(...) #注意是model_parallel

(4)保存模型时,model_parallel保存了训练时显卡数量的信息,所以如果直接保存model_parallel的话,只能将模型设置为相同数量的显卡调用,否则训练的模型将不能调用。因此,为了之后的调用方便,只保存CPU上的模型,即model:

model.save(...) #注意是model,不是model_parallel

如果用到了callback函数,则默认保存的也是model_parallel(因为训练函数是针对model_parallel的),所以要用回调函数保存model的话需要自己对回调函数进行定义:

class OwnCheckpoint(keras.callbacks.Callback):
 def __init__(self,model):
  self.model_to_save=model
 def on_epoch_end(self,epoch,logs=None): #这里logs必须写
  self.model_to_save.save('model_advanced/model_%d.h5' % epoch)

定以后具体使用如下:

checkpoint=OwnCheckpoint(model)
model_parallel.fit_generator(...,callbacks=[checkpoint])

这样就没问题了!

补充知识:keras.fit_generator及多卡训练记录

1.环境问题

使用keras,以tensorflow为背景,tensorflow1.14多卡训练会出错 python3.6

2.代码

2.1

os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
os.environ['CUDA_VISIBLE_DEVICES'] = '4,5'

2.2 自定义generator函数

def img_image_generator(path_img, path_lab, batch_size, data_list):
 while True:
 # 'train_list.csv'
 file_list = pd.read_csv(data_list, sep=',',usecols=[1]).values.tolist()
 file_list = [i[0] for i in file_list]
 cnt = 0
 X = []
 Y1 = []
 for file_i in file_list:
 x = cv2.imread(path_img+'/'+file_i, cv2.IMREAD_GRAYSCALE)
 x = x.astype('float32')
 x /= 255.
 y = cv2.imread(path_lab+'/'+file_i, cv2.IMREAD_GRAYSCALE)
 y = y.astype('float32')
 y /= 255.
 X.append(x.reshape(256, 256, 1))
 Y1.append(y.reshape(256, 256, 1))
 cnt += 1
 if cnt == batch_size:
 cnt = 0
 yield (np.array(X), [np.array(Y1), np.array(Y1)])
 X = []
 Y1 = []

2.3 函数调用及训练

generator_train = img_image_generator(path1, path2, 4, pathcsv_train)
 generator_test= img_image_generator(path1, path2, 4, pathcsv_test)
 model.fit_generator(generator_train, steps_per_epoch=237*2, epochs=50, callbacks=callbacks_list, validation_data=generator_test, validation_steps=60*2)

3. 多卡训练

3.1 复制model

model_parallel = multi_gpu_model(model, gpus=2)

3.2 checkpoint 定义

class ParallelModelCheckpoint(ModelCheckpoint):
  def __init__(self, model, filepath, monitor='val_out_final_score', verbose=0,\
   save_best_only=False, save_weights_only=False, mode='auto', period=1):
   self.single_model = model 
   super(ParallelModelCheckpoint, self).__init__(filepath, monitor, verbose, save_best_only, save_weights_only, mode, period)
  
  def set_model(self, model):
   super(ParallelModelCheckpoint, self).set_model(self.single_model)

使用

model_checkpoint = ParallelModelCheckpoint(model=model, filepath=filepath, monitor='val_loss',verbose=1, save_best_only=True, mode='min')

3.3 注意的问题

保存模型是时候需要使用以原来的模型保存,不能使用model_parallel保存

以上这篇keras多显卡训练方式就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
使用Python生成url短链接的方法
May 04 Python
总结python爬虫抓站的实用技巧
Aug 09 Python
Django如何实现内容缓存示例详解
Sep 24 Python
浅谈python爬虫使用Selenium模拟浏览器行为
Feb 23 Python
python matplotlib画图库学习绘制常用的图
Mar 19 Python
Python中Numpy ndarray的使用详解
May 24 Python
python机器学习库scikit-learn:SVR的基本应用
Jun 26 Python
python实现动态数组的示例代码
Jul 15 Python
Python 实用技巧之利用Shell通配符做字符串匹配
Aug 23 Python
python实现从尾到头打印单链表操作示例
Feb 22 Python
PyQt5连接MySQL及QMYSQL driver not loaded错误解决
Apr 29 Python
python中xlutils库用法浅析
Dec 29 Python
使用SQLAlchemy操作数据库表过程解析
Jun 10 #Python
keras 多gpu并行运行案例
Jun 10 #Python
Keras自定义IOU方式
Jun 10 #Python
Python实现在线批量美颜功能过程解析
Jun 10 #Python
浅谈keras中的目标函数和优化函数MSE用法
Jun 10 #Python
keras 解决加载lstm+crf模型出错的问题
Jun 10 #Python
使用Keras加载含有自定义层或函数的模型操作
Jun 10 #Python
You might like
一个简单的MySQL数据浏览器
2006/10/09 PHP
一个分页的论坛
2006/10/09 PHP
php下将图片以二进制存入mysql数据库中并显示的实现代码
2010/05/27 PHP
php empty,isset,is_null判断比较(差异与异同)
2010/10/19 PHP
解析关于java,php以及html的所有文件编码与乱码的处理方法汇总
2013/06/24 PHP
PHP图片等比缩放类SimpleImage使用方法和使用实例分享
2014/04/10 PHP
php简单获取文件扩展名的方法
2015/03/24 PHP
php基于jquery的ajax技术传递json数据简单实例
2016/04/15 PHP
php命名空间设计思想、用法与缺点分析
2019/07/17 PHP
Yii Framework框架使用PHPExcel组件的方法示例
2019/07/24 PHP
Prototype Class对象学习
2009/07/19 Javascript
网易JS面试题与Javascript词法作用域说明
2010/11/09 Javascript
跨浏览器的事件对象介绍
2012/06/27 Javascript
javascript使用window.open提示“已经计划系统关机”的原因
2014/08/15 Javascript
第五章之BootStrap 栅格系统
2016/04/25 Javascript
深入理解Angular2 模板语法
2016/08/07 Javascript
JS中的hasOwnProperty()和isPrototypeOf()属性实例详解
2016/08/11 Javascript
webpack打包js文件及部署的实现方法
2017/12/18 Javascript
使用vue-cli导入Element UI组件的方法
2018/05/16 Javascript
angular6.0开发教程之如何安装angular6.0框架
2018/06/29 Javascript
实例分析vue循环列表动态数据的处理方法
2018/09/28 Javascript
python with statement 进行文件操作指南
2014/08/22 Python
简单上手Python中装饰器的使用
2015/07/12 Python
python+mongodb数据抓取详细介绍
2017/10/25 Python
总结python中pass的作用
2019/02/27 Python
Python列表倒序输出及其效率详解
2020/03/04 Python
李维斯德国官方网上商店:Levi’s德国
2016/09/10 全球购物
伦敦剧院门票:London Theatre Direct
2018/11/21 全球购物
shell变量的作用空间是什么
2013/08/17 面试题
医院门卫岗位职责
2013/12/30 职场文书
十八大演讲稿
2014/05/22 职场文书
2014年机关作风建设工作总结
2014/10/23 职场文书
领导班子整改方案
2014/10/25 职场文书
员工工作及收入证明
2014/10/28 职场文书
2015年宣传部工作总结范文
2015/03/31 职场文书
一次Mysql update sql不当引起的生产故障记录
2022/04/01 MySQL