基于Tensorflow批量数据的输入实现方式


Posted in Python onFebruary 05, 2020

基于Tensorflow下的批量数据的输入处理:

1.Tensor TFrecords格式

2.h5py的库的数组方法

在tensorflow的框架下写CNN代码,我在书写过程中,感觉不是框架内容难写, 更多的是我在对图像的预处理和输入这部分花了很多精神。

使用了两种方法:

方法一:

Tensor 以Tfrecords的格式存储数据,如果对数据进行标签,可以同时做到数据打标签。

①创建TFrecords文件

orig_image = '/home/images/train_image/'
gen_image = '/home/images/image_train.tfrecords'
def create_record():
  writer = tf.python_io.TFRecordWriter(gen_image)
  class_path = orig_image
  for img_name in os.listdir(class_path): #读取每一幅图像
    img_path = class_path + img_name 
    img = Image.open(img_path) #读取图像
    #img = img.resize((256, 256)) #设置图片大小, 在这里可以对图像进行处理
    img_raw = img.tobytes() #将图片转化为原声bytes 
    example = tf.train.Example(
         features=tf.train.Features(feature={
             'label': tf.train.Feature(int64_list=tf.train.Int64List(value=[0])), #打标签
             'img_raw': tf.train.Feature(bytes_list=tf.train.BytesList(value=[img_raw]))#存储数据
             }))
    writer.write(example.SerializeToString())
  writer.close()

②读取TFrecords文件

def read_and_decode(filename):
  #创建文件队列,不限读取的数据
  filename_queue = tf.train.string_input_producer([filename])
  reader = tf.TFRecordReader()
  _, serialized_example = reader.read(filename_queue)

  features = tf.parse_single_example(
      serialized_example,
      features={
          'label': tf.FixedLenFeature([], tf.int64),
          'img_raw': tf.FixedLenFeature([], tf.string)})
  label = features['label']
  img = features['img_raw']
  img = tf.decode_raw(img, tf.uint8) #tf.float32
  img = tf.image.convert_image_dtype(img, dtype=tf.float32)
  img = tf.reshape(img, [256, 256, 1])
  label = tf.cast(label, tf.int32)
  return img, label

③批量读取数据,使用tf.train.batch

min_after_dequeue = 10000
capacity = min_after_dequeue + 3 * batch_size
num_samples= len(os.listdir(orig_image))
create_record()
img, label = read_and_decode(gen_image)
total_batch = int(num_samples/batch_size)
image_batch, label_batch = tf.train.batch([img, label], batch_size=batch_size,
                      num_threads=32, capacity=capacity) 
init_op = tf.group(tf.global_variables_initializer(), tf.local_variables_initializer())
with tf.Session() as sess:
  sess.run(init_op)
  coord = tf.train.Coordinator()
  threads = tf.train.start_queue_runners(coord=coord)
  for i in range(total_batch):
     cur_image_batch, cur_label_batch = sess.run([image_batch, label_batch])
  coord.request_stop()
  coord.join(threads)

方法二:

使用h5py就是使用数组的格式来存储数据

这个方法比较好,在CNN的过程中,会使用到多个数据类存储,比较好用, 比如一个数据进行了两种以上的变化,并且分类存储,我认为这个方法会比较好用。

import os
import h5py
import matplotlib.pyplot as plt
import numpy as np
import random
from scipy.interpolate import griddata
from skimage import img_as_float
import matplotlib.pyplot as plt
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'
class_path = '/home/awen/Juanjuan/Python Project/train_BSDS/test_gray_0_1/'
for img_name in os.listdir(class_path):
  img_path = class_path + img_name
  img = io.imread(img_path)
  m1 = img_as_float(img)
  m2, m3 = sample_inter1(m1) #一个数据处理的函数
  m1 = m1.reshape([256, 256, 1])
  m2 = m2.reshape([256, 256, 1])
  m3 = m3.reshape([256, 256, 1])
  orig_image.append(m1)
  sample_near.append(m2)
  sample_line.append(m3)

arrorig_image = np.asarray(orig_image) # [?, 256, 256, 1]
arrlsample_near = np.asarray(sample_near) # [?, 256, 256, 1] 
arrlsample_line = np.asarray(sample_line) # [?, 256, 256, 1] 

save_path = '/home/awen/Juanjuan/Python Project/train_BSDS/test_sample/train.h5'
def make_data(path):
  with h5py.File(save_path, 'w') as hf:
     hf.create_dataset('orig_image', data=arrorig_image)
     hf.create_dataset('sample_near', data=arrlsample_near)
     hf.create_dataset('sample_line', data=arrlsample_line)

def read_data(path):
  with h5py.File(path, 'r') as hf:
     orig_image = np.array(hf.get('orig_image')) #一定要对清楚上边的标签名orig_image;
     sample_near = np.array(hf.get('sample_near'))
     sample_line = np.array(hf.get('sample_line'))
  return orig_image, sample_near, sample_line
