解决Tensorflow sess.run导致的内存溢出问题


Posted in Python onFebruary 05, 2020

下面是调用模型进行批量测试的代码(出现溢出),开始以为导致溢出的原因是数据读入方式问题引起的,用了tf , PIL和cv等方式读入图片数据,发现越来越慢,内存占用飙升,调试时发现是sess.run这里出了问题(随着for循环进行速度越来越慢)。

# Creates graph from saved GraphDef
  create_graph(pb_path)
 
  # Init tf Session
  config = tf.ConfigProto()
  config.gpu_options.allow_growth = True
  sess = tf.Session(config=config)
  init = tf.global_variables_initializer()
  sess.run(init)
 
 
  input_image_tensor = sess.graph.get_tensor_by_name("create_inputs/batch:0") 
  output_tensor_name = sess.graph.get_tensor_by_name("conv6/out_1:0") 
 
 
  for filename in os.listdir(image_dir):
    image_path = os.path.join(image_dir, filename)
 
    start = time.time()
    image_data = cv2.imread(image_path)
    image_data = cv2.resize(image_data, (w, h))
    image_data_1 = image_data - IMG_MEAN
    input_image = np.expand_dims(image_data_1, 0)
 
    raw_output_up = tf.image.resize_bilinear(output_tensor_name, size=[h, w], align_corners=True) 
    raw_output_up = tf.argmax(raw_output_up, axis=3)
    
 
    predict_img = sess.run(raw_output_up, feed_dict={input_image_tensor: input_image})    # 1,height,width
    predict_img = np.squeeze(predict_img)   # height, width 
 
    voc_palette = visual.make_palette(3)
    masked_im = visual.vis_seg(image_data, predict_img, voc_palette)
    cv2.imwrite("%s_pred.png" % (save_dir + filename.split(".")[0]), masked_im)
 
 
    print(time.time() - start)
 
  print(">>>>>>Done")

