如何从csv文件构建Tensorflow的数据集


Posted in Python onSeptember 21, 2020

从csv文件构建Tensorflow的数据集

当我们有一系列CSV文件,如何构建Tensorflow的数据集呢?

基本步骤

  1. 获得一组CSV文件的路径
  2. 将这组文件名,转成文件名对应的dataset => file_dataset
  3. 根据file_dataset中的每个文件名,读取文件内容 生成一个内容的dataset => content_dataset
  4. 这样的多个content_dataset, 拼接起来,形成一整个dataset
  5. 因为读出来的每条记录都是string类型, 所以还需要对每条记录做decode

存在一个这样的变量train_filenames

pprint.pprint(train_filenames)
#	['generate_csv\\train_00.csv',
#	 'generate_csv\\train_01.csv',
#	 'generate_csv\\train_02.csv',
#	 'generate_csv\\train_03.csv',
#	 'generate_csv\\train_04.csv',
#	 'generate_csv\\train_05.csv',
#	 'generate_csv\\train_06.csv',
#	 'generate_csv\\train_07.csv',
#	 'generate_csv\\train_08.csv',
#	 'generate_csv\\train_09.csv',
#	 'generate_csv\\train_10.csv',
#	 'generate_csv\\train_11.csv',
#	 'generate_csv\\train_12.csv',
#	 'generate_csv\\train_13.csv',
#	 'generate_csv\\train_14.csv',
#	 'generate_csv\\train_15.csv',
#	 'generate_csv\\train_16.csv',
#	 'generate_csv\\train_17.csv',
#	 'generate_csv\\train_18.csv',
#	 'generate_csv\\train_19.csv']

接着,我们用提前定义好的API构建文件名数据集file_dataset

filename_dataset = tf.data.Dataset.list_files(train_filenames)
for filename in filename_dataset:
  print(filename)
#tf.Tensor(b'generate_csv\\train_09.csv', shape=(), dtype=string)
#tf.Tensor(b'generate_csv\\train_19.csv', shape=(), dtype=string)
#tf.Tensor(b'generate_csv\\train_03.csv', shape=(), dtype=string)
#tf.Tensor(b'generate_csv\\train_01.csv', shape=(), dtype=string)
#tf.Tensor(b'generate_csv\\train_14.csv', shape=(), dtype=string)
#tf.Tensor(b'generate_csv\\train_17.csv', shape=(), dtype=string)
#tf.Tensor(b'generate_csv\\train_15.csv', shape=(), dtype=string)
#tf.Tensor(b'generate_csv\\train_06.csv', shape=(), dtype=string)
#tf.Tensor(b'generate_csv\\train_05.csv', shape=(), dtype=string)
#tf.Tensor(b'generate_csv\\train_07.csv', shape=(), dtype=string)
#tf.Tensor(b'generate_csv\\train_11.csv', shape=(), dtype=string)
#tf.Tensor(b'generate_csv\\train_02.csv', shape=(), dtype=string)
#tf.Tensor(b'generate_csv\\train_12.csv', shape=(), dtype=string)
#tf.Tensor(b'generate_csv\\train_13.csv', shape=(), dtype=string)
#tf.Tensor(b'generate_csv\\train_10.csv', shape=(), dtype=string)
#tf.Tensor(b'generate_csv\\train_16.csv', shape=(), dtype=string)
#tf.Tensor(b'generate_csv\\train_18.csv', shape=(), dtype=string)
#tf.Tensor(b'generate_csv\\train_00.csv', shape=(), dtype=string)
#tf.Tensor(b'generate_csv\\train_04.csv', shape=(), dtype=string)
#tf.Tensor(b'generate_csv\\train_08.csv', shape=(), dtype=string)

第三步, 根据每个文件名,去读取文件里面的内容

dataset = filename_dataset.interleave(
  lambda filename: tf.data.TextLineDataset(filename).skip(1),
  cycle_length=5
)

