浅谈tensorflow中dataset.shuffle和dataset.batch dataset.repeat注意点


Posted in Python onJune 08, 2020

batch很好理解,就是batch size。注意在一个epoch中最后一个batch大小可能小于等于batch size

dataset.repeat就是俗称epoch,但在tf中与dataset.shuffle的使用顺序可能会导致个epoch的混合

dataset.shuffle就是说维持一个buffer size 大小的 shuffle buffer,图中所需的每个样本从shuffle buffer中获取,取得一个样本后,就从源数据集中加入一个样本到shuffle buffer中。

import os
os.environ['CUDA_VISIBLE_DEVICES'] = ""
import numpy as np
import tensorflow as tf
np.random.seed(0)
x = np.random.sample((11,2))
# make a dataset from a numpy array
print(x)
print()
dataset = tf.data.Dataset.from_tensor_slices(x)
dataset = dataset.shuffle(3)
dataset = dataset.batch(4)
dataset = dataset.repeat(2)

# create the iterator
iter = dataset.make_one_shot_iterator()
el = iter.get_next()

with tf.Session() as sess:
  print(sess.run(el))
  print(sess.run(el))
  print(sess.run(el))
  print(sess.run(el))
  print(sess.run(el))
  print(sess.run(el))
  print(sess.run(el))
  print(sess.run(el))
  print(sess.run(el))
  print(sess.run(el))
  print(sess.run(el))
  print(sess.run(el))
  print(sess.run(el))
#源数据集
[[ 0.5488135  0.71518937]
 [ 0.60276338 0.54488318]
 [ 0.4236548  0.64589411]
 [ 0.43758721 0.891773 ]
 [ 0.96366276 0.38344152]
 [ 0.79172504 0.52889492]
 [ 0.56804456 0.92559664]
 [ 0.07103606 0.0871293 ]
 [ 0.0202184  0.83261985]
 [ 0.77815675 0.87001215]
 [ 0.97861834 0.79915856]]

# 通过shuffle batch后取得的样本
[[ 0.4236548  0.64589411]
 [ 0.60276338 0.54488318]
 [ 0.43758721 0.891773 ]
 [ 0.5488135  0.71518937]]
[[ 0.96366276 0.38344152]
 [ 0.56804456 0.92559664]
 [ 0.0202184  0.83261985]
 [ 0.79172504 0.52889492]]
[[ 0.07103606 0.0871293 ]
 [ 0.97861834 0.79915856]
 [ 0.77815675 0.87001215]] #最后一个batch样本个数为3
[[ 0.60276338 0.54488318]
 [ 0.5488135  0.71518937]
 [ 0.43758721 0.891773 ]
 [ 0.79172504 0.52889492]]
[[ 0.4236548  0.64589411]
 [ 0.56804456 0.92559664]
 [ 0.0202184  0.83261985]
 [ 0.07103606 0.0871293 ]]
[[ 0.77815675 0.87001215]
 [ 0.96366276 0.38344152]
 [ 0.97861834 0.79915856]] #最后一个batch样本个数为3

1、按照shuffle中设置的buffer size,首先从源数据集取得三个样本:
shuffle buffer:
[ 0.5488135 0.71518937]
[ 0.60276338 0.54488318]
[ 0.4236548 0.64589411]
2、从buffer中取一个样本到batch中得:
shuffle buffer:
[ 0.5488135 0.71518937]
[ 0.60276338 0.54488318]
batch:
[ 0.4236548 0.64589411]
3、shuffle buffer不足三个样本,从源数据集提取一个样本:
shuffle buffer:
[ 0.5488135 0.71518937]
[ 0.60276338 0.54488318]
[ 0.43758721 0.891773 ]
4、从buffer中取一个样本到batch中得:
shuffle buffer:
[ 0.5488135 0.71518937]
[ 0.43758721 0.891773 ]
batch:
[ 0.4236548 0.64589411]
[ 0.60276338 0.54488318]
5、如此反复。这就意味中如果shuffle 的buffer size=1,数据集不打乱。如果shuffle 的buffer size=数据集样本数量,随机打乱整个数据集