下面是解决溢出问题的代码(将部分代码放在for循环外

# Creates graph from saved GraphDef
  create_graph(pb_path)
 
  # Init tf Session
  config = tf.ConfigProto()
  config.gpu_options.allow_growth = True
  sess = tf.Session(config=config)
  init = tf.global_variables_initializer()
  sess.run(init)
 
  input_image_tensor = sess.graph.get_tensor_by_name("create_inputs/batch:0") 
  output_tensor_name = sess.graph.get_tensor_by_name("conv6/out_1:0") 
  
##############################################################################################################
  raw_output_up = tf.image.resize_bilinear(output_tensor_name, size=[h, w], align_corners=True) 
  raw_output_up = tf.argmax(raw_output_up, axis=3)
##############################################################################################################
 
  for filename in os.listdir(image_dir):
    image_path = os.path.join(image_dir, filename)
 
    start = time.time()
    image_data = cv2.imread(image_path)
    image_data = cv2.resize(image_data, (w, h))
    image_data_1 = image_data - IMG_MEAN
    input_image = np.expand_dims(image_data_1, 0)
    
    predict_img = sess.run(raw_output_up, feed_dict={input_image_tensor: input_image})    # 1,height,width
    predict_img = np.squeeze(predict_img)   # height, width 
 
    voc_palette = visual.make_palette(3)
    masked_im = visual.vis_seg(image_data, predict_img, voc_palette)
    cv2.imwrite("%s_pred.png" % (save_dir + filename.split(".")[0]), masked_im)
    print(time.time() - start)
 
  print(">>>>>>Done")

总结:

在迭代过程中, 在sess.run的for循环中不要加入tensorflow一些op操作,会增加图节点,否则随着迭代的进行,tf的图会越来越大,最终导致溢出;

建议不要使用tf.gfile.FastGFile(image_path, 'rb').read()读入数据(有可能会造成溢出),用opencv之类读取。

以上这篇解决Tensoflow sess.run导致的内存溢出问题就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
python中input()与raw_input()的区别分析
Feb 27 Python
Linux中Python 环境软件包安装步骤
Mar 31 Python
Python编程实现二分法和牛顿迭代法求平方根代码
Dec 04 Python
Python排序搜索基本算法之归并排序实例分析
Dec 08 Python
Django自定义过滤器定义与用法示例
Mar 22 Python
Django使用模板后无法找到静态资源文件问题解决
Jul 19 Python
解析python的局部变量和全局变量
Aug 15 Python
使用pyinstaller逆向.pyc文件
Dec 20 Python
python数据库操作mysql:pymysql、sqlalchemy常见用法详解
Mar 30 Python
解决Python spyder显示不全df列和行的问题
Apr 20 Python
Jupyter Notebook 实现正常显示中文和负号
Apr 24 Python
详解Python+OpenCV绘制灰度直方图
Mar 22 Python
解决TensorFlow训练内存不断增长,进程被杀死问题
Feb 05 #Python
浅谈tensorflow之内存暴涨问题
Feb 05 #Python
对Tensorflow中Device实例的生成和管理详解
Feb 04 #Python
关于windows下Tensorflow和pytorch安装教程
Feb 04 #Python
django3.02模板中的超链接配置实例代码
Feb 04 #Python
tensorflow自定义激活函数实例
Feb 04 #Python
pytorch对梯度进行可视化进行梯度检查教程
Feb 04 #Python
You might like
php下利用curl判断远程文件是否存在的实现代码
2011/10/08 PHP
YII中assets的使用示例
2014/07/31 PHP
PHP回溯法解决0-1背包问题实例分析
2015/03/23 PHP
PHP实现事件机制实例分析
2015/06/26 PHP
php实现的mysqldb读写分离操作类示例
2017/02/07 PHP
tp5(thinkPHP5框架)使用DB实现批量删除功能示例
2019/05/28 PHP
Jquery读取URL参数小例子
2013/08/30 Javascript
JavaScript极简入门教程(三):数组
2014/10/25 Javascript
javascript自动生成包含数字与字符的随机字符串
2015/02/09 Javascript
Javascript中级语法快速入手
2016/07/30 Javascript
Angular外部使用js调用Angular控制器中的函数方法或变量用法示例
2016/08/05 Javascript
jquery判断类型是不是number类型的实例代码
2016/10/07 Javascript
分分钟学会vue中vuex的应用(入门教程)
2017/09/14 Javascript
VueJS事件处理器v-on的使用方法
2017/09/27 Javascript
详解Angular5 服务端渲染实战
2018/01/04 Javascript
vue中render函数的使用详解
2018/10/12 Javascript
vue.draggable实现表格拖拽排序效果
2018/12/01 Javascript
JS实现处理时间,年月日,星期的公共方法示例
2019/05/31 Javascript
基于html+css+js实现简易计算器代码实例
2020/02/28 Javascript
JS+CSS实现3D切割轮播图
2020/03/21 Javascript
html-webpack-plugin修改页面的title的方法
2020/06/18 Javascript
Vue+Java 通过websocket实现服务器与客户端双向通信操作
2020/09/22 Javascript
python append、extend与insert的区别
2016/10/13 Python
Python运维自动化之nginx配置文件对比操作示例
2018/08/29 Python
python实现汉诺塔算法
2021/03/01 Python
如何卸载python插件
2020/07/08 Python
PyCharm2020最新激活码+激活码补丁(亲测最新版PyCharm2020.2激活成功)
2020/11/25 Python
详解python3 GUI刷屏器(附源码)
2021/02/18 Python
Linux的文件类型
2016/07/05 面试题
学生打架检讨书大全
2014/01/23 职场文书
寻找最美家庭活动方案
2014/08/20 职场文书
团委工作总结2015
2015/04/02 职场文书
旅游项目合作意向书
2015/05/08 职场文书
《法国号》教学反思
2016/02/22 职场文书
SQL实现LeetCode(196.删除重复邮箱)
2021/08/07 MySQL
css3 文字断裂效果
2022/04/22 HTML / CSS