如何从csv文件构建Tensorflow的数据集


Posted in Python onSeptember 21, 2020

从csv文件构建Tensorflow的数据集

当我们有一系列CSV文件,如何构建Tensorflow的数据集呢?

基本步骤

  1. 获得一组CSV文件的路径
  2. 将这组文件名,转成文件名对应的dataset => file_dataset
  3. 根据file_dataset中的每个文件名,读取文件内容 生成一个内容的dataset => content_dataset
  4. 这样的多个content_dataset, 拼接起来,形成一整个dataset
  5. 因为读出来的每条记录都是string类型, 所以还需要对每条记录做decode

存在一个这样的变量train_filenames

pprint.pprint(train_filenames)
#	['generate_csv\\train_00.csv',
#	 'generate_csv\\train_01.csv',
#	 'generate_csv\\train_02.csv',
#	 'generate_csv\\train_03.csv',
#	 'generate_csv\\train_04.csv',
#	 'generate_csv\\train_05.csv',
#	 'generate_csv\\train_06.csv',
#	 'generate_csv\\train_07.csv',
#	 'generate_csv\\train_08.csv',
#	 'generate_csv\\train_09.csv',
#	 'generate_csv\\train_10.csv',
#	 'generate_csv\\train_11.csv',
#	 'generate_csv\\train_12.csv',
#	 'generate_csv\\train_13.csv',
#	 'generate_csv\\train_14.csv',
#	 'generate_csv\\train_15.csv',
#	 'generate_csv\\train_16.csv',
#	 'generate_csv\\train_17.csv',
#	 'generate_csv\\train_18.csv',
#	 'generate_csv\\train_19.csv']

接着,我们用提前定义好的API构建文件名数据集file_dataset

filename_dataset = tf.data.Dataset.list_files(train_filenames)
for filename in filename_dataset:
  print(filename)
#tf.Tensor(b'generate_csv\\train_09.csv', shape=(), dtype=string)
#tf.Tensor(b'generate_csv\\train_19.csv', shape=(), dtype=string)
#tf.Tensor(b'generate_csv\\train_03.csv', shape=(), dtype=string)
#tf.Tensor(b'generate_csv\\train_01.csv', shape=(), dtype=string)
#tf.Tensor(b'generate_csv\\train_14.csv', shape=(), dtype=string)
#tf.Tensor(b'generate_csv\\train_17.csv', shape=(), dtype=string)
#tf.Tensor(b'generate_csv\\train_15.csv', shape=(), dtype=string)
#tf.Tensor(b'generate_csv\\train_06.csv', shape=(), dtype=string)
#tf.Tensor(b'generate_csv\\train_05.csv', shape=(), dtype=string)
#tf.Tensor(b'generate_csv\\train_07.csv', shape=(), dtype=string)
#tf.Tensor(b'generate_csv\\train_11.csv', shape=(), dtype=string)
#tf.Tensor(b'generate_csv\\train_02.csv', shape=(), dtype=string)
#tf.Tensor(b'generate_csv\\train_12.csv', shape=(), dtype=string)
#tf.Tensor(b'generate_csv\\train_13.csv', shape=(), dtype=string)
#tf.Tensor(b'generate_csv\\train_10.csv', shape=(), dtype=string)
#tf.Tensor(b'generate_csv\\train_16.csv', shape=(), dtype=string)
#tf.Tensor(b'generate_csv\\train_18.csv', shape=(), dtype=string)
#tf.Tensor(b'generate_csv\\train_00.csv', shape=(), dtype=string)
#tf.Tensor(b'generate_csv\\train_04.csv', shape=(), dtype=string)
#tf.Tensor(b'generate_csv\\train_08.csv', shape=(), dtype=string)

第三步, 根据每个文件名,去读取文件里面的内容

dataset = filename_dataset.interleave(
  lambda filename: tf.data.TextLineDataset(filename).skip(1),
  cycle_length=5
)