import os
os.environ['CUDA_VISIBLE_DEVICES'] = ""
import numpy as np
import tensorflow as tf
np.random.seed(0)
x = np.random.sample((11,2))
# make a dataset from a numpy array
print(x)
print()
dataset = tf.data.Dataset.from_tensor_slices(x)
dataset = dataset.shuffle(1)
dataset = dataset.batch(4)
dataset = dataset.repeat(2)

# create the iterator
iter = dataset.make_one_shot_iterator()
el = iter.get_next()

with tf.Session() as sess:
  print(sess.run(el))
  print(sess.run(el))
  print(sess.run(el))
  print(sess.run(el))
  print(sess.run(el))
  print(sess.run(el))
  print(sess.run(el))
  print(sess.run(el))
  print(sess.run(el))
  print(sess.run(el))
  print(sess.run(el))
  print(sess.run(el))
  print(sess.run(el))

[[ 0.5488135  0.71518937]
 [ 0.60276338 0.54488318]
 [ 0.4236548  0.64589411]
 [ 0.43758721 0.891773 ]
 [ 0.96366276 0.38344152]
 [ 0.79172504 0.52889492]
 [ 0.56804456 0.92559664]
 [ 0.07103606 0.0871293 ]
 [ 0.0202184  0.83261985]
 [ 0.77815675 0.87001215]
 [ 0.97861834 0.79915856]]

[[ 0.5488135  0.71518937]
 [ 0.60276338 0.54488318]
 [ 0.4236548  0.64589411]
 [ 0.43758721 0.891773 ]]
[[ 0.96366276 0.38344152]
 [ 0.79172504 0.52889492]
 [ 0.56804456 0.92559664]
 [ 0.07103606 0.0871293 ]]
[[ 0.0202184  0.83261985]
 [ 0.77815675 0.87001215]
 [ 0.97861834 0.79915856]]
[[ 0.5488135  0.71518937]
 [ 0.60276338 0.54488318]
 [ 0.4236548  0.64589411]
 [ 0.43758721 0.891773 ]]
[[ 0.96366276 0.38344152]
 [ 0.79172504 0.52889492]
 [ 0.56804456 0.92559664]
 [ 0.07103606 0.0871293 ]]
[[ 0.0202184  0.83261985]
 [ 0.77815675 0.87001215]
 [ 0.97861834 0.79915856]]

注意如果repeat在shuffle之前使用:

官方说repeat在shuffle之前使用能提高性能,但模糊了数据样本的epoch关系

import os
os.environ['CUDA_VISIBLE_DEVICES'] = ""
import numpy as np
import tensorflow as tf
np.random.seed(0)
x = np.random.sample((11,2))
# make a dataset from a numpy array
print(x)
print()
dataset = tf.data.Dataset.from_tensor_slices(x)
dataset = dataset.repeat(2)
dataset = dataset.shuffle(11)
dataset = dataset.batch(4)

# create the iterator
iter = dataset.make_one_shot_iterator()
el = iter.get_next()

with tf.Session() as sess:
  print(sess.run(el))
  print(sess.run(el))
  print(sess.run(el))
  print(sess.run(el))
  print(sess.run(el))
  print(sess.run(el))
  print(sess.run(el))
  print(sess.run(el))
  print(sess.run(el))
  print(sess.run(el))
  print(sess.run(el))
  print(sess.run(el))
  print(sess.run(el))

[[ 0.5488135  0.71518937]
 [ 0.60276338 0.54488318]
 [ 0.4236548  0.64589411]
 [ 0.43758721 0.891773 ]
 [ 0.96366276 0.38344152]
 [ 0.79172504 0.52889492]
 [ 0.56804456 0.92559664]
 [ 0.07103606 0.0871293 ]
 [ 0.0202184  0.83261985]
 [ 0.77815675 0.87001215]
 [ 0.97861834 0.79915856]]

