如何从csv文件构建Tensorflow的数据集


Posted in Python onSeptember 21, 2020

从csv文件构建Tensorflow的数据集

当我们有一系列CSV文件,如何构建Tensorflow的数据集呢?

基本步骤

  1. 获得一组CSV文件的路径
  2. 将这组文件名,转成文件名对应的dataset => file_dataset
  3. 根据file_dataset中的每个文件名,读取文件内容 生成一个内容的dataset => content_dataset
  4. 这样的多个content_dataset, 拼接起来,形成一整个dataset
  5. 因为读出来的每条记录都是string类型, 所以还需要对每条记录做decode

存在一个这样的变量train_filenames

pprint.pprint(train_filenames)
#	['generate_csv\\train_00.csv',
#	 'generate_csv\\train_01.csv',
#	 'generate_csv\\train_02.csv',
#	 'generate_csv\\train_03.csv',
#	 'generate_csv\\train_04.csv',
#	 'generate_csv\\train_05.csv',
#	 'generate_csv\\train_06.csv',
#	 'generate_csv\\train_07.csv',
#	 'generate_csv\\train_08.csv',
#	 'generate_csv\\train_09.csv',
#	 'generate_csv\\train_10.csv',
#	 'generate_csv\\train_11.csv',
#	 'generate_csv\\train_12.csv',
#	 'generate_csv\\train_13.csv',
#	 'generate_csv\\train_14.csv',
#	 'generate_csv\\train_15.csv',
#	 'generate_csv\\train_16.csv',
#	 'generate_csv\\train_17.csv',
#	 'generate_csv\\train_18.csv',
#	 'generate_csv\\train_19.csv']

接着,我们用提前定义好的API构建文件名数据集file_dataset

filename_dataset = tf.data.Dataset.list_files(train_filenames)
for filename in filename_dataset:
  print(filename)
#tf.Tensor(b'generate_csv\\train_09.csv', shape=(), dtype=string)
#tf.Tensor(b'generate_csv\\train_19.csv', shape=(), dtype=string)
#tf.Tensor(b'generate_csv\\train_03.csv', shape=(), dtype=string)
#tf.Tensor(b'generate_csv\\train_01.csv', shape=(), dtype=string)
#tf.Tensor(b'generate_csv\\train_14.csv', shape=(), dtype=string)
#tf.Tensor(b'generate_csv\\train_17.csv', shape=(), dtype=string)
#tf.Tensor(b'generate_csv\\train_15.csv', shape=(), dtype=string)
#tf.Tensor(b'generate_csv\\train_06.csv', shape=(), dtype=string)
#tf.Tensor(b'generate_csv\\train_05.csv', shape=(), dtype=string)
#tf.Tensor(b'generate_csv\\train_07.csv', shape=(), dtype=string)
#tf.Tensor(b'generate_csv\\train_11.csv', shape=(), dtype=string)
#tf.Tensor(b'generate_csv\\train_02.csv', shape=(), dtype=string)
#tf.Tensor(b'generate_csv\\train_12.csv', shape=(), dtype=string)
#tf.Tensor(b'generate_csv\\train_13.csv', shape=(), dtype=string)
#tf.Tensor(b'generate_csv\\train_10.csv', shape=(), dtype=string)
#tf.Tensor(b'generate_csv\\train_16.csv', shape=(), dtype=string)
#tf.Tensor(b'generate_csv\\train_18.csv', shape=(), dtype=string)
#tf.Tensor(b'generate_csv\\train_00.csv', shape=(), dtype=string)
#tf.Tensor(b'generate_csv\\train_04.csv', shape=(), dtype=string)
#tf.Tensor(b'generate_csv\\train_08.csv', shape=(), dtype=string)

第三步, 根据每个文件名,去读取文件里面的内容

dataset = filename_dataset.interleave(
  lambda filename: tf.data.TextLineDataset(filename).skip(1),
  cycle_length=5
)

