解决Tensorflow sess.run导致的内存溢出问题


Posted in Python onFebruary 05, 2020

下面是调用模型进行批量测试的代码(出现溢出),开始以为导致溢出的原因是数据读入方式问题引起的,用了tf , PIL和cv等方式读入图片数据,发现越来越慢,内存占用飙升,调试时发现是sess.run这里出了问题(随着for循环进行速度越来越慢)。

# Creates graph from saved GraphDef
  create_graph(pb_path)
 
  # Init tf Session
  config = tf.ConfigProto()
  config.gpu_options.allow_growth = True
  sess = tf.Session(config=config)
  init = tf.global_variables_initializer()
  sess.run(init)
 
 
  input_image_tensor = sess.graph.get_tensor_by_name("create_inputs/batch:0") 
  output_tensor_name = sess.graph.get_tensor_by_name("conv6/out_1:0") 
 
 
  for filename in os.listdir(image_dir):
    image_path = os.path.join(image_dir, filename)
 
    start = time.time()
    image_data = cv2.imread(image_path)
    image_data = cv2.resize(image_data, (w, h))
    image_data_1 = image_data - IMG_MEAN
    input_image = np.expand_dims(image_data_1, 0)
 
    raw_output_up = tf.image.resize_bilinear(output_tensor_name, size=[h, w], align_corners=True) 
    raw_output_up = tf.argmax(raw_output_up, axis=3)
    
 
    predict_img = sess.run(raw_output_up, feed_dict={input_image_tensor: input_image})    # 1,height,width
    predict_img = np.squeeze(predict_img)   # height, width 
 
    voc_palette = visual.make_palette(3)
    masked_im = visual.vis_seg(image_data, predict_img, voc_palette)
    cv2.imwrite("%s_pred.png" % (save_dir + filename.split(".")[0]), masked_im)
 
 
    print(time.time() - start)
 
  print(">>>>>>Done")

