完美解决TensorFlow和Keras大数据量内存溢出的问题


Posted in Python onJuly 03, 2020

内存溢出问题是参加kaggle比赛或者做大数据量实验的第一个拦路虎。

以前做的练手小项目导致新手产生一个惯性思维——读取训练集图片的时候把所有图读到内存中,然后分批训练。

其实这是有问题的,很容易导致OOM。现在内存一般16G,而训练集图片通常是上万张,而且RGB图,还很大,VGG16的图片一般是224x224x3,上万张图片,16G内存根本不够用。这时候又会想起——设置batch,但是那个batch的输入参数却又是图片,它只是把传进去的图片分批送到显卡,而我OOM的地方恰是那个“传进去”的图片,怎么办?

解决思路其实说来也简单,打破思维定式就好了,不是把所有图片读到内存中,而是只把所有图片的路径一次性读到内存中。

大致的解决思路为:

将上万张图片的路径一次性读到内存中,自己实现一个分批读取函数,在该函数中根据自己的内存情况设置读取图片,只把这一批图片读入内存中,然后交给模型,模型再对这一批图片进行分批训练,因为内存一般大于等于显存,所以内存的批次大小和显存的批次大小通常不相同。

下面代码分别介绍Tensorflow和Keras分批将数据读到内存中的关键函数。Tensorflow对初学者不太友好,所以我个人现阶段更习惯用它的高层API Keras来做相关项目,下面的TF实现是之前不会用Keras分批读时候参考的一些列资料,在模型训练上仍使用Keras,只有分批读取用了TF的API。

Tensorlow

在input.py里写get_batch函数。

def get_batch(X_train, y_train, img_w, img_h, color_type, batch_size, capacity):
  '''
  Args:
    X_train: train img path list
    y_train: train labels list
    img_w: image width
    img_h: image height
    batch_size: batch size
    capacity: the maximum elements in queue
  Returns:
    X_train_batch: 4D tensor [batch_size, width, height, chanel],\
            dtype=tf.float32
    y_train_batch: 1D tensor [batch_size], dtype=int32
  '''
  X_train = tf.cast(X_train, tf.string)

  y_train = tf.cast(y_train, tf.int32)
  
  # make an input queue
  input_queue = tf.train.slice_input_producer([X_train, y_train])

  y_train = input_queue[1]
  X_train_contents = tf.read_file(input_queue[0])
  X_train = tf.image.decode_jpeg(X_train_contents, channels=color_type)

  X_train = tf.image.resize_images(X_train, [img_h, img_w], 
                   tf.image.ResizeMethod.NEAREST_NEIGHBOR)

  X_train_batch, y_train_batch = tf.train.batch([X_train, y_train],
                         batch_size=batch_size,
                         num_threads=64,
                         capacity=capacity)
  y_train_batch = tf.one_hot(y_train_batch, 10)

  return X_train_batch, y_train_batch

在train.py文件中训练(下面不是纯TF代码,model.fit是Keras的拟合,用纯TF的替换就好了)。

X_train_batch, y_train_batch = inp.get_batch(X_train, y_train, 
                       img_w, img_h, color_type, 
                       train_batch_size, capacity)
X_valid_batch, y_valid_batch = inp.get_batch(X_valid, y_valid, 
                       img_w, img_h, color_type, 
                       valid_batch_size, capacity)
with tf.Session() as sess:

  coord = tf.train.Coordinator()
  threads = tf.train.start_queue_runners(coord=coord)
  try:
    for step in np.arange(max_step):
      if coord.should_stop() :
        break
      X_train, y_train = sess.run([X_train_batch, 
                       y_train_batch])
      X_valid, y_valid = sess.run([X_valid_batch,
                       y_valid_batch])
       
      ckpt_path = 'log/weights-{val_loss:.4f}.hdf5'
      ckpt = tf.keras.callbacks.ModelCheckpoint(ckpt_path, 
                           monitor='val_loss', 
                           verbose=1, 
                           save_best_only=True, 
                           mode='min')
      model.fit(X_train, y_train, batch_size=64, 
             epochs=50, verbose=1,
             validation_data=(X_valid, y_valid),
             callbacks=[ckpt])
      
      del X_train, y_train, X_valid, y_valid

  except tf.errors.OutOfRangeError:
    print('done!')
  finally:
    coord.request_stop()
  coord.join(threads)
  sess.close()

Keras

keras文档中对fit、predict、evaluate这些函数都有一个generator,这个generator就是解决分批问题的。

关键函数:fit_generator

# 读取图片函数
def get_im_cv2(paths, img_rows, img_cols, color_type=1, normalize=True):
  '''
  参数:
    paths:要读取的图片路径列表
    img_rows:图片行
    img_cols:图片列
    color_type:图片颜色通道
  返回: 
    imgs: 图片数组
  '''
  # Load as grayscale
  imgs = []
  for path in paths:
    if color_type == 1:
      img = cv2.imread(path, 0)
    elif color_type == 3:
      img = cv2.imread(path)
    # Reduce size
    resized = cv2.resize(img, (img_cols, img_rows))
    if normalize:
      resized = resized.astype('float32')
      resized /= 127.5
      resized -= 1. 
    
    imgs.append(resized)
    
  return np.array(imgs).reshape(len(paths), img_rows, img_cols, color_type)

获取批次函数,其实就是一个generator

