解决Tensorflow sess.run导致的内存溢出问题


Posted in Python onFebruary 05, 2020

下面是调用模型进行批量测试的代码(出现溢出),开始以为导致溢出的原因是数据读入方式问题引起的,用了tf , PIL和cv等方式读入图片数据,发现越来越慢,内存占用飙升,调试时发现是sess.run这里出了问题(随着for循环进行速度越来越慢)。

# Creates graph from saved GraphDef
  create_graph(pb_path)
 
  # Init tf Session
  config = tf.ConfigProto()
  config.gpu_options.allow_growth = True
  sess = tf.Session(config=config)
  init = tf.global_variables_initializer()
  sess.run(init)
 
 
  input_image_tensor = sess.graph.get_tensor_by_name("create_inputs/batch:0") 
  output_tensor_name = sess.graph.get_tensor_by_name("conv6/out_1:0") 
 
 
  for filename in os.listdir(image_dir):
    image_path = os.path.join(image_dir, filename)
 
    start = time.time()
    image_data = cv2.imread(image_path)
    image_data = cv2.resize(image_data, (w, h))
    image_data_1 = image_data - IMG_MEAN
    input_image = np.expand_dims(image_data_1, 0)
 
    raw_output_up = tf.image.resize_bilinear(output_tensor_name, size=[h, w], align_corners=True) 
    raw_output_up = tf.argmax(raw_output_up, axis=3)
    
 
    predict_img = sess.run(raw_output_up, feed_dict={input_image_tensor: input_image})    # 1,height,width
    predict_img = np.squeeze(predict_img)   # height, width 
 
    voc_palette = visual.make_palette(3)
    masked_im = visual.vis_seg(image_data, predict_img, voc_palette)
    cv2.imwrite("%s_pred.png" % (save_dir + filename.split(".")[0]), masked_im)
 
 
    print(time.time() - start)
 
  print(">>>>>>Done")