for line in dataset.take(3):
  print(line)

#tf.Tensor(b'0.46908349737250216,1.8718193706428006,0.13936365871212536,-0.011055733363841472,-0.6349261778219746,-0.036732316700563934,1.0259470089944995,-1.319095600336748,2.171', shape=(), dtype=string)
#tf.Tensor(b'-1.102093775650278,1.313248890578542,-0.7212003024178728,-0.14707856286537277,0.34720121604358517,0.0965085401826684,-0.74698820254838,0.6810563907247876,1.428', shape=(), dtype=string)
#tf.Tensor(b'-0.8901003715328659,0.9142699762469286,-0.1851678950250224,-0.12947457252940406,0.5958187430364827,-0.021255215877779534,0.7914317693724252,-0.45618713536506217,0.75', shape=(), dtype=string)

interleave的作用可以类比map, 对每个元素应用操作,然后还能把结果合起来。
因此,有了interleave, 我们就把第三四步,一起完成了
之所以skip(1),是因为这个csv第一行是header.
cycle_length是并行化构建数据集的线程数

好,第五步,解析每条记录

def parse_csv_line(line, n_fields=9):
  defaults = [tf.constant(np.nan)] * n_fields
  parsed_fields = tf.io.decode_csv(line, record_defaults=defaults)
  x = tf.stack(parsed_fields[:-1])
  y = tf.stack(parsed_fields[-1:])
  return x, y

parse_csv_line('1.2286258796252256,-1.0806245954111382,0.4444161407754224,-0.0352172575329119,0.9740347681426992,-0.003516079473801425,-0.8126524696425611,0.865609068204283,2.803', 9)

#(<tf.Tensor: shape=(8,), dtype=float32, numpy= array([ 1.2286259 , -1.0806246 , 0.44441614, -0.03521726, 0.9740348 ,-0.00351608, -0.81265247, 0.86560905], dtype=float32)>,<tf.Tensor: shape=(1,), dtype=float32, numpy=array([2.803], dtype=float32)>)

最后,将每条记录都应用这个方法,就完成了构建。

dataset = dataset.map(parse_csv_line)

完整代码

def csv_2_dataset(filenames, n_readers_thread = 5, batch_size = 32, n_parse_thread = 5, shuffle_buffer_size = 10000):
  
  dataset = tf.data.Dataset.list_files(filenames)
  dataset = dataset.repeat()
  dataset = dataset.interleave(
    lambda filename: tf.data.TextLineDataset(filename).skip(1),
    cycle_length=n_readers_thread
  )
  dataset.shuffle(shuffle_buffer_size)
  dataset = dataset.map(parse_csv_line, num_parallel_calls = n_parse_thread)
  dataset = dataset.batch(batch_size)
  return dataset

如何使用

train_dataset = csv_2_dataset(train_filenames, batch_size=32)
valid_dataset = csv_2_dataset(valid_filenames, batch_size=32)

model = ...

model.fit(train_set, validation_data=valid_set, 
          steps_per_epoch = 11610 // 32,
          validation_steps = 3870 // 32,
          epochs=100, callbacks=callbacks)

这里的11610 和 3870是什么?

这是train_dataset 和 valid_dataset中数据的数量,需要在训练中手动指定每个batch中参与训练的数据的多少。

