浅谈tensorflow中dataset.shuffle和dataset.batch dataset.repeat注意点


Posted in Python onJune 08, 2020

batch很好理解,就是batch size。注意在一个epoch中最后一个batch大小可能小于等于batch size

dataset.repeat就是俗称epoch,但在tf中与dataset.shuffle的使用顺序可能会导致个epoch的混合

dataset.shuffle就是说维持一个buffer size 大小的 shuffle buffer,图中所需的每个样本从shuffle buffer中获取,取得一个样本后,就从源数据集中加入一个样本到shuffle buffer中。

import os
os.environ['CUDA_VISIBLE_DEVICES'] = ""
import numpy as np
import tensorflow as tf
np.random.seed(0)
x = np.random.sample((11,2))
# make a dataset from a numpy array
print(x)
print()
dataset = tf.data.Dataset.from_tensor_slices(x)
dataset = dataset.shuffle(3)
dataset = dataset.batch(4)
dataset = dataset.repeat(2)

# create the iterator
iter = dataset.make_one_shot_iterator()
el = iter.get_next()

with tf.Session() as sess:
  print(sess.run(el))
  print(sess.run(el))
  print(sess.run(el))
  print(sess.run(el))
  print(sess.run(el))
  print(sess.run(el))
  print(sess.run(el))
  print(sess.run(el))
  print(sess.run(el))
  print(sess.run(el))
  print(sess.run(el))
  print(sess.run(el))
  print(sess.run(el))
#源数据集
[[ 0.5488135  0.71518937]
 [ 0.60276338 0.54488318]
 [ 0.4236548  0.64589411]
 [ 0.43758721 0.891773 ]
 [ 0.96366276 0.38344152]
 [ 0.79172504 0.52889492]
 [ 0.56804456 0.92559664]
 [ 0.07103606 0.0871293 ]
 [ 0.0202184  0.83261985]
 [ 0.77815675 0.87001215]
 [ 0.97861834 0.79915856]]

# 通过shuffle batch后取得的样本
[[ 0.4236548  0.64589411]
 [ 0.60276338 0.54488318]
 [ 0.43758721 0.891773 ]
 [ 0.5488135  0.71518937]]
[[ 0.96366276 0.38344152]
 [ 0.56804456 0.92559664]
 [ 0.0202184  0.83261985]
 [ 0.79172504 0.52889492]]
[[ 0.07103606 0.0871293 ]
 [ 0.97861834 0.79915856]
 [ 0.77815675 0.87001215]] #最后一个batch样本个数为3
[[ 0.60276338 0.54488318]
 [ 0.5488135  0.71518937]
 [ 0.43758721 0.891773 ]
 [ 0.79172504 0.52889492]]
[[ 0.4236548  0.64589411]
 [ 0.56804456 0.92559664]
 [ 0.0202184  0.83261985]
 [ 0.07103606 0.0871293 ]]
[[ 0.77815675 0.87001215]
 [ 0.96366276 0.38344152]
 [ 0.97861834 0.79915856]] #最后一个batch样本个数为3

1、按照shuffle中设置的buffer size,首先从源数据集取得三个样本:
shuffle buffer:
[ 0.5488135 0.71518937]
[ 0.60276338 0.54488318]
[ 0.4236548 0.64589411]
2、从buffer中取一个样本到batch中得:
shuffle buffer:
[ 0.5488135 0.71518937]
[ 0.60276338 0.54488318]
batch:
[ 0.4236548 0.64589411]
3、shuffle buffer不足三个样本,从源数据集提取一个样本:
shuffle buffer:
[ 0.5488135 0.71518937]
[ 0.60276338 0.54488318]
[ 0.43758721 0.891773 ]
4、从buffer中取一个样本到batch中得:
shuffle buffer:
[ 0.5488135 0.71518937]
[ 0.43758721 0.891773 ]
batch:
[ 0.4236548 0.64589411]
[ 0.60276338 0.54488318]
5、如此反复。这就意味中如果shuffle 的buffer size=1,数据集不打乱。如果shuffle 的buffer size=数据集样本数量,随机打乱整个数据集