下面是解决溢出问题的代码(将部分代码放在for循环外

# Creates graph from saved GraphDef
  create_graph(pb_path)
 
  # Init tf Session
  config = tf.ConfigProto()
  config.gpu_options.allow_growth = True
  sess = tf.Session(config=config)
  init = tf.global_variables_initializer()
  sess.run(init)
 
  input_image_tensor = sess.graph.get_tensor_by_name("create_inputs/batch:0") 
  output_tensor_name = sess.graph.get_tensor_by_name("conv6/out_1:0") 
  
##############################################################################################################
  raw_output_up = tf.image.resize_bilinear(output_tensor_name, size=[h, w], align_corners=True) 
  raw_output_up = tf.argmax(raw_output_up, axis=3)
##############################################################################################################
 
  for filename in os.listdir(image_dir):
    image_path = os.path.join(image_dir, filename)
 
    start = time.time()
    image_data = cv2.imread(image_path)
    image_data = cv2.resize(image_data, (w, h))
    image_data_1 = image_data - IMG_MEAN
    input_image = np.expand_dims(image_data_1, 0)
    
    predict_img = sess.run(raw_output_up, feed_dict={input_image_tensor: input_image})    # 1,height,width
    predict_img = np.squeeze(predict_img)   # height, width 
 
    voc_palette = visual.make_palette(3)
    masked_im = visual.vis_seg(image_data, predict_img, voc_palette)
    cv2.imwrite("%s_pred.png" % (save_dir + filename.split(".")[0]), masked_im)
    print(time.time() - start)
 
  print(">>>>>>Done")

总结:

在迭代过程中, 在sess.run的for循环中不要加入tensorflow一些op操作,会增加图节点,否则随着迭代的进行,tf的图会越来越大,最终导致溢出;

建议不要使用tf.gfile.FastGFile(image_path, 'rb').read()读入数据(有可能会造成溢出),用opencv之类读取。

以上这篇解决Tensoflow sess.run导致的内存溢出问题就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
python判断windows系统是32位还是64位的方法
May 11 Python
python使用邻接矩阵构造图代码示例
Nov 10 Python
Python函数any()和all()的用法及区别介绍
Sep 14 Python
Django Rest framework之认证的实现代码
Dec 17 Python
python实现支付宝转账接口
May 07 Python
Pandas 重塑(stack)和轴向旋转(pivot)的实现
Jul 22 Python
如何使用Python脚本实现文件拷贝
Nov 20 Python
python飞机大战 pygame游戏创建快速入门详解
Dec 17 Python
Python的控制结构之For、While、If循环问题
Jun 30 Python
python 常见的排序算法实现汇总
Aug 21 Python
Python制作简单的剪刀石头布游戏
Dec 10 Python
Python实现王者荣耀自动刷金币的完整步骤
Jan 22 Python
解决TensorFlow训练内存不断增长,进程被杀死问题
Feb 05 #Python
浅谈tensorflow之内存暴涨问题
Feb 05 #Python
对Tensorflow中Device实例的生成和管理详解
Feb 04 #Python
关于windows下Tensorflow和pytorch安装教程
Feb 04 #Python
django3.02模板中的超链接配置实例代码
Feb 04 #Python
tensorflow自定义激活函数实例
Feb 04 #Python
pytorch对梯度进行可视化进行梯度检查教程
Feb 04 #Python
You might like
PHP中几种常见的超时处理全面总结
2012/09/11 PHP
PHP实现提取一个图像文件并在浏览器上显示的代码
2012/10/06 PHP
实例简介PHP的一些高级面向对象编程的特性
2015/11/27 PHP
php输出控制函数和输出函数生成静态页面
2019/06/27 PHP
xml 封装与解析(javascript和C#中)
2009/07/26 Javascript
基于jQuery的前端数据通用验证库
2011/08/08 Javascript
javascript实现div浮动在网页最顶上并带关闭按钮效果实例
2013/08/13 Javascript
js history对象简单实现返回和前进
2013/10/30 Javascript
Jquery的each里用return true或false代替break或continue
2014/05/21 Javascript
js的touch事件的实际引用
2014/10/13 Javascript
javascript上下方向键控制表格行选中并高亮显示的方法
2015/02/13 Javascript
jQuery实现类似老虎机滚动抽奖效果
2015/08/06 Javascript
JS组件Bootstrap Select2使用方法解析
2016/05/30 Javascript
jQuery实现对无序列表的排序功能(附demo源码下载)
2016/06/25 Javascript
Bootstrap源码解读表单(2)
2016/12/22 Javascript
jQuery实现圣诞节礼物传送(花式轮播)
2016/12/25 Javascript
vue 之 .sync 修饰符示例详解
2018/04/21 Javascript
React native ListView 增加顶部下拉刷新和底下点击刷新示例
2018/04/27 Javascript
详解vue引入子组件方法
2019/02/12 Javascript
VueJS 取得 URL 参数值的方法
2019/07/19 Javascript
微信小程序页面调用自定义组件内的事件详解
2019/09/12 Javascript
pycharm 使用心得(四)显示行号
2014/06/05 Python
python采集博客中上传的QQ截图文件
2014/07/18 Python
Python多进程分块读取超大文件的方法
2016/04/13 Python
深入理解NumPy简明教程---数组1
2016/12/17 Python
pygame游戏之旅 计算游戏中躲过的障碍数量
2018/11/20 Python
Python 正则表达式匹配字符串中的http链接方法
2018/12/25 Python
Python调用服务接口的实例
2019/01/03 Python
快速解决Django关闭Debug模式无法加载media图片与static静态文件
2020/04/07 Python
如何使用Pytorch搭建模型
2020/10/26 Python
使用css如何制作时间ICON方法实践
2012/11/12 HTML / CSS
包装类的功能、种类、常用方法
2012/01/27 面试题
领导四风问题整改措施思想汇报
2014/10/13 职场文书
2015年七一建党节演讲稿
2015/03/19 职场文书
改进工作作风心得体会
2016/01/23 职场文书
电脑只能进入安全模式无法正常启动的解决办法
2022/04/08 数码科技