基于Tensorflow批量数据的输入实现方式


Posted in Python onFebruary 05, 2020

基于Tensorflow下的批量数据的输入处理:

1.Tensor TFrecords格式

2.h5py的库的数组方法

在tensorflow的框架下写CNN代码,我在书写过程中,感觉不是框架内容难写, 更多的是我在对图像的预处理和输入这部分花了很多精神。

使用了两种方法:

方法一:

Tensor 以Tfrecords的格式存储数据,如果对数据进行标签,可以同时做到数据打标签。

①创建TFrecords文件

orig_image = '/home/images/train_image/'
gen_image = '/home/images/image_train.tfrecords'
def create_record():
  writer = tf.python_io.TFRecordWriter(gen_image)
  class_path = orig_image
  for img_name in os.listdir(class_path): #读取每一幅图像
    img_path = class_path + img_name 
    img = Image.open(img_path) #读取图像
    #img = img.resize((256, 256)) #设置图片大小, 在这里可以对图像进行处理
    img_raw = img.tobytes() #将图片转化为原声bytes 
    example = tf.train.Example(
         features=tf.train.Features(feature={
             'label': tf.train.Feature(int64_list=tf.train.Int64List(value=[0])), #打标签
             'img_raw': tf.train.Feature(bytes_list=tf.train.BytesList(value=[img_raw]))#存储数据
             }))
    writer.write(example.SerializeToString())
  writer.close()

②读取TFrecords文件

def read_and_decode(filename):
  #创建文件队列,不限读取的数据
  filename_queue = tf.train.string_input_producer([filename])
  reader = tf.TFRecordReader()
  _, serialized_example = reader.read(filename_queue)

  features = tf.parse_single_example(
      serialized_example,
      features={
          'label': tf.FixedLenFeature([], tf.int64),
          'img_raw': tf.FixedLenFeature([], tf.string)})
  label = features['label']
  img = features['img_raw']
  img = tf.decode_raw(img, tf.uint8) #tf.float32
  img = tf.image.convert_image_dtype(img, dtype=tf.float32)
  img = tf.reshape(img, [256, 256, 1])
  label = tf.cast(label, tf.int32)
  return img, label

③批量读取数据,使用tf.train.batch

min_after_dequeue = 10000
capacity = min_after_dequeue + 3 * batch_size
num_samples= len(os.listdir(orig_image))
create_record()
img, label = read_and_decode(gen_image)
total_batch = int(num_samples/batch_size)
image_batch, label_batch = tf.train.batch([img, label], batch_size=batch_size,
                      num_threads=32, capacity=capacity) 
init_op = tf.group(tf.global_variables_initializer(), tf.local_variables_initializer())
with tf.Session() as sess:
  sess.run(init_op)
  coord = tf.train.Coordinator()
  threads = tf.train.start_queue_runners(coord=coord)
  for i in range(total_batch):
     cur_image_batch, cur_label_batch = sess.run([image_batch, label_batch])
  coord.request_stop()
  coord.join(threads)

方法二:

使用h5py就是使用数组的格式来存储数据

这个方法比较好,在CNN的过程中,会使用到多个数据类存储,比较好用, 比如一个数据进行了两种以上的变化,并且分类存储,我认为这个方法会比较好用。

import os
import h5py
import matplotlib.pyplot as plt
import numpy as np
import random
from scipy.interpolate import griddata
from skimage import img_as_float
import matplotlib.pyplot as plt
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'
class_path = '/home/awen/Juanjuan/Python Project/train_BSDS/test_gray_0_1/'
for img_name in os.listdir(class_path):
  img_path = class_path + img_name
  img = io.imread(img_path)
  m1 = img_as_float(img)
  m2, m3 = sample_inter1(m1) #一个数据处理的函数
  m1 = m1.reshape([256, 256, 1])
  m2 = m2.reshape([256, 256, 1])
  m3 = m3.reshape([256, 256, 1])
  orig_image.append(m1)
  sample_near.append(m2)
  sample_line.append(m3)

arrorig_image = np.asarray(orig_image) # [?, 256, 256, 1]
arrlsample_near = np.asarray(sample_near) # [?, 256, 256, 1] 
arrlsample_line = np.asarray(sample_line) # [?, 256, 256, 1] 

save_path = '/home/awen/Juanjuan/Python Project/train_BSDS/test_sample/train.h5'
def make_data(path):
  with h5py.File(save_path, 'w') as hf:
     hf.create_dataset('orig_image', data=arrorig_image)
     hf.create_dataset('sample_near', data=arrlsample_near)
     hf.create_dataset('sample_line', data=arrlsample_line)

def read_data(path):
  with h5py.File(path, 'r') as hf:
     orig_image = np.array(hf.get('orig_image')) #一定要对清楚上边的标签名orig_image;
     sample_near = np.array(hf.get('sample_near'))
     sample_line = np.array(hf.get('sample_line'))
  return orig_image, sample_near, sample_line