下面是解决溢出问题的代码(将部分代码放在for循环外

# Creates graph from saved GraphDef
  create_graph(pb_path)
 
  # Init tf Session
  config = tf.ConfigProto()
  config.gpu_options.allow_growth = True
  sess = tf.Session(config=config)
  init = tf.global_variables_initializer()
  sess.run(init)
 
  input_image_tensor = sess.graph.get_tensor_by_name("create_inputs/batch:0") 
  output_tensor_name = sess.graph.get_tensor_by_name("conv6/out_1:0") 
  
##############################################################################################################
  raw_output_up = tf.image.resize_bilinear(output_tensor_name, size=[h, w], align_corners=True) 
  raw_output_up = tf.argmax(raw_output_up, axis=3)
##############################################################################################################
 
  for filename in os.listdir(image_dir):
    image_path = os.path.join(image_dir, filename)
 
    start = time.time()
    image_data = cv2.imread(image_path)
    image_data = cv2.resize(image_data, (w, h))
    image_data_1 = image_data - IMG_MEAN
    input_image = np.expand_dims(image_data_1, 0)
    
    predict_img = sess.run(raw_output_up, feed_dict={input_image_tensor: input_image})    # 1,height,width
    predict_img = np.squeeze(predict_img)   # height, width 
 
    voc_palette = visual.make_palette(3)
    masked_im = visual.vis_seg(image_data, predict_img, voc_palette)
    cv2.imwrite("%s_pred.png" % (save_dir + filename.split(".")[0]), masked_im)
    print(time.time() - start)
 
  print(">>>>>>Done")

总结:

在迭代过程中, 在sess.run的for循环中不要加入tensorflow一些op操作,会增加图节点,否则随着迭代的进行,tf的图会越来越大,最终导致溢出;

建议不要使用tf.gfile.FastGFile(image_path, 'rb').read()读入数据(有可能会造成溢出),用opencv之类读取。

以上这篇解决Tensoflow sess.run导致的内存溢出问题就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
python基础教程之类class定义使用方法
Feb 20 Python
详解Python中find()方法的使用
May 18 Python
python utc datetime转换为时间戳的方法
Jan 15 Python
python selenium 弹出框处理的实现
Feb 26 Python
Python pip替换为阿里源的方法步骤
Jul 02 Python
Python3 无重复字符的最长子串的实现
Oct 08 Python
Python模拟FTP文件服务器的操作方法
Feb 18 Python
Python PyQt5整理介绍
Apr 01 Python
利用Python制作动态排名图的实现代码
Apr 09 Python
Python如何截图保存的三种方法(小结)
Sep 01 Python
详解向scrapy中的spider传递参数的几种方法(2种)
Sep 28 Python
Python办公自动化之教你用Python批量识别发票并录入到Excel表格中
Jun 26 Python
解决TensorFlow训练内存不断增长,进程被杀死问题
Feb 05 #Python
浅谈tensorflow之内存暴涨问题
Feb 05 #Python
对Tensorflow中Device实例的生成和管理详解
Feb 04 #Python
关于windows下Tensorflow和pytorch安装教程
Feb 04 #Python
django3.02模板中的超链接配置实例代码
Feb 04 #Python
tensorflow自定义激活函数实例
Feb 04 #Python
pytorch对梯度进行可视化进行梯度检查教程
Feb 04 #Python
You might like
php压缩HTML函数轻松实现压缩html/js/Css及注意事项
2013/01/27 PHP
php模拟实现斗地主发牌
2020/04/22 PHP
js 优化次数过多的循环 考虑到性能问题
2011/03/05 Javascript
关于JAVASCRIPT urldecode URL解码的问题
2012/01/08 Javascript
form表单action提交的js部分与html部分
2014/01/07 Javascript
浅析javascript中函数声明和函数表达式的区别
2015/02/15 Javascript
jquery插件qrcode在线生成二维码
2015/04/26 Javascript
javascript cookie的简单应用
2016/02/24 Javascript
一个用jquery写的判断div滚动条到底部的方法【推荐】
2016/04/29 Javascript
使用jQuery.Qrcode插件在客户端动态生成二维码并添加自定义Logo
2016/09/01 Javascript
浅谈jquery的html方法里包含特殊字符的处理
2016/11/30 Javascript
JS开发中百度地图+城市联动实现实时触发查询地址功能
2017/04/13 Javascript
详解Angular 中 ngOnInit 和 constructor 使用场景
2017/06/22 Javascript
js仿微信抢红包功能
2020/09/25 Javascript
改变vue请求过来的数据中的某一项值的方法(详解)
2018/03/08 Javascript
vue中element-ui表格缩略图悬浮放大功能的实例代码
2018/06/26 Javascript
vue 自定义提示框(Toast)组件的实现代码
2018/08/17 Javascript
解决layer.msg 不居中 ifram中的问题
2019/09/05 Javascript
微信小程序实现蓝牙打印
2019/09/23 Javascript
原生js实现ajax请求和JSONP跨域请求操作示例
2020/03/14 Javascript
微信小程序仿抖音视频之整屏上下切换功能的实现代码
2020/05/24 Javascript
vue项目中使用多选框的实例代码
2020/07/22 Javascript
Python实现测试磁盘性能的方法
2015/03/12 Python
详解Python的迭代器、生成器以及相关的itertools包
2015/04/02 Python
详解Python编程中包的概念与管理
2015/10/16 Python
简单讲解Python编程中namedtuple类的用法
2016/06/21 Python
Python使用一行代码获取上个月是几月
2018/08/30 Python
使用Scrapy爬取动态数据
2018/10/21 Python
Django 开发环境配置过程详解
2019/07/18 Python
Python 操作mysql数据库查询之fetchone(), fetchmany(), fetchall()用法示例
2019/10/17 Python
基于python计算滚动方差(标准差)talib和pd.rolling函数差异详解
2020/06/08 Python
Python requests接口测试实现代码
2020/09/08 Python
Banana Republic英国官网:香蕉共和国,GAP集团旗下偏贵族风
2018/04/24 全球购物
村官学习十八大感想
2014/01/15 职场文书
习近平在党的群众路线教育实践活动总结大会上的讲话全文
2014/10/25 职场文书
入党自荐书范文
2015/03/05 职场文书