for line in dataset.take(3):
  print(line)

#tf.Tensor(b'0.46908349737250216,1.8718193706428006,0.13936365871212536,-0.011055733363841472,-0.6349261778219746,-0.036732316700563934,1.0259470089944995,-1.319095600336748,2.171', shape=(), dtype=string)
#tf.Tensor(b'-1.102093775650278,1.313248890578542,-0.7212003024178728,-0.14707856286537277,0.34720121604358517,0.0965085401826684,-0.74698820254838,0.6810563907247876,1.428', shape=(), dtype=string)
#tf.Tensor(b'-0.8901003715328659,0.9142699762469286,-0.1851678950250224,-0.12947457252940406,0.5958187430364827,-0.021255215877779534,0.7914317693724252,-0.45618713536506217,0.75', shape=(), dtype=string)

interleave的作用可以类比map, 对每个元素应用操作,然后还能把结果合起来。
因此,有了interleave, 我们就把第三四步,一起完成了
之所以skip(1),是因为这个csv第一行是header.
cycle_length是并行化构建数据集的线程数

好,第五步,解析每条记录

def parse_csv_line(line, n_fields=9):
  defaults = [tf.constant(np.nan)] * n_fields
  parsed_fields = tf.io.decode_csv(line, record_defaults=defaults)
  x = tf.stack(parsed_fields[:-1])
  y = tf.stack(parsed_fields[-1:])
  return x, y

parse_csv_line('1.2286258796252256,-1.0806245954111382,0.4444161407754224,-0.0352172575329119,0.9740347681426992,-0.003516079473801425,-0.8126524696425611,0.865609068204283,2.803', 9)

#(<tf.Tensor: shape=(8,), dtype=float32, numpy= array([ 1.2286259 , -1.0806246 , 0.44441614, -0.03521726, 0.9740348 ,-0.00351608, -0.81265247, 0.86560905], dtype=float32)>,<tf.Tensor: shape=(1,), dtype=float32, numpy=array([2.803], dtype=float32)>)

最后,将每条记录都应用这个方法,就完成了构建。

dataset = dataset.map(parse_csv_line)

完整代码

def csv_2_dataset(filenames, n_readers_thread = 5, batch_size = 32, n_parse_thread = 5, shuffle_buffer_size = 10000):
  
  dataset = tf.data.Dataset.list_files(filenames)
  dataset = dataset.repeat()
  dataset = dataset.interleave(
    lambda filename: tf.data.TextLineDataset(filename).skip(1),
    cycle_length=n_readers_thread
  )
  dataset.shuffle(shuffle_buffer_size)
  dataset = dataset.map(parse_csv_line, num_parallel_calls = n_parse_thread)
  dataset = dataset.batch(batch_size)
  return dataset

如何使用

train_dataset = csv_2_dataset(train_filenames, batch_size=32)
valid_dataset = csv_2_dataset(valid_filenames, batch_size=32)

model = ...

model.fit(train_set, validation_data=valid_set, 
          steps_per_epoch = 11610 // 32,
          validation_steps = 3870 // 32,
          epochs=100, callbacks=callbacks)

这里的11610 和 3870是什么?

这是train_dataset 和 valid_dataset中数据的数量,需要在训练中手动指定每个batch中参与训练的数据的多少。