for line in dataset.take(3):
  print(line)

#tf.Tensor(b'0.46908349737250216,1.8718193706428006,0.13936365871212536,-0.011055733363841472,-0.6349261778219746,-0.036732316700563934,1.0259470089944995,-1.319095600336748,2.171', shape=(), dtype=string)
#tf.Tensor(b'-1.102093775650278,1.313248890578542,-0.7212003024178728,-0.14707856286537277,0.34720121604358517,0.0965085401826684,-0.74698820254838,0.6810563907247876,1.428', shape=(), dtype=string)
#tf.Tensor(b'-0.8901003715328659,0.9142699762469286,-0.1851678950250224,-0.12947457252940406,0.5958187430364827,-0.021255215877779534,0.7914317693724252,-0.45618713536506217,0.75', shape=(), dtype=string)

interleave的作用可以类比map, 对每个元素应用操作,然后还能把结果合起来。
因此,有了interleave, 我们就把第三四步,一起完成了
之所以skip(1),是因为这个csv第一行是header.
cycle_length是并行化构建数据集的线程数

好,第五步,解析每条记录

def parse_csv_line(line, n_fields=9):
  defaults = [tf.constant(np.nan)] * n_fields
  parsed_fields = tf.io.decode_csv(line, record_defaults=defaults)
  x = tf.stack(parsed_fields[:-1])
  y = tf.stack(parsed_fields[-1:])
  return x, y

parse_csv_line('1.2286258796252256,-1.0806245954111382,0.4444161407754224,-0.0352172575329119,0.9740347681426992,-0.003516079473801425,-0.8126524696425611,0.865609068204283,2.803', 9)

#(<tf.Tensor: shape=(8,), dtype=float32, numpy= array([ 1.2286259 , -1.0806246 , 0.44441614, -0.03521726, 0.9740348 ,-0.00351608, -0.81265247, 0.86560905], dtype=float32)>,<tf.Tensor: shape=(1,), dtype=float32, numpy=array([2.803], dtype=float32)>)

最后,将每条记录都应用这个方法,就完成了构建。

dataset = dataset.map(parse_csv_line)

完整代码

def csv_2_dataset(filenames, n_readers_thread = 5, batch_size = 32, n_parse_thread = 5, shuffle_buffer_size = 10000):
  
  dataset = tf.data.Dataset.list_files(filenames)
  dataset = dataset.repeat()
  dataset = dataset.interleave(
    lambda filename: tf.data.TextLineDataset(filename).skip(1),
    cycle_length=n_readers_thread
  )
  dataset.shuffle(shuffle_buffer_size)
  dataset = dataset.map(parse_csv_line, num_parallel_calls = n_parse_thread)
  dataset = dataset.batch(batch_size)
  return dataset

如何使用

train_dataset = csv_2_dataset(train_filenames, batch_size=32)
valid_dataset = csv_2_dataset(valid_filenames, batch_size=32)

model = ...

model.fit(train_set, validation_data=valid_set, 
          steps_per_epoch = 11610 // 32,
          validation_steps = 3870 // 32,
          epochs=100, callbacks=callbacks)

这里的11610 和 3870是什么?

这是train_dataset 和 valid_dataset中数据的数量,需要在训练中手动指定每个batch中参与训练的数据的多少。