def get_train_batch(X_train, y_train, batch_size, img_w, img_h, color_type, is_argumentation):
  '''
  参数:
    X_train:所有图片路径列表
    y_train: 所有图片对应的标签列表
    batch_size:批次
    img_w:图片宽
    img_h:图片高
    color_type:图片类型
    is_argumentation:是否需要数据增强
  返回: 
    一个generator,x: 获取的批次图片 y: 获取的图片对应的标签
  '''
  while 1:
    for i in range(0, len(X_train), batch_size):
      x = get_im_cv2(X_train[i:i+batch_size], img_w, img_h, color_type)
      y = y_train[i:i+batch_size]
      if is_argumentation:
        # 数据增强
        x, y = img_augmentation(x, y)
      # 最重要的就是这个yield,它代表返回,返回以后循环还是会继续,然后再返回。就比如有一个机器一直在作累加运算,但是会把每次累加中间结果告诉你一样,直到把所有数加完
      yield({'input': x}, {'output': y})

训练函数

result = model.fit_generator(generator=get_train_batch(X_train, y_train, train_batch_size, img_w, img_h, color_type, True), 
     steps_per_epoch=1351, 
     epochs=50, verbose=1,
     validation_data=get_train_batch(X_valid, y_valid, valid_batch_size,img_w, img_h, color_type, False),
     validation_steps=52,
     callbacks=[ckpt, early_stop],
     max_queue_size=capacity,
     workers=1)

就是这么简单。但是当初从0到1的过程很难熬,每天都没有进展,没有头绪,急躁占据了思维的大部,熬过了这个阶段,就会一切顺利,不是运气,而是踩过的从0到1的每个脚印累积的灵感的爆发,从0到1的脚印越多,后面的路越顺利。

以上这篇完美解决TensorFlow和Keras大数据量内存溢出的问题就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
python 实现在txt指定行追加文本的方法
Apr 29 Python
python requests 测试代理ip是否生效
Jul 25 Python
用python爬取租房网站信息的代码
Dec 14 Python
pybind11和numpy进行交互的方法
Jul 04 Python
详解Pandas之容易让人混淆的行选择和列选择
Jul 10 Python
10分钟教你用python动画演示深度优先算法搜寻逃出迷宫的路径
Aug 12 Python
Python Django 前后端分离 API的方法
Aug 28 Python
python加密解密库cryptography使用openSSL生成的密匙加密解密
Feb 11 Python
python 通过文件夹导入包的操作
Jun 01 Python
python爬虫爬取淘宝商品比价(附淘宝反爬虫机制解决小办法)
Dec 03 Python
linux系统下pip升级报错的解决方法
Jan 31 Python
Python中字符串对象语法分享
Feb 24 Python
Keras 在fit_generator训练方式中加入图像random_crop操作
Jul 03 #Python
keras的三种模型实现与区别说明
Jul 03 #Python
Keras中 ImageDataGenerator函数的参数用法
Jul 03 #Python
python程序如何进行保存
Jul 03 #Python
keras的ImageDataGenerator和flow()的用法说明
Jul 03 #Python
python如何安装下载后的模块
Jul 03 #Python
python中id函数运行方式
Jul 03 #Python
You might like
编写漂亮的代码 - 将后台程序与前端程序分开
2008/04/23 PHP
用php实现让页面只能被百度gogole蜘蛛访问的方法
2009/12/29 PHP
探讨Hessian在PHP中的使用分析
2013/06/13 PHP
php中如何使对象可以像数组一样进行foreach循环
2013/08/09 PHP
简单解析PHP程序的运行流程
2016/06/23 PHP
php头像上传预览实例代码
2017/05/02 PHP
PHP安装memcache扩展的步骤讲解
2019/02/14 PHP
PHP实现的权重算法示例【可用于游戏根据权限来随机物品】
2019/02/15 PHP
jQuery oLoader实现的加载图片和页面效果
2015/03/14 Javascript
微信小程序 POST请求(网络请求)详解及实例代码
2016/11/16 Javascript
Vim快速合并行及vim 将文件所有行合并到一行
2017/11/27 Javascript
详解vue-cli之webpack3构建全面提速优化
2017/12/25 Javascript
vue2.0项目实现路由跳转的方法详解
2018/06/21 Javascript
vue百度地图 + 定位的详解
2019/05/13 Javascript
通过javascript实现扫雷游戏代码实例
2020/02/09 Javascript
[04:11]DOTA2亚洲邀请赛小组赛第一日 TOP10精彩集锦
2015/01/30 DOTA
python正则分组的应用
2013/11/10 Python
python函数返回多个值的示例方法
2013/12/04 Python
Python实现替换文件中指定内容的方法
2018/03/19 Python
Django使用Mysql数据库已经存在的数据表方法
2018/05/27 Python
PyQt5显示GIF图片的方法
2019/06/17 Python
python解析xml简单示例
2019/06/21 Python
python命令 -u参数用法解析
2019/10/24 Python
Python函数参数分类原理详解
2020/05/28 Python
keras中epoch,batch,loss,val_loss用法说明
2020/07/02 Python
Python批量删除mysql中千万级大量数据的脚本分享
2020/12/03 Python
html5 worker 实例(一) 为什么测试不到效果
2013/06/24 HTML / CSS
美国户外运动商店:Sun & Ski
2018/08/23 全球购物
万宝龙英国官网:Montblanc手表、书写工具、皮革和珠宝
2018/10/16 全球购物
美国围栏公司:Walpole Outdoors
2019/11/19 全球购物
物理专业本科生自荐信
2014/01/30 职场文书
党支部承诺书范文
2014/03/28 职场文书
小学生三分钟演讲稿
2014/08/18 职场文书
综合素质评价自我评价
2015/03/06 职场文书
关于实现中国梦的心得体会
2016/01/05 职场文书
六年级上册《闻官军收河南河北》的教学设计
2019/11/15 职场文书