[[ 0.56804456 0.92559664]
 [ 0.5488135  0.71518937]
 [ 0.60276338 0.54488318]
 [ 0.07103606 0.0871293 ]]
[[ 0.96366276 0.38344152]
 [ 0.43758721 0.891773 ]
 [ 0.43758721 0.891773 ]
 [ 0.77815675 0.87001215]]
[[ 0.79172504 0.52889492]  #出现相同样本出现在同一个batch中
 [ 0.79172504 0.52889492]
 [ 0.60276338 0.54488318]
 [ 0.4236548  0.64589411]]
[[ 0.07103606 0.0871293 ]
 [ 0.4236548  0.64589411]
 [ 0.96366276 0.38344152]
 [ 0.5488135  0.71518937]]
[[ 0.97861834 0.79915856]
 [ 0.0202184  0.83261985]
 [ 0.77815675 0.87001215]
 [ 0.56804456 0.92559664]]
[[ 0.0202184  0.83261985]
 [ 0.97861834 0.79915856]]     #可以看到最后个batch为2,而前面都是4

使用案例:

def input_fn(filenames, batch_size=32, num_epochs=1, perform_shuffle=False):
  print('Parsing', filenames)
  def decode_libsvm(line):
    #columns = tf.decode_csv(value, record_defaults=CSV_COLUMN_DEFAULTS)
    #features = dict(zip(CSV_COLUMNS, columns))
    #labels = features.pop(LABEL_COLUMN)
    columns = tf.string_split([line], ' ')
    labels = tf.string_to_number(columns.values[0], out_type=tf.float32)
    splits = tf.string_split(columns.values[1:], ':')
    id_vals = tf.reshape(splits.values,splits.dense_shape)
    feat_ids, feat_vals = tf.split(id_vals,num_or_size_splits=2,axis=1)
    feat_ids = tf.string_to_number(feat_ids, out_type=tf.int32)
    feat_vals = tf.string_to_number(feat_vals, out_type=tf.float32)
    #feat_ids = tf.reshape(feat_ids,shape=[-1,FLAGS.field_size])
    #for i in range(splits.dense_shape.eval()[0]):
    #  feat_ids.append(tf.string_to_number(splits.values[2*i], out_type=tf.int32))
    #  feat_vals.append(tf.string_to_number(splits.values[2*i+1]))
    #return tf.reshape(feat_ids,shape=[-1,field_size]), tf.reshape(feat_vals,shape=[-1,field_size]), labels
    return {"feat_ids": feat_ids, "feat_vals": feat_vals}, labels

  # Extract lines from input files using the Dataset API, can pass one filename or filename list
  dataset = tf.data.TextLineDataset(filenames).map(decode_libsvm, num_parallel_calls=10).prefetch(500000)  # multi-thread pre-process then prefetch

  # Randomizes input using a window of 256 elements (read into memory)
  if perform_shuffle:
    dataset = dataset.shuffle(buffer_size=256)

  # epochs from blending together.
  dataset = dataset.repeat(num_epochs)
  dataset = dataset.batch(batch_size) # Batch size to use

  #return dataset.make_one_shot_iterator()
  iterator = dataset.make_one_shot_iterator()
  batch_features, batch_labels = iterator.get_next()
  #return tf.reshape(batch_ids,shape=[-1,field_size]), tf.reshape(batch_vals,shape=[-1,field_size]), batch_labels
  return batch_features, batch_labels

到此这篇关于浅谈tensorflow中dataset.shuffle和dataset.batch dataset.repeat注意点的文章就介绍到这了,更多相关tensorflow中dataset.shuffle和dataset.batch dataset.repeat内容请搜索三水点靠木以前的文章或继续浏览下面的相关文章希望大家以后多多支持三水点靠木! 