import os
os.environ['CUDA_VISIBLE_DEVICES'] = ""
import numpy as np
import tensorflow as tf
np.random.seed(0)
x = np.random.sample((11,2))
# make a dataset from a numpy array
print(x)
print()
dataset = tf.data.Dataset.from_tensor_slices(x)
dataset = dataset.shuffle(1)
dataset = dataset.batch(4)
dataset = dataset.repeat(2)

# create the iterator
iter = dataset.make_one_shot_iterator()
el = iter.get_next()

with tf.Session() as sess:
  print(sess.run(el))
  print(sess.run(el))
  print(sess.run(el))
  print(sess.run(el))
  print(sess.run(el))
  print(sess.run(el))
  print(sess.run(el))
  print(sess.run(el))
  print(sess.run(el))
  print(sess.run(el))
  print(sess.run(el))
  print(sess.run(el))
  print(sess.run(el))

[[ 0.5488135  0.71518937]
 [ 0.60276338 0.54488318]
 [ 0.4236548  0.64589411]
 [ 0.43758721 0.891773 ]
 [ 0.96366276 0.38344152]
 [ 0.79172504 0.52889492]
 [ 0.56804456 0.92559664]
 [ 0.07103606 0.0871293 ]
 [ 0.0202184  0.83261985]
 [ 0.77815675 0.87001215]
 [ 0.97861834 0.79915856]]

[[ 0.5488135  0.71518937]
 [ 0.60276338 0.54488318]
 [ 0.4236548  0.64589411]
 [ 0.43758721 0.891773 ]]
[[ 0.96366276 0.38344152]
 [ 0.79172504 0.52889492]
 [ 0.56804456 0.92559664]
 [ 0.07103606 0.0871293 ]]
[[ 0.0202184  0.83261985]
 [ 0.77815675 0.87001215]
 [ 0.97861834 0.79915856]]
[[ 0.5488135  0.71518937]
 [ 0.60276338 0.54488318]
 [ 0.4236548  0.64589411]
 [ 0.43758721 0.891773 ]]
[[ 0.96366276 0.38344152]
 [ 0.79172504 0.52889492]
 [ 0.56804456 0.92559664]
 [ 0.07103606 0.0871293 ]]
[[ 0.0202184  0.83261985]
 [ 0.77815675 0.87001215]
 [ 0.97861834 0.79915856]]

注意如果repeat在shuffle之前使用:

官方说repeat在shuffle之前使用能提高性能,但模糊了数据样本的epoch关系

import os
os.environ['CUDA_VISIBLE_DEVICES'] = ""
import numpy as np
import tensorflow as tf
np.random.seed(0)
x = np.random.sample((11,2))
# make a dataset from a numpy array
print(x)
print()
dataset = tf.data.Dataset.from_tensor_slices(x)
dataset = dataset.repeat(2)
dataset = dataset.shuffle(11)
dataset = dataset.batch(4)

# create the iterator
iter = dataset.make_one_shot_iterator()
el = iter.get_next()

with tf.Session() as sess:
  print(sess.run(el))
  print(sess.run(el))
  print(sess.run(el))
  print(sess.run(el))
  print(sess.run(el))
  print(sess.run(el))
  print(sess.run(el))
  print(sess.run(el))
  print(sess.run(el))
  print(sess.run(el))
  print(sess.run(el))
  print(sess.run(el))
  print(sess.run(el))

[[ 0.5488135  0.71518937]
 [ 0.60276338 0.54488318]
 [ 0.4236548  0.64589411]
 [ 0.43758721 0.891773 ]
 [ 0.96366276 0.38344152]
 [ 0.79172504 0.52889492]
 [ 0.56804456 0.92559664]
 [ 0.07103606 0.0871293 ]
 [ 0.0202184  0.83261985]
 [ 0.77815675 0.87001215]
 [ 0.97861834 0.79915856]]