make_data(save_path)
orig_image1, sample_near1, sample_line1 = read_data(save_path)
total_number = len(orig_image1)
batch_size = 20
batch_index = total_number/batch_size
for i in range(batch_index):
  batch_orig = orig_image1[i*batch_size:(i+1)*batch_size]
  batch_sample_near = sample_near1[i*batch_size:(i+1)*batch_size]
  batch_sample_line = sample_line1[i*batch_size:(i+1)*batch_size]

在使用h5py的时候,生成的文件巨大的时候,读取数据显示错误:ioerror: unable to open file (bad object header version number)

基本就是这个生成的文件不能使用,适当的减少存储的数据,即可。

以上这篇基于Tensorflow批量数据的输入实现方式就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
详解如何用OpenCV + Python 实现人脸识别
Oct 20 Python
python issubclass 和 isinstance函数
Jul 25 Python
Pytorch反向求导更新网络参数的方法
Aug 17 Python
python Dijkstra算法实现最短路径问题的方法
Sep 19 Python
基于python中__add__函数的用法
Nov 25 Python
使用Python进行防病毒免杀解析
Dec 13 Python
Pytorch释放显存占用方式
Jan 13 Python
python GUI库图形界面开发之PyQt5表单布局控件QFormLayout详细使用方法与实例
Mar 06 Python
JupyterNotebook 输出窗口的显示效果调整实现
Sep 22 Python
python实现网页录音效果
Oct 26 Python
python 写一个水果忍者游戏
Jan 13 Python
python疲劳驾驶困倦低头检测功能的实现
Apr 04 Python
Python操作注册表详细步骤介绍
Feb 05 #Python
Python类继承和多态原理解析
Feb 05 #Python
Python模块 _winreg操作注册表
Feb 05 #Python
python3操作注册表的方法(Url protocol)
Feb 05 #Python
Python tkinter模版代码实例
Feb 05 #Python
Python Scrapy框架第一个入门程序示例
Feb 05 #Python
python lambda函数及三个常用的高阶函数
Feb 05 #Python
You might like
微信开发之获取JSAPI TICKET
2017/07/07 PHP
JavaScript性能优化 创建文档碎片(document.createDocumentFragment)
2010/07/13 Javascript
javascript代码运行不出来执行错误的可能情况整理
2013/10/18 Javascript
Nodejs学习笔记之Global Objects全局对象
2015/01/13 NodeJs
jQuery性能优化技巧分析
2015/02/20 Javascript
JS 实现Base64编码与解码实例详解
2016/11/07 Javascript
AngularJS中的JSONP实例解析
2016/12/01 Javascript
jQuery File Upload文件上传插件使用详解
2016/12/06 Javascript
BootStrap的select2既可以查询又可以输入的实现代码
2017/02/17 Javascript
React Native基础入门之调试React Native应用的一小步
2018/07/02 Javascript
微信小程序有旋转动画效果的音乐组件实例代码
2018/08/22 Javascript
如何解决vue2.0下IE浏览器白屏问题
2018/09/13 Javascript
微信小程序页面间传值与页面取值操作实例分析
2019/04/30 Javascript
[01:55]2014DOTA2国际邀请赛快报:国土生病 紧急去医院治疗
2014/07/10 DOTA
Python批量查询域名是否被注册过
2017/06/21 Python
解决Tensorflow安装成功,但在导入时报错的问题
2018/06/13 Python
Python中创建二维数组
2018/10/17 Python
python3使用flask编写注册post接口的方法
2018/12/28 Python
Python实现批量执行同目录下的py文件方法
2019/01/11 Python
python 搜索大文件的实例代码
2019/07/08 Python
Docker部署Python爬虫项目的方法步骤
2020/01/19 Python
python函数中将变量名转换成字符串实例
2020/05/11 Python
完美解决Django2.0中models下的ForeignKey()问题
2020/05/19 Python
python3通过subprocess模块调用脚本并和脚本交互的操作
2020/12/05 Python
Python实现Kerberos用户的增删改查操作
2020/12/14 Python
CSS3中文字镂空、透明值、阴影效果设置示例小结
2016/03/07 HTML / CSS
详解如何在css中引入自定义字体(font-face)
2018/05/17 HTML / CSS
Nike法国官方网站:Nike.com FR
2018/07/22 全球购物
网上常见的一份Linux面试题(多项选择部分)
2015/02/07 面试题
军训的自我鉴定
2013/12/10 职场文书
户外婚礼策划方案
2014/02/08 职场文书
大学应届毕业生求职信
2014/05/24 职场文书
高三教师工作总结2015
2015/07/21 职场文书
小数乘法教学反思
2016/02/22 职场文书
提升Nginx性能的一些建议
2021/03/31 Servers
pycharm 如何查看某一函数源码的快捷键
2021/05/12 Python