解决Tensorflow sess.run导致的内存溢出问题


Posted in Python onFebruary 05, 2020

下面是调用模型进行批量测试的代码(出现溢出),开始以为导致溢出的原因是数据读入方式问题引起的,用了tf , PIL和cv等方式读入图片数据,发现越来越慢,内存占用飙升,调试时发现是sess.run这里出了问题(随着for循环进行速度越来越慢)。

# Creates graph from saved GraphDef
  create_graph(pb_path)
 
  # Init tf Session
  config = tf.ConfigProto()
  config.gpu_options.allow_growth = True
  sess = tf.Session(config=config)
  init = tf.global_variables_initializer()
  sess.run(init)
 
 
  input_image_tensor = sess.graph.get_tensor_by_name("create_inputs/batch:0") 
  output_tensor_name = sess.graph.get_tensor_by_name("conv6/out_1:0") 
 
 
  for filename in os.listdir(image_dir):
    image_path = os.path.join(image_dir, filename)
 
    start = time.time()
    image_data = cv2.imread(image_path)
    image_data = cv2.resize(image_data, (w, h))
    image_data_1 = image_data - IMG_MEAN
    input_image = np.expand_dims(image_data_1, 0)
 
    raw_output_up = tf.image.resize_bilinear(output_tensor_name, size=[h, w], align_corners=True) 
    raw_output_up = tf.argmax(raw_output_up, axis=3)
    
 
    predict_img = sess.run(raw_output_up, feed_dict={input_image_tensor: input_image})    # 1,height,width
    predict_img = np.squeeze(predict_img)   # height, width 
 
    voc_palette = visual.make_palette(3)
    masked_im = visual.vis_seg(image_data, predict_img, voc_palette)
    cv2.imwrite("%s_pred.png" % (save_dir + filename.split(".")[0]), masked_im)
 
 
    print(time.time() - start)
 
  print(">>>>>>Done")