[[ 0.56804456 0.92559664]
 [ 0.5488135  0.71518937]
 [ 0.60276338 0.54488318]
 [ 0.07103606 0.0871293 ]]
[[ 0.96366276 0.38344152]
 [ 0.43758721 0.891773 ]
 [ 0.43758721 0.891773 ]
 [ 0.77815675 0.87001215]]
[[ 0.79172504 0.52889492]  #出现相同样本出现在同一个batch中
 [ 0.79172504 0.52889492]
 [ 0.60276338 0.54488318]
 [ 0.4236548  0.64589411]]
[[ 0.07103606 0.0871293 ]
 [ 0.4236548  0.64589411]
 [ 0.96366276 0.38344152]
 [ 0.5488135  0.71518937]]
[[ 0.97861834 0.79915856]
 [ 0.0202184  0.83261985]
 [ 0.77815675 0.87001215]
 [ 0.56804456 0.92559664]]
[[ 0.0202184  0.83261985]
 [ 0.97861834 0.79915856]]     #可以看到最后个batch为2,而前面都是4

使用案例:

def input_fn(filenames, batch_size=32, num_epochs=1, perform_shuffle=False):
  print('Parsing', filenames)
  def decode_libsvm(line):
    #columns = tf.decode_csv(value, record_defaults=CSV_COLUMN_DEFAULTS)
    #features = dict(zip(CSV_COLUMNS, columns))
    #labels = features.pop(LABEL_COLUMN)
    columns = tf.string_split([line], ' ')
    labels = tf.string_to_number(columns.values[0], out_type=tf.float32)
    splits = tf.string_split(columns.values[1:], ':')
    id_vals = tf.reshape(splits.values,splits.dense_shape)
    feat_ids, feat_vals = tf.split(id_vals,num_or_size_splits=2,axis=1)
    feat_ids = tf.string_to_number(feat_ids, out_type=tf.int32)
    feat_vals = tf.string_to_number(feat_vals, out_type=tf.float32)
    #feat_ids = tf.reshape(feat_ids,shape=[-1,FLAGS.field_size])
    #for i in range(splits.dense_shape.eval()[0]):
    #  feat_ids.append(tf.string_to_number(splits.values[2*i], out_type=tf.int32))
    #  feat_vals.append(tf.string_to_number(splits.values[2*i+1]))
    #return tf.reshape(feat_ids,shape=[-1,field_size]), tf.reshape(feat_vals,shape=[-1,field_size]), labels
    return {"feat_ids": feat_ids, "feat_vals": feat_vals}, labels

  # Extract lines from input files using the Dataset API, can pass one filename or filename list
  dataset = tf.data.TextLineDataset(filenames).map(decode_libsvm, num_parallel_calls=10).prefetch(500000)  # multi-thread pre-process then prefetch

  # Randomizes input using a window of 256 elements (read into memory)
  if perform_shuffle:
    dataset = dataset.shuffle(buffer_size=256)

  # epochs from blending together.
  dataset = dataset.repeat(num_epochs)
  dataset = dataset.batch(batch_size) # Batch size to use

  #return dataset.make_one_shot_iterator()
  iterator = dataset.make_one_shot_iterator()
  batch_features, batch_labels = iterator.get_next()
  #return tf.reshape(batch_ids,shape=[-1,field_size]), tf.reshape(batch_vals,shape=[-1,field_size]), batch_labels
  return batch_features, batch_labels

到此这篇关于浅谈tensorflow中dataset.shuffle和dataset.batch dataset.repeat注意点的文章就介绍到这了,更多相关tensorflow中dataset.shuffle和dataset.batch dataset.repeat内容请搜索三水点靠木以前的文章或继续浏览下面的相关文章希望大家以后多多支持三水点靠木! 