Python 相关文章推荐
Python的MongoDB模块PyMongo操作方法集锦
Jan 05 Python
python类的方法属性与方法属性的动态绑定代码详解
Dec 27 Python
Python使用OpenCV进行标定
May 08 Python
Python常见字典内建函数用法示例
May 14 Python
python顺序的读取文件夹下名称有序的文件方法
Jul 11 Python
numpy.where() 用法详解
May 27 Python
python 环境搭建 及python-3.4.4的下载和安装过程
Jul 20 Python
Python浮点数四舍五入问题的分析与解决方法
Nov 19 Python
python datetime处理时间小结
Apr 16 Python
python GUI模拟实现计算器
Jun 22 Python
Python用户自定义异常的实现
Dec 25 Python
python定义具名元组实例操作
Feb 28 Python
Python3通过chmod修改目录或文件权限的方法示例
Jun 08 #Python
win10下python3.8的PIL库安装过程
Jun 08 #Python
python rolling regression. 使用 Python 实现滚动回归操作
Jun 08 #Python
Python selenium爬虫实现定时任务过程解析
Jun 08 #Python
python:HDF和CSV存储优劣对比分析
Jun 08 #Python
Python实现一个简单的毕业生信息管理系统的示例代码
Jun 08 #Python
Python while true实现爬虫定时任务
Jun 08 #Python
You might like
PHP 批量删除 sql语句
2009/06/05 PHP
解决PHP在DOS命令行下却无法链接MySQL的技术笔记
2010/12/29 PHP
PHP仿博客园 个人博客(1) 数据库与界面设计
2013/07/05 PHP
PHP保留两位小数并且四舍五入及不四舍五入的方法
2013/09/22 PHP
解密ThinkPHP3.1.2版本之独立分组功能应用
2014/06/19 PHP
实例讲解php将字符串输出到HTML
2019/01/27 PHP
Laravel框架源码解析之反射的使用详解
2020/05/14 PHP
JS 添加网页桌面快捷方式的代码详细整理
2012/12/27 Javascript
jquery怎样实现ajax联动框(一)
2013/03/08 Javascript
输入自动提示搜索提示功能的javascript:sugggestion.js
2013/09/02 Javascript
在javascript中执行任意html代码的方法示例解读
2013/12/25 Javascript
javascript如何使用bind指定接收者
2014/05/04 Javascript
javascript中解析四则运算表达式的算法和示例
2014/08/11 Javascript
影响jQuery使用的14个方面
2014/09/01 Javascript
两种方法基于jQuery实现IE浏览器兼容placeholder效果
2014/10/14 Javascript
原生JS实现美图瀑布流布局赏析
2015/09/07 Javascript
微信小程序 页面跳转传递值几种方法详解
2017/01/12 Javascript
微信小程序组件 marquee实例详解
2017/06/23 Javascript
jq.ajax+php+mysql实现关键字模糊查询(示例讲解)
2018/01/02 Javascript
jquery获取元素到屏幕四周可视距离的方法
2018/09/05 jQuery
vue中使用带隐藏文本信息的图片、图片水印的方法
2020/04/24 Javascript
详解微信小程序入门从这里出发(登录注册、开发工具、文件及结构介绍)
2020/07/21 Javascript
[01:40]2014DOTA2国际邀请赛 三冰SOLO赛后采访恶搞
2014/07/09 DOTA
pymongo为mongodb数据库添加索引的方法
2015/05/11 Python
Python2.X/Python3.X中urllib库区别讲解
2017/12/19 Python
python最小生成树kruskal与prim算法详解
2019/01/17 Python
Python面向对象程序设计类变量与成员变量、类方法与成员方法用法分析
2019/04/12 Python
Python二元赋值实用技巧解析
2019/10/25 Python
Python利用逻辑回归模型解决MNIST手写数字识别问题详解
2020/01/14 Python
浅谈HTML5 & CSS3的新交互特性
2016/07/19 HTML / CSS
英国在线电子和小工具商店:TecoBuy
2018/10/06 全球购物
联想智利官方网站:Lenovo Chile
2020/06/03 全球购物
linux面试相关问题
2013/04/28 面试题
幼师专业求职推荐信
2013/11/08 职场文书
2015年七一建党节慰问信
2015/03/23 职场文书
如何给HttpServletRequest增加消息头
2021/06/30 Java/Android