如何从csv文件构建Tensorflow的数据集


Posted in Python onSeptember 21, 2020

从csv文件构建Tensorflow的数据集

当我们有一系列CSV文件,如何构建Tensorflow的数据集呢?

基本步骤

  1. 获得一组CSV文件的路径
  2. 将这组文件名,转成文件名对应的dataset => file_dataset
  3. 根据file_dataset中的每个文件名,读取文件内容 生成一个内容的dataset => content_dataset
  4. 这样的多个content_dataset, 拼接起来,形成一整个dataset
  5. 因为读出来的每条记录都是string类型, 所以还需要对每条记录做decode

存在一个这样的变量train_filenames

pprint.pprint(train_filenames)
#	['generate_csv\\train_00.csv',
#	 'generate_csv\\train_01.csv',
#	 'generate_csv\\train_02.csv',
#	 'generate_csv\\train_03.csv',
#	 'generate_csv\\train_04.csv',
#	 'generate_csv\\train_05.csv',
#	 'generate_csv\\train_06.csv',
#	 'generate_csv\\train_07.csv',
#	 'generate_csv\\train_08.csv',
#	 'generate_csv\\train_09.csv',
#	 'generate_csv\\train_10.csv',
#	 'generate_csv\\train_11.csv',
#	 'generate_csv\\train_12.csv',
#	 'generate_csv\\train_13.csv',
#	 'generate_csv\\train_14.csv',
#	 'generate_csv\\train_15.csv',
#	 'generate_csv\\train_16.csv',
#	 'generate_csv\\train_17.csv',
#	 'generate_csv\\train_18.csv',
#	 'generate_csv\\train_19.csv']

接着,我们用提前定义好的API构建文件名数据集file_dataset

filename_dataset = tf.data.Dataset.list_files(train_filenames)
for filename in filename_dataset:
  print(filename)
#tf.Tensor(b'generate_csv\\train_09.csv', shape=(), dtype=string)
#tf.Tensor(b'generate_csv\\train_19.csv', shape=(), dtype=string)
#tf.Tensor(b'generate_csv\\train_03.csv', shape=(), dtype=string)
#tf.Tensor(b'generate_csv\\train_01.csv', shape=(), dtype=string)
#tf.Tensor(b'generate_csv\\train_14.csv', shape=(), dtype=string)
#tf.Tensor(b'generate_csv\\train_17.csv', shape=(), dtype=string)
#tf.Tensor(b'generate_csv\\train_15.csv', shape=(), dtype=string)
#tf.Tensor(b'generate_csv\\train_06.csv', shape=(), dtype=string)
#tf.Tensor(b'generate_csv\\train_05.csv', shape=(), dtype=string)
#tf.Tensor(b'generate_csv\\train_07.csv', shape=(), dtype=string)
#tf.Tensor(b'generate_csv\\train_11.csv', shape=(), dtype=string)
#tf.Tensor(b'generate_csv\\train_02.csv', shape=(), dtype=string)
#tf.Tensor(b'generate_csv\\train_12.csv', shape=(), dtype=string)
#tf.Tensor(b'generate_csv\\train_13.csv', shape=(), dtype=string)
#tf.Tensor(b'generate_csv\\train_10.csv', shape=(), dtype=string)
#tf.Tensor(b'generate_csv\\train_16.csv', shape=(), dtype=string)
#tf.Tensor(b'generate_csv\\train_18.csv', shape=(), dtype=string)
#tf.Tensor(b'generate_csv\\train_00.csv', shape=(), dtype=string)
#tf.Tensor(b'generate_csv\\train_04.csv', shape=(), dtype=string)
#tf.Tensor(b'generate_csv\\train_08.csv', shape=(), dtype=string)

第三步, 根据每个文件名,去读取文件里面的内容

dataset = filename_dataset.interleave(
  lambda filename: tf.data.TextLineDataset(filename).skip(1),
  cycle_length=5
)

for line in dataset.take(3):
  print(line)

#tf.Tensor(b'0.46908349737250216,1.8718193706428006,0.13936365871212536,-0.011055733363841472,-0.6349261778219746,-0.036732316700563934,1.0259470089944995,-1.319095600336748,2.171', shape=(), dtype=string)
#tf.Tensor(b'-1.102093775650278,1.313248890578542,-0.7212003024178728,-0.14707856286537277,0.34720121604358517,0.0965085401826684,-0.74698820254838,0.6810563907247876,1.428', shape=(), dtype=string)
#tf.Tensor(b'-0.8901003715328659,0.9142699762469286,-0.1851678950250224,-0.12947457252940406,0.5958187430364827,-0.021255215877779534,0.7914317693724252,-0.45618713536506217,0.75', shape=(), dtype=string)

interleave的作用可以类比map, 对每个元素应用操作,然后还能把结果合起来。
因此,有了interleave, 我们就把第三四步,一起完成了
之所以skip(1),是因为这个csv第一行是header.
cycle_length是并行化构建数据集的线程数

好,第五步,解析每条记录

def parse_csv_line(line, n_fields=9):
  defaults = [tf.constant(np.nan)] * n_fields
  parsed_fields = tf.io.decode_csv(line, record_defaults=defaults)
  x = tf.stack(parsed_fields[:-1])
  y = tf.stack(parsed_fields[-1:])
  return x, y

parse_csv_line('1.2286258796252256,-1.0806245954111382,0.4444161407754224,-0.0352172575329119,0.9740347681426992,-0.003516079473801425,-0.8126524696425611,0.865609068204283,2.803', 9)

#(<tf.Tensor: shape=(8,), dtype=float32, numpy= array([ 1.2286259 , -1.0806246 , 0.44441614, -0.03521726, 0.9740348 ,-0.00351608, -0.81265247, 0.86560905], dtype=float32)>,<tf.Tensor: shape=(1,), dtype=float32, numpy=array([2.803], dtype=float32)>)