Python 相关文章推荐
Python配置文件解析模块ConfigParser使用实例
Apr 13 Python
使用Python构建Hopfield网络的教程
Apr 14 Python
浅谈Python基础之I/O模型
May 11 Python
python+matplotlib绘制饼图散点图实例代码
Jan 20 Python
python实现C4.5决策树算法
Aug 29 Python
详解Python正则表达式re模块
Mar 19 Python
使用python实现简单五子棋游戏
Jun 18 Python
python处理自动化任务之同时批量修改word里面的内容的方法
Aug 23 Python
python3 xpath和requests应用详解
Mar 06 Python
如何在VSCode下使用Jupyter的教程详解
Jul 13 Python
python用分数表示矩阵的方法实例
Jan 11 Python
聊聊Python String型列表求最值的问题
Jan 18 Python
Python3通过chmod修改目录或文件权限的方法示例
Jun 08 #Python
win10下python3.8的PIL库安装过程
Jun 08 #Python
python rolling regression. 使用 Python 实现滚动回归操作
Jun 08 #Python
Python selenium爬虫实现定时任务过程解析
Jun 08 #Python
python:HDF和CSV存储优劣对比分析
Jun 08 #Python
Python实现一个简单的毕业生信息管理系统的示例代码
Jun 08 #Python
Python while true实现爬虫定时任务
Jun 08 #Python
You might like
php读取二进制流(C语言结构体struct数据文件)的深入解析
2013/06/13 PHP
在PHP语言中使用JSON和将json还原成数组的方法
2016/07/19 PHP
yii插入数据库防并发的简单代码
2017/05/27 PHP
PHP defined()函数的使用图文详解
2019/07/20 PHP
PHP生成随机字符串实例代码(字母+数字)
2019/09/11 PHP
PHP fopen中文文件名乱码问题解决方案
2020/10/28 PHP
一些常用的Javascript函数
2006/12/22 Javascript
JQuery live函数
2010/12/24 Javascript
JQuery制作的放大效果的popup对话框(未添加任何jquery plugin)分享
2013/04/28 Javascript
javascript在子页面中函数无法调试问题解决方法
2014/01/17 Javascript
javascript定义变量时加var与不加var的区别
2014/12/22 Javascript
Jquery组件easyUi实现表单验证示例
2016/08/23 Javascript
Javascript中判断一个值是否为undefined的方法详解
2016/09/28 Javascript
用AngularJS来实现监察表单按钮的禁用效果
2016/11/02 Javascript
100多个基础常用JS函数和语法集合大全
2017/02/16 Javascript
JS基于正则实现数字千分位用逗号分隔的方法
2017/06/16 Javascript
Angular4.0中引入laydate.js日期插件的方法教程
2017/12/25 Javascript
Angular实现模版驱动表单的自定义校验功能(密码确认为例)
2018/05/17 Javascript
Three.js实现简单3D房间布局
2018/12/30 Javascript
利用Bootstrap Multiselect实现下拉框多选功能
2019/04/08 Javascript
ElementUI Tag组件实现多标签生成的方法示例
2019/07/08 Javascript
js实现磁性吸附的示例
2020/10/26 Javascript
Vue 的 v-model用法实例
2020/11/23 Vue.js
javascript实现点击小图显示大图
2020/11/29 Javascript
[41:56]Spirit vs Liquid Supermajor小组赛A组 BO3 第一场 6.2
2018/06/03 DOTA
零基础写python爬虫之urllib2使用指南
2014/11/05 Python
python寻找list中最大值、最小值并返回其所在位置的方法
2018/06/27 Python
Python urlopen()和urlretrieve()用法解析
2020/01/07 Python
pytorch方法测试——激活函数(ReLU)详解
2020/01/15 Python
详解CSS3原生支持div铺满浏览器的方法
2018/08/30 HTML / CSS
HTML5中判断用户是否正在浏览页面的方法
2014/05/03 HTML / CSS
上党课的心得体会
2014/09/02 职场文书
作风大整顿心得体会
2014/09/10 职场文书
党支部反对四风思想汇报
2014/10/10 职场文书
公司行政助理岗位职责
2015/04/11 职场文书
React forwardRef的使用方法及注意点
2021/06/13 Javascript