解决Tensorflow sess.run导致的内存溢出问题


Posted in Python onFebruary 05, 2020

下面是调用模型进行批量测试的代码(出现溢出),开始以为导致溢出的原因是数据读入方式问题引起的,用了tf , PIL和cv等方式读入图片数据,发现越来越慢,内存占用飙升,调试时发现是sess.run这里出了问题(随着for循环进行速度越来越慢)。

# Creates graph from saved GraphDef
  create_graph(pb_path)
 
  # Init tf Session
  config = tf.ConfigProto()
  config.gpu_options.allow_growth = True
  sess = tf.Session(config=config)
  init = tf.global_variables_initializer()
  sess.run(init)
 
 
  input_image_tensor = sess.graph.get_tensor_by_name("create_inputs/batch:0") 
  output_tensor_name = sess.graph.get_tensor_by_name("conv6/out_1:0") 
 
 
  for filename in os.listdir(image_dir):
    image_path = os.path.join(image_dir, filename)
 
    start = time.time()
    image_data = cv2.imread(image_path)
    image_data = cv2.resize(image_data, (w, h))
    image_data_1 = image_data - IMG_MEAN
    input_image = np.expand_dims(image_data_1, 0)
 
    raw_output_up = tf.image.resize_bilinear(output_tensor_name, size=[h, w], align_corners=True) 
    raw_output_up = tf.argmax(raw_output_up, axis=3)
    
 
    predict_img = sess.run(raw_output_up, feed_dict={input_image_tensor: input_image})    # 1,height,width
    predict_img = np.squeeze(predict_img)   # height, width 
 
    voc_palette = visual.make_palette(3)
    masked_im = visual.vis_seg(image_data, predict_img, voc_palette)
    cv2.imwrite("%s_pred.png" % (save_dir + filename.split(".")[0]), masked_im)
 
 
    print(time.time() - start)
 
  print(">>>>>>Done")