最后,将每条记录都应用这个方法,就完成了构建。

dataset = dataset.map(parse_csv_line)

完整代码

def csv_2_dataset(filenames, n_readers_thread = 5, batch_size = 32, n_parse_thread = 5, shuffle_buffer_size = 10000):
  
  dataset = tf.data.Dataset.list_files(filenames)
  dataset = dataset.repeat()
  dataset = dataset.interleave(
    lambda filename: tf.data.TextLineDataset(filename).skip(1),
    cycle_length=n_readers_thread
  )
  dataset.shuffle(shuffle_buffer_size)
  dataset = dataset.map(parse_csv_line, num_parallel_calls = n_parse_thread)
  dataset = dataset.batch(batch_size)
  return dataset

如何使用

train_dataset = csv_2_dataset(train_filenames, batch_size=32)
valid_dataset = csv_2_dataset(valid_filenames, batch_size=32)

model = ...

model.fit(train_set, validation_data=valid_set, 
          steps_per_epoch = 11610 // 32,
          validation_steps = 3870 // 32,
          epochs=100, callbacks=callbacks)

这里的11610 和 3870是什么?

这是train_dataset 和 valid_dataset中数据的数量,需要在训练中手动指定每个batch中参与训练的数据的多少。

model.evaluate(test_set, steps=5160//32)

同理,测试的时候,使用这样的数据集,也需要手动指定。
5160是测试数据集的总量。

以上就是如何从csv文件构建Tensorflow的数据集的详细内容,更多关于csv文件构建Tensorflow的数据集的资料请关注三水点靠木其它相关文章!

Python 相关文章推荐
python批量修改文件后缀示例代码分享
Dec 24 Python
Python聚类算法之基本K均值实例详解
Nov 20 Python
谈谈如何手动释放Python的内存
Dec 17 Python
浅谈python中列表、字符串、字典的常用操作
Sep 19 Python
python3基于OpenCV实现证件照背景替换
Jul 18 Python
python实现本地图片转存并重命名的示例代码
Oct 27 Python
一文带你了解Python中的字符串是什么
Nov 20 Python
浅谈Pycharm中的Python Console与Terminal
Jan 17 Python
树莓派极简安装OpenCv的方法步骤
Oct 10 Python
python如何获取apk的packagename和activity
Jan 10 Python
使用python的turtle函数绘制一个滑稽表情
Feb 28 Python
解决pyecharts运行后产生的html文件用浏览器打开空白
Mar 11 Python
python打包多类型文件的操作方法
Sep 21 #Python
python 星号(*)的多种用途
Sep 21 #Python
Python+Selenium随机生成手机验证码并检查页面上是否弹出重复手机号码提示框
Sep 21 #Python
解决PyCharm不在run输出运行结果而不是再Console里输出的问题
Sep 21 #Python
python map比for循环快在哪
Sep 21 #Python
通过实例解析Python文件操作实现步骤
Sep 21 #Python
python Paramiko使用示例
Sep 21 #Python
You might like
解析PHP汉字转换拼音的类
2013/06/18 PHP
JavaScript 数组的 uniq 方法
2008/01/23 Javascript
JS控制日期显示的小例子
2013/11/23 Javascript
js代码实现的加入收藏效果并兼容主流浏览器
2014/06/23 Javascript
jQuery实现区域打印功能代码详解
2016/06/17 Javascript
浅谈JS的基础类型与引用类型
2016/09/13 Javascript
Vue.JS入门教程之自定义指令
2016/12/08 Javascript
简单的jQuery拖拽排序效果的实现(增强动态)
2017/02/09 Javascript
layui分页效果实现代码
2017/05/19 Javascript
ES6 系列之 Generator 的自动执行的方法示例
2018/10/19 Javascript
微信二次分享报错invalid signature问题及解决方法
2019/04/01 Javascript
Vue CLI2升级至Vue CLI3的方法步骤
2019/05/20 Javascript
微信小程序个人中心的列表控件实现代码
2020/04/26 Javascript
[01:02:03]2014 DOTA2华西杯精英邀请赛 5 24 NewBee VS VG
2014/05/26 DOTA
python将html转成PDF的实现代码(包含中文)
2013/03/04 Python
Python实现求笛卡尔乘积的方法
2017/09/16 Python
python的中异常处理机制
2018/08/30 Python
Python编程深度学习计算库之numpy
2018/12/28 Python
python抓取多种类型的页面方法实例
2019/11/20 Python
最新pycharm安装教程
2020/11/18 Python
完美解决torch.cuda.is_available()一直返回False的玄学方法
2021/02/06 Python
CSS3中的@keyframes关键帧动画的选择器绑定
2016/06/13 HTML / CSS
美国主要的特色咖啡和茶公司:Peet’s Coffee
2020/02/14 全球购物
Eclipse面试题
2014/03/22 面试题
农民入党思想汇报
2014/01/03 职场文书
如何写好优秀的创业计划书
2014/01/30 职场文书
护士的自我鉴定
2014/02/07 职场文书
2014年人民警察入党思想汇报
2014/10/12 职场文书
英语教师个人总结
2015/02/09 职场文书
防汛通知
2015/04/25 职场文书
2015年加油站工作总结
2015/05/13 职场文书
亲戚关系证明
2015/06/24 职场文书
工程主管竞聘书
2015/09/15 职场文书
实习报告范文之电话客服岗位
2019/07/26 职场文书
Css预编语言及区别详解
2021/04/25 HTML / CSS
Python必备技巧之字符数据操作详解
2022/03/23 Python