model.evaluate(test_set, steps=5160//32)

同理,测试的时候,使用这样的数据集,也需要手动指定。
5160是测试数据集的总量。

以上就是如何从csv文件构建Tensorflow的数据集的详细内容,更多关于csv文件构建Tensorflow的数据集的资料请关注三水点靠木其它相关文章!

Python 相关文章推荐
python判断windows隐藏文件的方法
Mar 21 Python
python处理PHP数组文本文件实例
Sep 18 Python
Python常用知识点汇总
May 08 Python
Windows下安装python2和python3多版本教程
Mar 30 Python
Python三级菜单的实例
Sep 13 Python
python删除本地夹里重复文件的方法
Nov 19 Python
Python 实现子类获取父类的类成员方法
Jan 11 Python
Python 元组操作总结
Sep 18 Python
keras 自定义loss model.add_loss的使用详解
Jun 22 Python
Python获取百度热搜的完整代码
Apr 07 Python
Python基础之常用库常用方法整理
Apr 30 Python
利用python调用摄像头的实例分析
Jun 07 Python
python打包多类型文件的操作方法
Sep 21 #Python
python 星号(*)的多种用途
Sep 21 #Python
Python+Selenium随机生成手机验证码并检查页面上是否弹出重复手机号码提示框
Sep 21 #Python
解决PyCharm不在run输出运行结果而不是再Console里输出的问题
Sep 21 #Python
python map比for循环快在哪
Sep 21 #Python
通过实例解析Python文件操作实现步骤
Sep 21 #Python
python Paramiko使用示例
Sep 21 #Python
You might like
可快速识别放射性物质-国外大神教你diy一个开放式辐射探测器
2020/03/12 无线电
PHP实现的汉字拼音转换和公历农历转换类及使用示例
2014/07/01 PHP
php文件读取方法实例分析
2015/06/20 PHP
php使用strip_tags()去除html标签仍有空白的解决方法
2016/07/28 PHP
针对多用户实现头像上传功能PHP代码 适用于登陆页面制作
2016/08/17 PHP
js获取html参数及向swf传递参数应用介绍
2013/02/18 Javascript
jquery利用ajax调用后台方法实例
2013/08/23 Javascript
浅析用prototype定义自己的方法
2013/11/14 Javascript
Ajax局部更新导致JS事件重复触发问题的解决方法
2014/10/14 Javascript
JS中多步骤多分步的StepJump组件实例详解
2016/04/01 Javascript
JS动态遍历json中所有键值对的方法(不知道属性名的情况)
2016/12/28 Javascript
vue slot 在子组件中显示父组件传递的模板
2018/03/02 Javascript
详解在HTTPS 项目中使用百度地图 API
2019/04/26 Javascript
[01:05:41]EG vs Optic Supermajor 败者组 BO3 第二场 6.6
2018/06/07 DOTA
pyqt4教程之widget使用示例分享
2014/03/07 Python
Python+tkinter使用40行代码实现计算器功能
2018/01/30 Python
解决Pycharm下面出现No R interpreter defined的问题
2018/10/29 Python
python 读取竖线分隔符的文本方法
2018/12/20 Python
Python3 执行Linux Bash命令的方法
2019/07/12 Python
浅谈Python3 numpy.ptp()最大值与最小值的差
2019/08/24 Python
安装2019Pycharm最新版本的教程详解
2019/10/22 Python
python实现的分析并统计nginx日志数据功能示例
2019/12/21 Python
python 控制台单行刷新,多行刷新实例
2020/02/19 Python
Python3爬虫发送请求的知识点实例
2020/07/30 Python
用css3实现转换过渡和动画效果
2020/03/13 HTML / CSS
MAC Cosmetics官方网站:魅可专业艺术彩妆
2019/04/10 全球购物
泰海淘:泰国king Power王权免税集团旗下跨境海淘综合型电商
2020/07/26 全球购物
环境科学毕业生自荐信
2013/11/21 职场文书
企业厂长岗位职责
2013/12/17 职场文书
军人违纪检讨书
2014/02/04 职场文书
2014年新生军训方案
2014/05/01 职场文书
北京故宫导游词
2015/01/31 职场文书
美丽的大脚观后感
2015/06/03 职场文书
创业计划书之家政服务
2019/09/18 职场文书
JavaScript流程控制(循环)
2021/12/06 Javascript
windows11选中自动复制怎么开启? Win11自动复制所选内容的方法
2022/07/23 数码科技