make_data(save_path)
orig_image1, sample_near1, sample_line1 = read_data(save_path)
total_number = len(orig_image1)
batch_size = 20
batch_index = total_number/batch_size
for i in range(batch_index):
  batch_orig = orig_image1[i*batch_size:(i+1)*batch_size]
  batch_sample_near = sample_near1[i*batch_size:(i+1)*batch_size]
  batch_sample_line = sample_line1[i*batch_size:(i+1)*batch_size]

在使用h5py的时候,生成的文件巨大的时候,读取数据显示错误:ioerror: unable to open file (bad object header version number)

基本就是这个生成的文件不能使用,适当的减少存储的数据,即可。

以上这篇基于Tensorflow批量数据的输入实现方式就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
详细介绍Python函数中的默认参数
Mar 30 Python
Python中尝试多线程编程的一个简明例子
Apr 07 Python
python实现的文件同步服务器实例
Jun 02 Python
Python工程师面试题 与Python基础语法相关
Jan 14 Python
PyQt编程之如何在屏幕中央显示窗体的实例
Jun 18 Python
详解Python 字符串相似性的几种度量方法
Aug 29 Python
基于pandas中expand的作用详解
Dec 17 Python
Python中base64与xml取值结合问题
Dec 22 Python
Python3 pywin32模块安装的详细步骤
May 26 Python
python3爬虫中异步协程的用法
Jul 10 Python
浅谈matplotlib 绘制梯度下降求解过程
Jul 12 Python
python基于tkinter制作m3u8视频下载工具
Apr 24 Python
Python操作注册表详细步骤介绍
Feb 05 #Python
Python类继承和多态原理解析
Feb 05 #Python
Python模块 _winreg操作注册表
Feb 05 #Python
python3操作注册表的方法(Url protocol)
Feb 05 #Python
Python tkinter模版代码实例
Feb 05 #Python
Python Scrapy框架第一个入门程序示例
Feb 05 #Python
python lambda函数及三个常用的高阶函数
Feb 05 #Python
You might like
深入PHP magic quotes的详解
2013/06/17 PHP
php-fpm添加service服务的例子
2018/04/27 PHP
php模拟实现斗地主发牌
2020/04/22 PHP
PHP 实现 WebSocket 协议原理与应用详解
2020/04/22 PHP
jquery $(document).ready() 与window.onload的区别
2009/12/28 Javascript
jQuery 表格插件整理
2010/04/27 Javascript
Javascript面向对象编程(三) 非构造函数的继承
2011/08/28 Javascript
利用JQuery制作符合Web标准的QQ弹出消息
2014/01/14 Javascript
jQuery中click事件用法实例
2014/12/26 Javascript
JS前向后瞻正则表达式定义与用法示例
2016/12/27 Javascript
JavaScript运动框架 多物体任意值运动(三)
2017/05/17 Javascript
NodeJS如何实现同步的方法示例
2018/08/24 NodeJs
vue实现弹框遮罩点击其他区域弹框关闭及v-if与v-show的区别介绍
2018/09/29 Javascript
详解vue中使用微信jssdk
2019/04/19 Javascript
在layui tab控件中载入外部html页面的方法
2019/09/04 Javascript
IE11下CKEditor在Bootstrap Modal中下拉问题的解决
2019/09/25 Javascript
使用python实现ANN
2017/12/20 Python
Python数据处理numpy.median的实例讲解
2018/04/02 Python
python 以16进制打印输出的方法
2018/07/09 Python
解决tensorflow测试模型时NotFoundError错误的问题
2018/07/27 Python
pyqt 多窗口之间的相互调用方法
2019/06/19 Python
python3 打印输出字典中特定的某个key的方法示例
2019/07/06 Python
Django中使用CORS实现跨域请求过程解析
2019/08/05 Python
Python imread、newaxis用法详解
2019/11/04 Python
python3实现用turtle模块画一棵随机樱花树
2019/11/21 Python
python pyenv多版本管理工具的使用
2019/12/23 Python
python实现密码强度校验
2020/03/18 Python
HTML5标签与HTML4标签的区别示例介绍
2013/07/18 HTML / CSS
找到您丢失的钥匙、钱包和手机:Tile
2017/05/19 全球购物
2014年宣传部个人工作总结
2014/12/06 职场文书
2014年小学少先队工作总结
2014/12/18 职场文书
第一节英语课开场白
2015/06/01 职场文书
2016猴年春节问候语
2015/11/11 职场文书
2016年小学六一儿童节活动总结
2016/04/06 职场文书
2019奶茶店创业计划书范本,值得你借鉴
2019/08/14 职场文书
手把手教你从零开始react+antd搭建项目
2021/06/03 Javascript