model.evaluate(test_set, steps=5160//32)

同理,测试的时候,使用这样的数据集,也需要手动指定。
5160是测试数据集的总量。

以上就是如何从csv文件构建Tensorflow的数据集的详细内容,更多关于csv文件构建Tensorflow的数据集的资料请关注三水点靠木其它相关文章!

Python 相关文章推荐
python基础教程之python消息摘要算法使用示例
Feb 10 Python
python连接mysql并提交mysql事务示例
Mar 05 Python
使用Python的Tornado框架实现一个一对一聊天的程序
Apr 25 Python
详解Python中列表和元祖的使用方法
Apr 25 Python
python中numpy包使用教程之数组和相关操作详解
Jul 30 Python
python使用正则表达式替换匹配成功的组并输出替换的次数
Nov 22 Python
http请求 request失败自动重新尝试代码示例
Jan 25 Python
实践Vim配置python开发环境
Jul 02 Python
Pytorch实现的手写数字mnist识别功能完整示例
Dec 13 Python
Python3 字典dictionary入门基础附实例
Feb 10 Python
python中@property的作用和getter setter的解释
Dec 22 Python
Python机器学习工具scikit-learn的使用笔记
Jan 28 Python
python打包多类型文件的操作方法
Sep 21 #Python
python 星号(*)的多种用途
Sep 21 #Python
Python+Selenium随机生成手机验证码并检查页面上是否弹出重复手机号码提示框
Sep 21 #Python
解决PyCharm不在run输出运行结果而不是再Console里输出的问题
Sep 21 #Python
python map比for循环快在哪
Sep 21 #Python
通过实例解析Python文件操作实现步骤
Sep 21 #Python
python Paramiko使用示例
Sep 21 #Python
You might like
解析mysql中UNIX_TIMESTAMP()函数与php中time()函数的区别
2013/06/24 PHP
PHP函数strip_tags的一个bug浅析
2014/05/22 PHP
PHP判断是否为空的几个函数对比
2015/04/21 PHP
使用PHP接受文件并获得其后缀名的方法
2015/08/05 PHP
使用PHP uniqid函数生成唯一ID
2015/11/18 PHP
JavaScript中的其他对象
2008/01/16 Javascript
javascript 面向对象,实现namespace,class,继承,重载
2009/10/29 Javascript
查看大图功能代码jquery版
2013/11/05 Javascript
ExtJS 刷新后如何默认选中刷新前最后一次选中的节点
2014/04/03 Javascript
JavaScript中提前声明变量或函数例子
2014/11/12 Javascript
jQuery获得包含margin的outerWidth和outerHeight的方法
2015/03/25 Javascript
举例讲解AngularJS中的模块
2015/06/17 Javascript
利用jQuery和CSS将背景图片拉伸
2015/10/16 Javascript
浅析BootStrap栅格系统
2016/06/07 Javascript
js html5 css俄罗斯方块游戏再现
2016/10/17 Javascript
angular源码学习第一篇 setupModuleLoader方法
2016/10/20 Javascript
JavaScript &amp; jQuery完美判断图片是否加载完毕
2017/01/08 Javascript
ionic选择多张图片上传的示例代码
2017/10/10 Javascript
vue自定义指令实现方法详解
2019/02/11 Javascript
详解Vue demo实现商品列表的展示
2019/05/07 Javascript
Vue内部渲染视图的方法
2019/09/02 Javascript
Vue3新特性之在Composition API中使用CSS Modules
2020/07/13 Javascript
[01:52]DOTA2完美大师赛Vega战队趣味视频——kpii老师小课堂
2017/11/25 DOTA
Python图像处理之gif动态图的解析与合成操作详解
2018/12/30 Python
python中Lambda表达式详解
2019/11/20 Python
python 使用递归回溯完美解决八皇后的问题
2020/02/26 Python
Python 批量读取文件中指定字符的实现
2020/03/06 Python
python matplotlib.pyplot.plot()参数用法
2020/04/14 Python
都柏林通行卡/城市通票:The Dublin Pass
2020/02/16 全球购物
继电保护工岗位职责
2014/01/05 职场文书
学校安全责任书
2014/04/14 职场文书
个性车贴标语
2014/06/24 职场文书
法学专业求职信
2014/07/15 职场文书
个人自我剖析材料
2014/09/30 职场文书
毕业赠语大全
2015/06/23 职场文书
2016年离婚协议书范文
2016/03/18 职场文书