model.evaluate(test_set, steps=5160//32)

同理,测试的时候,使用这样的数据集,也需要手动指定。
5160是测试数据集的总量。

以上就是如何从csv文件构建Tensorflow的数据集的详细内容,更多关于csv文件构建Tensorflow的数据集的资料请关注三水点靠木其它相关文章!

Python 相关文章推荐
Python编程中的文件读写及相关的文件对象方法讲解
Jan 19 Python
pygame实现弹力球及其变速效果
Jul 03 Python
Python实现获取照片拍摄日期并重命名的方法
Sep 30 Python
mac下给python3安装requests库和scrapy库的实例
Jun 13 Python
python 通过类中一个方法获取另一个方法变量的实例
Jan 22 Python
Python Selenium 之关闭窗口close与quit的方法
Feb 13 Python
pandas DataFrame 行列索引及值的获取的方法
Jul 02 Python
python range实例用法分享
Feb 06 Python
4款Python 类型检查工具,你选择哪个呢?
Oct 30 Python
django注册用邮箱发送验证码的实现
Apr 18 Python
python字典进行运算原理及实例分享
Aug 02 Python
详解Python中__new__方法的作用
Mar 31 Python
python打包多类型文件的操作方法
Sep 21 #Python
python 星号(*)的多种用途
Sep 21 #Python
Python+Selenium随机生成手机验证码并检查页面上是否弹出重复手机号码提示框
Sep 21 #Python
解决PyCharm不在run输出运行结果而不是再Console里输出的问题
Sep 21 #Python
python map比for循环快在哪
Sep 21 #Python
通过实例解析Python文件操作实现步骤
Sep 21 #Python
python Paramiko使用示例
Sep 21 #Python
You might like
PHP中对数据库操作的封装
2006/10/09 PHP
给初学者的30条PHP最佳实践(荒野无灯)
2011/08/02 PHP
PHP实现UTF-8文件BOM自动检测与移除实例
2014/11/05 PHP
6个超实用的PHP代码片段
2015/08/10 PHP
JavaScript自定义DateDiff函数(兼容所有浏览器)
2012/03/01 Javascript
javascript改变position值实现菜单滚动至顶部后固定
2013/01/18 Javascript
js setTimeout()函数介绍及应用以倒计时为例
2013/12/12 Javascript
ie8下修改input的type属性报错的解决方法
2014/09/16 Javascript
JS实现常见的TAB、弹出层效果(TAB标签,斑马线,遮罩层等)
2015/10/08 Javascript
Node.js中JavaScript操作MySQL的常用方法整理
2016/03/01 Javascript
详谈jQuery unbind 删除绑定事件 / 移除标签方法
2017/03/02 Javascript
Angular2下使用pdf插件的方法详解
2017/04/29 Javascript
webpack学习教程之publicPath路径问题详解
2017/06/17 Javascript
nodejs中安装ghost出错的原因及解决方法
2017/10/23 NodeJs
微信小程序 多行文本显示...+显示更多按钮和收起更多按钮功能
2019/09/26 Javascript
使用JavaScript计算前一天和后一天的思路详解
2019/12/20 Javascript
vue在线动态切换主题色方案
2020/03/26 Javascript
vue.js中使用微信扫一扫解决invalid signature问题(完美解决)
2020/04/11 Javascript
react使用CSS实现react动画功能示例
2020/05/18 Javascript
解决vue路由name同名,路由重复的问题
2020/08/05 Javascript
JavaScript array常用方法代码实例详解
2020/09/02 Javascript
探究Python中isalnum()方法的使用
2015/05/18 Python
深入解读Python解析XML的几种方式
2016/02/16 Python
详解Python读取配置文件模块ConfigParser
2017/05/11 Python
解决新django中的path不能使用正则表达式的问题
2018/12/18 Python
Python使用正则表达式分割字符串的实现方法
2019/07/16 Python
利用Python实现kNN算法的代码
2019/08/16 Python
Python descriptor(描述符)的实现
2020/11/15 Python
一款利用css3的鼠标经过动画显示详情特效的实例教程
2014/12/29 HTML / CSS
师范应届生教师求职信
2013/11/05 职场文书
统计学专业毕业生的自我评价分享
2013/11/28 职场文书
委托协议书范本
2014/04/22 职场文书
高等学院职业生涯规划书范文
2014/09/16 职场文书
个人承诺书格式范文
2015/04/29 职场文书
2015婚礼主持词开场白
2015/05/28 职场文书
学校教代会开幕词
2016/03/04 职场文书