下面是解决溢出问题的代码(将部分代码放在for循环外

# Creates graph from saved GraphDef
  create_graph(pb_path)
 
  # Init tf Session
  config = tf.ConfigProto()
  config.gpu_options.allow_growth = True
  sess = tf.Session(config=config)
  init = tf.global_variables_initializer()
  sess.run(init)
 
  input_image_tensor = sess.graph.get_tensor_by_name("create_inputs/batch:0") 
  output_tensor_name = sess.graph.get_tensor_by_name("conv6/out_1:0") 
  
##############################################################################################################
  raw_output_up = tf.image.resize_bilinear(output_tensor_name, size=[h, w], align_corners=True) 
  raw_output_up = tf.argmax(raw_output_up, axis=3)
##############################################################################################################
 
  for filename in os.listdir(image_dir):
    image_path = os.path.join(image_dir, filename)
 
    start = time.time()
    image_data = cv2.imread(image_path)
    image_data = cv2.resize(image_data, (w, h))
    image_data_1 = image_data - IMG_MEAN
    input_image = np.expand_dims(image_data_1, 0)
    
    predict_img = sess.run(raw_output_up, feed_dict={input_image_tensor: input_image})    # 1,height,width
    predict_img = np.squeeze(predict_img)   # height, width 
 
    voc_palette = visual.make_palette(3)
    masked_im = visual.vis_seg(image_data, predict_img, voc_palette)
    cv2.imwrite("%s_pred.png" % (save_dir + filename.split(".")[0]), masked_im)
    print(time.time() - start)
 
  print(">>>>>>Done")

总结:

在迭代过程中, 在sess.run的for循环中不要加入tensorflow一些op操作,会增加图节点,否则随着迭代的进行,tf的图会越来越大,最终导致溢出;

建议不要使用tf.gfile.FastGFile(image_path, 'rb').read()读入数据(有可能会造成溢出),用opencv之类读取。

以上这篇解决Tensoflow sess.run导致的内存溢出问题就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
Python中使用动态变量名的方法
May 06 Python
python访问mysql数据库的实现方法(2则示例)
Jan 06 Python
老生常谈Python之装饰器、迭代器和生成器
Jul 26 Python
python 除法保留两位小数点的方法
Jul 16 Python
python3.7 使用pymssql往sqlserver插入数据的方法
Jul 08 Python
python实现对服务器脚本敏感信息的加密解密功能
Aug 13 Python
Mac 使用python3的matplot画图不显示的解决
Nov 23 Python
基于python连接oracle导并出数据文件
Apr 28 Python
如何在python中执行另一个py文件
Apr 30 Python
Python并发concurrent.futures和asyncio实例
May 04 Python
详解pytorch创建tensor函数
Mar 22 Python
pytorch实现加载保存查看checkpoint文件
Jul 15 Python
解决TensorFlow训练内存不断增长,进程被杀死问题
Feb 05 #Python
浅谈tensorflow之内存暴涨问题
Feb 05 #Python
对Tensorflow中Device实例的生成和管理详解
Feb 04 #Python
关于windows下Tensorflow和pytorch安装教程
Feb 04 #Python
django3.02模板中的超链接配置实例代码
Feb 04 #Python
tensorflow自定义激活函数实例
Feb 04 #Python
pytorch对梯度进行可视化进行梯度检查教程
Feb 04 #Python
You might like
re0第二季蕾姆被制作组打入冷宫!艾米莉亚女主扶正,原因唏嘘
2020/04/02 日漫
php中文本数据翻页(留言本翻页)
2006/10/09 PHP
ThinkPHP框架实现导出excel数据的方法示例【基于PHPExcel】
2018/05/12 PHP
最近项目写了一些js,水平有待提高
2009/01/31 Javascript
JavaScript Cookie显示用户上次访问的时间和次数
2009/12/08 Javascript
浅析JQuery获取和设置Select选项的常用方法总结
2013/07/04 Javascript
使用JSLint提高JS代码质量方法分享
2013/12/16 Javascript
利用try-catch判断变量是已声明未声明还是未赋值
2014/03/12 Javascript
jQuery中:eq()选择器用法实例
2014/12/29 Javascript
jQuery实现图片渐入渐出切换展示效果
2015/08/15 Javascript
基于cssSlidy.js插件实现响应式手机图片轮播效果
2016/08/30 Javascript
AngularJS入门教程之模块化操作用法示例
2016/11/02 Javascript
使用Vue写一个datepicker的示例
2018/01/27 Javascript
Vue.directive()的用法和实例详解
2018/03/04 Javascript
vue2.0 computed 计算list循环后累加值的实例
2018/03/07 Javascript
每个 JavaScript 工程师都应懂的33个概念
2018/10/22 Javascript
微信小程序如何获取用户头像和昵称
2019/09/23 Javascript
vue使用better-scroll实现滑动以及左右联动
2020/06/30 Javascript
Python中super的用法实例
2015/05/28 Python
如何准确判断请求是搜索引擎爬虫(蜘蛛)发出的请求
2015/10/13 Python
十个Python程序员易犯的错误
2015/12/15 Python
selenium跳过webdriver检测并模拟登录淘宝
2019/06/12 Python
python实现控制COM口的示例
2019/07/03 Python
Django全局启用登陆验证login_required的方法
2020/06/02 Python
详解CSS3的图层阴影和文字阴影效果使用
2016/06/09 HTML / CSS
Bootstrap 学习分享
2012/11/12 HTML / CSS
美国设计师精美珠宝购物网:Netaya
2016/08/28 全球购物
英国皇家造币厂:The Royal Mint
2018/10/05 全球购物
介绍一下gcc特性
2012/01/20 面试题
数控个人求职信范文
2014/02/03 职场文书
奥运会口号
2014/06/13 职场文书
运动会标语
2014/06/21 职场文书
幼儿园小班教师个人工作总结
2015/02/06 职场文书
集团财务总监岗位职责
2015/04/03 职场文书
关于空气污染危害的感想
2015/08/11 职场文书
mysql如何能有效防止删库跑路
2021/10/05 MySQL