如何从csv文件构建Tensorflow的数据集


Posted in Python onSeptember 21, 2020

从csv文件构建Tensorflow的数据集

当我们有一系列CSV文件,如何构建Tensorflow的数据集呢?

基本步骤

  1. 获得一组CSV文件的路径
  2. 将这组文件名,转成文件名对应的dataset => file_dataset
  3. 根据file_dataset中的每个文件名,读取文件内容 生成一个内容的dataset => content_dataset
  4. 这样的多个content_dataset, 拼接起来,形成一整个dataset
  5. 因为读出来的每条记录都是string类型, 所以还需要对每条记录做decode

存在一个这样的变量train_filenames

pprint.pprint(train_filenames)
#	['generate_csv\\train_00.csv',
#	 'generate_csv\\train_01.csv',
#	 'generate_csv\\train_02.csv',
#	 'generate_csv\\train_03.csv',
#	 'generate_csv\\train_04.csv',
#	 'generate_csv\\train_05.csv',
#	 'generate_csv\\train_06.csv',
#	 'generate_csv\\train_07.csv',
#	 'generate_csv\\train_08.csv',
#	 'generate_csv\\train_09.csv',
#	 'generate_csv\\train_10.csv',
#	 'generate_csv\\train_11.csv',
#	 'generate_csv\\train_12.csv',
#	 'generate_csv\\train_13.csv',
#	 'generate_csv\\train_14.csv',
#	 'generate_csv\\train_15.csv',
#	 'generate_csv\\train_16.csv',
#	 'generate_csv\\train_17.csv',
#	 'generate_csv\\train_18.csv',
#	 'generate_csv\\train_19.csv']

接着,我们用提前定义好的API构建文件名数据集file_dataset

filename_dataset = tf.data.Dataset.list_files(train_filenames)
for filename in filename_dataset:
  print(filename)
#tf.Tensor(b'generate_csv\\train_09.csv', shape=(), dtype=string)
#tf.Tensor(b'generate_csv\\train_19.csv', shape=(), dtype=string)
#tf.Tensor(b'generate_csv\\train_03.csv', shape=(), dtype=string)
#tf.Tensor(b'generate_csv\\train_01.csv', shape=(), dtype=string)
#tf.Tensor(b'generate_csv\\train_14.csv', shape=(), dtype=string)
#tf.Tensor(b'generate_csv\\train_17.csv', shape=(), dtype=string)
#tf.Tensor(b'generate_csv\\train_15.csv', shape=(), dtype=string)
#tf.Tensor(b'generate_csv\\train_06.csv', shape=(), dtype=string)
#tf.Tensor(b'generate_csv\\train_05.csv', shape=(), dtype=string)
#tf.Tensor(b'generate_csv\\train_07.csv', shape=(), dtype=string)
#tf.Tensor(b'generate_csv\\train_11.csv', shape=(), dtype=string)
#tf.Tensor(b'generate_csv\\train_02.csv', shape=(), dtype=string)
#tf.Tensor(b'generate_csv\\train_12.csv', shape=(), dtype=string)
#tf.Tensor(b'generate_csv\\train_13.csv', shape=(), dtype=string)
#tf.Tensor(b'generate_csv\\train_10.csv', shape=(), dtype=string)
#tf.Tensor(b'generate_csv\\train_16.csv', shape=(), dtype=string)
#tf.Tensor(b'generate_csv\\train_18.csv', shape=(), dtype=string)
#tf.Tensor(b'generate_csv\\train_00.csv', shape=(), dtype=string)
#tf.Tensor(b'generate_csv\\train_04.csv', shape=(), dtype=string)
#tf.Tensor(b'generate_csv\\train_08.csv', shape=(), dtype=string)

第三步, 根据每个文件名,去读取文件里面的内容

dataset = filename_dataset.interleave(
  lambda filename: tf.data.TextLineDataset(filename).skip(1),
  cycle_length=5
)

for line in dataset.take(3):
  print(line)

#tf.Tensor(b'0.46908349737250216,1.8718193706428006,0.13936365871212536,-0.011055733363841472,-0.6349261778219746,-0.036732316700563934,1.0259470089944995,-1.319095600336748,2.171', shape=(), dtype=string)
#tf.Tensor(b'-1.102093775650278,1.313248890578542,-0.7212003024178728,-0.14707856286537277,0.34720121604358517,0.0965085401826684,-0.74698820254838,0.6810563907247876,1.428', shape=(), dtype=string)
#tf.Tensor(b'-0.8901003715328659,0.9142699762469286,-0.1851678950250224,-0.12947457252940406,0.5958187430364827,-0.021255215877779534,0.7914317693724252,-0.45618713536506217,0.75', shape=(), dtype=string)

interleave的作用可以类比map, 对每个元素应用操作,然后还能把结果合起来。
因此,有了interleave, 我们就把第三四步,一起完成了
之所以skip(1),是因为这个csv第一行是header.
cycle_length是并行化构建数据集的线程数

好,第五步,解析每条记录

def parse_csv_line(line, n_fields=9):
  defaults = [tf.constant(np.nan)] * n_fields
  parsed_fields = tf.io.decode_csv(line, record_defaults=defaults)
  x = tf.stack(parsed_fields[:-1])
  y = tf.stack(parsed_fields[-1:])
  return x, y

parse_csv_line('1.2286258796252256,-1.0806245954111382,0.4444161407754224,-0.0352172575329119,0.9740347681426992,-0.003516079473801425,-0.8126524696425611,0.865609068204283,2.803', 9)

#(<tf.Tensor: shape=(8,), dtype=float32, numpy= array([ 1.2286259 , -1.0806246 , 0.44441614, -0.03521726, 0.9740348 ,-0.00351608, -0.81265247, 0.86560905], dtype=float32)>,<tf.Tensor: shape=(1,), dtype=float32, numpy=array([2.803], dtype=float32)>)