下面是解决溢出问题的代码(将部分代码放在for循环外

# Creates graph from saved GraphDef
  create_graph(pb_path)
 
  # Init tf Session
  config = tf.ConfigProto()
  config.gpu_options.allow_growth = True
  sess = tf.Session(config=config)
  init = tf.global_variables_initializer()
  sess.run(init)
 
  input_image_tensor = sess.graph.get_tensor_by_name("create_inputs/batch:0") 
  output_tensor_name = sess.graph.get_tensor_by_name("conv6/out_1:0") 
  
##############################################################################################################
  raw_output_up = tf.image.resize_bilinear(output_tensor_name, size=[h, w], align_corners=True) 
  raw_output_up = tf.argmax(raw_output_up, axis=3)
##############################################################################################################
 
  for filename in os.listdir(image_dir):
    image_path = os.path.join(image_dir, filename)
 
    start = time.time()
    image_data = cv2.imread(image_path)
    image_data = cv2.resize(image_data, (w, h))
    image_data_1 = image_data - IMG_MEAN
    input_image = np.expand_dims(image_data_1, 0)
    
    predict_img = sess.run(raw_output_up, feed_dict={input_image_tensor: input_image})    # 1,height,width
    predict_img = np.squeeze(predict_img)   # height, width 
 
    voc_palette = visual.make_palette(3)
    masked_im = visual.vis_seg(image_data, predict_img, voc_palette)
    cv2.imwrite("%s_pred.png" % (save_dir + filename.split(".")[0]), masked_im)
    print(time.time() - start)
 
  print(">>>>>>Done")

总结:

在迭代过程中, 在sess.run的for循环中不要加入tensorflow一些op操作,会增加图节点,否则随着迭代的进行,tf的图会越来越大,最终导致溢出;

建议不要使用tf.gfile.FastGFile(image_path, 'rb').read()读入数据(有可能会造成溢出),用opencv之类读取。

以上这篇解决Tensoflow sess.run导致的内存溢出问题就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
Python中的hypot()方法使用简介
May 18 Python
Django objects.all()、objects.get()与objects.filter()之间的区别介绍
Jun 12 Python
pygame 精灵的行走及二段跳的实现方法(必看篇)
Jul 10 Python
Python tkinter模块中类继承的三种方式分析
Aug 08 Python
Django如何开发简单的查询接口详解
May 17 Python
如何使用pyinstaller打包32位的exe程序
May 26 Python
Python学习笔记之变量、自定义函数用法示例
May 28 Python
VSCode中自动为Python文件添加头部注释
Nov 14 Python
Pycharm 2020最新永久激活码(附最新激活码和插件)
Sep 17 Python
使用豆瓣源来安装python中的第三方库方法
Jan 26 Python
Python环境搭建过程从安装到Hello World
Feb 05 Python
Python 匹配文本并在其上一行追加文本
May 11 Python
解决TensorFlow训练内存不断增长,进程被杀死问题
Feb 05 #Python
浅谈tensorflow之内存暴涨问题
Feb 05 #Python
对Tensorflow中Device实例的生成和管理详解
Feb 04 #Python
关于windows下Tensorflow和pytorch安装教程
Feb 04 #Python
django3.02模板中的超链接配置实例代码
Feb 04 #Python
tensorflow自定义激活函数实例
Feb 04 #Python
pytorch对梯度进行可视化进行梯度检查教程
Feb 04 #Python
You might like
PHP 异步执行方法,模拟多线程的应用分析
2013/06/03 PHP
php分页代码学习示例分享
2014/02/20 PHP
php使用fopen创建utf8编码文件的方法
2014/10/31 PHP
PHP中strncmp()函数比较两个字符串前2个字符是否相等的方法
2016/01/07 PHP
基于ThinkPHP5.0实现图片上传插件
2017/09/25 PHP
php设计模式之模板模式实例分析【星际争霸游戏案例】
2020/03/24 PHP
Javascript根据指定下标或对象删除数组元素
2012/12/21 Javascript
javascript打印html内容功能的方法示例
2013/11/28 Javascript
javascript判断是手机还是电脑访问网页的简单实例分享
2014/06/03 Javascript
用js的document.write输出的广告无阻塞加载的方法
2014/06/05 Javascript
jquery实现点击展开列表同时隐藏其他列表
2015/08/10 Javascript
基于jQuery实现仿QQ空间送礼物功能代码
2016/05/24 Javascript
angular实现IM聊天图片发送实例
2017/05/08 Javascript
Vue.js搭建移动端购物车界面
2020/06/28 Javascript
Vue2.x通用编辑组件的封装及应用详解
2019/05/28 Javascript
Vue3 响应式侦听与计算的实现
2020/11/11 Javascript
JavaScript构造函数原理及实现流程解析
2020/11/19 Javascript
[08:02]DOTA2牵红线 zhou神抱得美人归
2014/03/22 DOTA
一波神奇的Python语句、函数与方法的使用技巧总结
2015/12/08 Python
Python模拟登陆实现代码
2017/06/14 Python
Python爬取十篇新闻统计TF-IDF
2018/01/03 Python
使用Python从零开始撸一个区块链
2018/03/14 Python
Python实现的读取文件内容并写入其他文件操作示例
2019/04/09 Python
pycharm激活码有效到2020年11月底
2020/09/18 Python
Python中logging日志库实例详解
2020/02/19 Python
python GUI库图形界面开发之PyQt5控件QTableWidget详细使用方法与属性
2020/02/25 Python
CSS3实现点击放大的动画实例代码
2017/02/27 HTML / CSS
使用html5 canvas绘制圆环动效
2019/06/03 HTML / CSS
施华洛世奇德国官网:SWAROVSKI德国
2017/02/01 全球购物
Stefania Mode美国:奢华设计师和时尚服装
2018/01/07 全球购物
欧洲领先的火车票和大巴票预订平台:Trainline
2018/12/26 全球购物
Notino罗马尼亚网站:购买香水和化妆品
2019/07/20 全球购物
大学生求职计划书
2014/04/30 职场文书
技术岗位竞聘演讲稿
2014/05/16 职场文书
2014年校长工作总结
2014/12/11 职场文书
营运督导岗位职责
2015/04/10 职场文书