最后,将每条记录都应用这个方法,就完成了构建。

dataset = dataset.map(parse_csv_line)

完整代码

def csv_2_dataset(filenames, n_readers_thread = 5, batch_size = 32, n_parse_thread = 5, shuffle_buffer_size = 10000):
  
  dataset = tf.data.Dataset.list_files(filenames)
  dataset = dataset.repeat()
  dataset = dataset.interleave(
    lambda filename: tf.data.TextLineDataset(filename).skip(1),
    cycle_length=n_readers_thread
  )
  dataset.shuffle(shuffle_buffer_size)
  dataset = dataset.map(parse_csv_line, num_parallel_calls = n_parse_thread)
  dataset = dataset.batch(batch_size)
  return dataset

如何使用

train_dataset = csv_2_dataset(train_filenames, batch_size=32)
valid_dataset = csv_2_dataset(valid_filenames, batch_size=32)

model = ...

model.fit(train_set, validation_data=valid_set, 
          steps_per_epoch = 11610 // 32,
          validation_steps = 3870 // 32,
          epochs=100, callbacks=callbacks)

这里的11610 和 3870是什么?

这是train_dataset 和 valid_dataset中数据的数量,需要在训练中手动指定每个batch中参与训练的数据的多少。

model.evaluate(test_set, steps=5160//32)

同理,测试的时候,使用这样的数据集,也需要手动指定。
5160是测试数据集的总量。

以上就是如何从csv文件构建Tensorflow的数据集的详细内容,更多关于csv文件构建Tensorflow的数据集的资料请关注三水点靠木其它相关文章!

Python 相关文章推荐
python获得一个月有多少天的方法
Jun 04 Python
python语言使用技巧分享
May 31 Python
Django实现快速分页的方法实例
Oct 22 Python
Python常见字典内建函数用法示例
May 14 Python
基于Django框架的权限组件rbac实例讲解
Aug 31 Python
解决python彩色螺旋线绘制引发的问题
Nov 23 Python
django列表筛选功能的实现代码
Mar 27 Python
python和go语言的区别是什么
Jul 20 Python
一文带你了解Python 四种常见基础爬虫方法介绍
Dec 04 Python
python 实现两个变量值进行交换的n种操作
Jun 02 Python
python数据可视化使用pyfinance分析证券收益示例详解
Nov 20 Python
python 实现图片特效处理
Apr 03 Python
python打包多类型文件的操作方法
Sep 21 #Python
python 星号(*)的多种用途
Sep 21 #Python
Python+Selenium随机生成手机验证码并检查页面上是否弹出重复手机号码提示框
Sep 21 #Python
解决PyCharm不在run输出运行结果而不是再Console里输出的问题
Sep 21 #Python
python map比for循环快在哪
Sep 21 #Python
通过实例解析Python文件操作实现步骤
Sep 21 #Python
python Paramiko使用示例
Sep 21 #Python
You might like
浅谈php提交form表单
2015/07/01 PHP
PHP代码重构方法漫谈
2018/04/17 PHP
PHP标准库(PHP SPL)详解
2019/03/16 PHP
PHP基于进程控制函数实现多线程
2020/12/09 PHP
jquery 图片截取工具jquery.imagecropper.js
2010/04/09 Javascript
Javascript中的变量使用说明
2010/05/18 Javascript
ExtJS PropertyGrid中使用Combobox选择值问题
2010/06/13 Javascript
jQuery lazyload 的重复加载错误以及修复方法
2010/11/19 Javascript
JS 无限级 Select效果实现代码(json格式)
2011/08/30 Javascript
Jquery插件仿百度搜索关键字自动匹配功能
2016/05/11 Javascript
jQuery DateTimePicker 日期和时间插件示例
2017/01/22 Javascript
关于vue.js组件数据流的问题
2017/07/26 Javascript
9种改善AngularJS性能的方法
2017/11/28 Javascript
使用 vue.js 构建大型单页应用
2018/02/10 Javascript
在vue里面设置全局变量或数据的方法
2018/03/09 Javascript
详解如何webpack使用DllPlugin
2018/09/30 Javascript
小程序hover-class点击态效果实现
2019/02/26 Javascript
javascript实现小型区块链功能
2019/04/03 Javascript
微信小程序页面间跳转传参方式总结
2019/06/13 Javascript
Vue的编码技巧与规范使用详解
2019/08/28 Javascript
关于Vue中axios的封装实例详解
2019/10/20 Javascript
[54:51]Ti4 冒泡赛第二轮LGD vs C9 3
2014/07/14 DOTA
Python装饰器用法实例总结
2018/02/07 Python
基于python-pptx库中文文档及使用详解
2020/02/14 Python
北美领先的智能产品购物网站:Wellbots
2018/06/11 全球购物
澳大利亚波西米亚风情网上商店:Czarina
2019/03/18 全球购物
周生生珠宝香港官网:Chow Sang Sang(香港及海外配送)
2019/09/05 全球购物
英国最大的滑板品牌选择:Route One
2019/09/22 全球购物
Under Armour安德玛意大利官网:美国高端运动科技品牌
2020/01/16 全球购物
长青弘远的面试题
2012/06/09 面试题
什么是继承
2013/12/07 面试题
大学生标准自荐书
2014/06/15 职场文书
农业局党的群众路线教育实践活动整改方案
2014/09/20 职场文书
2014年流动人口工作总结
2014/11/26 职场文书
新生入学欢迎词
2015/01/26 职场文书
MySQ InnoDB和MyISAM存储引擎介绍
2022/04/26 MySQL