浅谈tensorflow中dataset.shuffle和dataset.batch dataset.repeat注意点


Posted in Python onJune 08, 2020

batch很好理解,就是batch size。注意在一个epoch中最后一个batch大小可能小于等于batch size

dataset.repeat就是俗称epoch,但在tf中与dataset.shuffle的使用顺序可能会导致个epoch的混合

dataset.shuffle就是说维持一个buffer size 大小的 shuffle buffer,图中所需的每个样本从shuffle buffer中获取,取得一个样本后,就从源数据集中加入一个样本到shuffle buffer中。

import os
os.environ['CUDA_VISIBLE_DEVICES'] = ""
import numpy as np
import tensorflow as tf
np.random.seed(0)
x = np.random.sample((11,2))
# make a dataset from a numpy array
print(x)
print()
dataset = tf.data.Dataset.from_tensor_slices(x)
dataset = dataset.shuffle(3)
dataset = dataset.batch(4)
dataset = dataset.repeat(2)

# create the iterator
iter = dataset.make_one_shot_iterator()
el = iter.get_next()

with tf.Session() as sess:
  print(sess.run(el))
  print(sess.run(el))
  print(sess.run(el))
  print(sess.run(el))
  print(sess.run(el))
  print(sess.run(el))
  print(sess.run(el))
  print(sess.run(el))
  print(sess.run(el))
  print(sess.run(el))
  print(sess.run(el))
  print(sess.run(el))
  print(sess.run(el))
#源数据集
[[ 0.5488135  0.71518937]
 [ 0.60276338 0.54488318]
 [ 0.4236548  0.64589411]
 [ 0.43758721 0.891773 ]
 [ 0.96366276 0.38344152]
 [ 0.79172504 0.52889492]
 [ 0.56804456 0.92559664]
 [ 0.07103606 0.0871293 ]
 [ 0.0202184  0.83261985]
 [ 0.77815675 0.87001215]
 [ 0.97861834 0.79915856]]

# 通过shuffle batch后取得的样本
[[ 0.4236548  0.64589411]
 [ 0.60276338 0.54488318]
 [ 0.43758721 0.891773 ]
 [ 0.5488135  0.71518937]]
[[ 0.96366276 0.38344152]
 [ 0.56804456 0.92559664]
 [ 0.0202184  0.83261985]
 [ 0.79172504 0.52889492]]
[[ 0.07103606 0.0871293 ]
 [ 0.97861834 0.79915856]
 [ 0.77815675 0.87001215]] #最后一个batch样本个数为3
[[ 0.60276338 0.54488318]
 [ 0.5488135  0.71518937]
 [ 0.43758721 0.891773 ]
 [ 0.79172504 0.52889492]]
[[ 0.4236548  0.64589411]
 [ 0.56804456 0.92559664]
 [ 0.0202184  0.83261985]
 [ 0.07103606 0.0871293 ]]
[[ 0.77815675 0.87001215]
 [ 0.96366276 0.38344152]
 [ 0.97861834 0.79915856]] #最后一个batch样本个数为3

1、按照shuffle中设置的buffer size,首先从源数据集取得三个样本:
shuffle buffer:
[ 0.5488135 0.71518937]
[ 0.60276338 0.54488318]
[ 0.4236548 0.64589411]
2、从buffer中取一个样本到batch中得:
shuffle buffer:
[ 0.5488135 0.71518937]
[ 0.60276338 0.54488318]
batch:
[ 0.4236548 0.64589411]
3、shuffle buffer不足三个样本,从源数据集提取一个样本:
shuffle buffer:
[ 0.5488135 0.71518937]
[ 0.60276338 0.54488318]
[ 0.43758721 0.891773 ]
4、从buffer中取一个样本到batch中得:
shuffle buffer:
[ 0.5488135 0.71518937]
[ 0.43758721 0.891773 ]
batch:
[ 0.4236548 0.64589411]
[ 0.60276338 0.54488318]
5、如此反复。这就意味中如果shuffle 的buffer size=1,数据集不打乱。如果shuffle 的buffer size=数据集样本数量,随机打乱整个数据集

import os
os.environ['CUDA_VISIBLE_DEVICES'] = ""
import numpy as np
import tensorflow as tf
np.random.seed(0)
x = np.random.sample((11,2))
# make a dataset from a numpy array
print(x)
print()
dataset = tf.data.Dataset.from_tensor_slices(x)
dataset = dataset.shuffle(1)
dataset = dataset.batch(4)
dataset = dataset.repeat(2)

# create the iterator
iter = dataset.make_one_shot_iterator()
el = iter.get_next()

with tf.Session() as sess:
  print(sess.run(el))
  print(sess.run(el))
  print(sess.run(el))
  print(sess.run(el))
  print(sess.run(el))
  print(sess.run(el))
  print(sess.run(el))
  print(sess.run(el))
  print(sess.run(el))
  print(sess.run(el))
  print(sess.run(el))
  print(sess.run(el))
  print(sess.run(el))

[[ 0.5488135  0.71518937]
 [ 0.60276338 0.54488318]
 [ 0.4236548  0.64589411]
 [ 0.43758721 0.891773 ]
 [ 0.96366276 0.38344152]
 [ 0.79172504 0.52889492]
 [ 0.56804456 0.92559664]
 [ 0.07103606 0.0871293 ]
 [ 0.0202184  0.83261985]
 [ 0.77815675 0.87001215]
 [ 0.97861834 0.79915856]]

[[ 0.5488135  0.71518937]
 [ 0.60276338 0.54488318]
 [ 0.4236548  0.64589411]
 [ 0.43758721 0.891773 ]]
[[ 0.96366276 0.38344152]
 [ 0.79172504 0.52889492]
 [ 0.56804456 0.92559664]
 [ 0.07103606 0.0871293 ]]
[[ 0.0202184  0.83261985]
 [ 0.77815675 0.87001215]
 [ 0.97861834 0.79915856]]
[[ 0.5488135  0.71518937]
 [ 0.60276338 0.54488318]
 [ 0.4236548  0.64589411]
 [ 0.43758721 0.891773 ]]
[[ 0.96366276 0.38344152]
 [ 0.79172504 0.52889492]
 [ 0.56804456 0.92559664]
 [ 0.07103606 0.0871293 ]]
[[ 0.0202184  0.83261985]
 [ 0.77815675 0.87001215]
 [ 0.97861834 0.79915856]]

注意如果repeat在shuffle之前使用:

官方说repeat在shuffle之前使用能提高性能,但模糊了数据样本的epoch关系

import os
os.environ['CUDA_VISIBLE_DEVICES'] = ""
import numpy as np
import tensorflow as tf
np.random.seed(0)
x = np.random.sample((11,2))
# make a dataset from a numpy array
print(x)
print()
dataset = tf.data.Dataset.from_tensor_slices(x)
dataset = dataset.repeat(2)
dataset = dataset.shuffle(11)
dataset = dataset.batch(4)

# create the iterator
iter = dataset.make_one_shot_iterator()
el = iter.get_next()

with tf.Session() as sess:
  print(sess.run(el))
  print(sess.run(el))
  print(sess.run(el))
  print(sess.run(el))
  print(sess.run(el))
  print(sess.run(el))
  print(sess.run(el))
  print(sess.run(el))
  print(sess.run(el))
  print(sess.run(el))
  print(sess.run(el))
  print(sess.run(el))
  print(sess.run(el))

[[ 0.5488135  0.71518937]
 [ 0.60276338 0.54488318]
 [ 0.4236548  0.64589411]
 [ 0.43758721 0.891773 ]
 [ 0.96366276 0.38344152]
 [ 0.79172504 0.52889492]
 [ 0.56804456 0.92559664]
 [ 0.07103606 0.0871293 ]
 [ 0.0202184  0.83261985]
 [ 0.77815675 0.87001215]
 [ 0.97861834 0.79915856]]

[[ 0.56804456 0.92559664]
 [ 0.5488135  0.71518937]
 [ 0.60276338 0.54488318]
 [ 0.07103606 0.0871293 ]]
[[ 0.96366276 0.38344152]
 [ 0.43758721 0.891773 ]
 [ 0.43758721 0.891773 ]
 [ 0.77815675 0.87001215]]
[[ 0.79172504 0.52889492]  #出现相同样本出现在同一个batch中
 [ 0.79172504 0.52889492]
 [ 0.60276338 0.54488318]
 [ 0.4236548  0.64589411]]
[[ 0.07103606 0.0871293 ]
 [ 0.4236548  0.64589411]
 [ 0.96366276 0.38344152]
 [ 0.5488135  0.71518937]]
[[ 0.97861834 0.79915856]
 [ 0.0202184  0.83261985]
 [ 0.77815675 0.87001215]
 [ 0.56804456 0.92559664]]
[[ 0.0202184  0.83261985]
 [ 0.97861834 0.79915856]]     #可以看到最后个batch为2,而前面都是4

使用案例:

def input_fn(filenames, batch_size=32, num_epochs=1, perform_shuffle=False):
  print('Parsing', filenames)
  def decode_libsvm(line):
    #columns = tf.decode_csv(value, record_defaults=CSV_COLUMN_DEFAULTS)
    #features = dict(zip(CSV_COLUMNS, columns))
    #labels = features.pop(LABEL_COLUMN)
    columns = tf.string_split([line], ' ')
    labels = tf.string_to_number(columns.values[0], out_type=tf.float32)
    splits = tf.string_split(columns.values[1:], ':')
    id_vals = tf.reshape(splits.values,splits.dense_shape)
    feat_ids, feat_vals = tf.split(id_vals,num_or_size_splits=2,axis=1)
    feat_ids = tf.string_to_number(feat_ids, out_type=tf.int32)
    feat_vals = tf.string_to_number(feat_vals, out_type=tf.float32)
    #feat_ids = tf.reshape(feat_ids,shape=[-1,FLAGS.field_size])
    #for i in range(splits.dense_shape.eval()[0]):
    #  feat_ids.append(tf.string_to_number(splits.values[2*i], out_type=tf.int32))
    #  feat_vals.append(tf.string_to_number(splits.values[2*i+1]))
    #return tf.reshape(feat_ids,shape=[-1,field_size]), tf.reshape(feat_vals,shape=[-1,field_size]), labels
    return {"feat_ids": feat_ids, "feat_vals": feat_vals}, labels

  # Extract lines from input files using the Dataset API, can pass one filename or filename list
  dataset = tf.data.TextLineDataset(filenames).map(decode_libsvm, num_parallel_calls=10).prefetch(500000)  # multi-thread pre-process then prefetch

  # Randomizes input using a window of 256 elements (read into memory)
  if perform_shuffle:
    dataset = dataset.shuffle(buffer_size=256)

  # epochs from blending together.
  dataset = dataset.repeat(num_epochs)
  dataset = dataset.batch(batch_size) # Batch size to use

  #return dataset.make_one_shot_iterator()
  iterator = dataset.make_one_shot_iterator()
  batch_features, batch_labels = iterator.get_next()
  #return tf.reshape(batch_ids,shape=[-1,field_size]), tf.reshape(batch_vals,shape=[-1,field_size]), batch_labels
  return batch_features, batch_labels

到此这篇关于浅谈tensorflow中dataset.shuffle和dataset.batch dataset.repeat注意点的文章就介绍到这了,更多相关tensorflow中dataset.shuffle和dataset.batch dataset.repeat内容请搜索三水点靠木以前的文章或继续浏览下面的相关文章希望大家以后多多支持三水点靠木! 

Python 相关文章推荐
Eclipse + Python 的安装与配置流程
Mar 05 Python
pytyon 带有重复的全排列
Aug 13 Python
python代码制作configure文件示例
Jul 28 Python
零基础写python爬虫之爬虫的定义及URL构成
Nov 04 Python
在Python 字典中一键对应多个值的实例
Feb 03 Python
python 字符串追加实例
Jul 20 Python
python 利用jinja2模板生成html代码实例
Oct 10 Python
使用python实现kNN分类算法
Oct 16 Python
PyTorch加载预训练模型实例(pretrained)
Jan 17 Python
Python3爬虫里关于Splash负载均衡配置详解
Jul 10 Python
浅析Python 条件控制语句
Jul 15 Python
Python 使用生成器代替线程的方法
Aug 04 Python
Python3通过chmod修改目录或文件权限的方法示例
Jun 08 #Python
win10下python3.8的PIL库安装过程
Jun 08 #Python
python rolling regression. 使用 Python 实现滚动回归操作
Jun 08 #Python
Python selenium爬虫实现定时任务过程解析
Jun 08 #Python
python:HDF和CSV存储优劣对比分析
Jun 08 #Python
Python实现一个简单的毕业生信息管理系统的示例代码
Jun 08 #Python
Python while true实现爬虫定时任务
Jun 08 #Python
You might like
PHP 采集获取指定网址的内容
2010/01/05 PHP
php获取汉字首字母的函数
2013/11/07 PHP
PHP的运行机制与原理(底层)
2015/11/16 PHP
浅谈Laravel中的一个后期静态绑定
2017/08/11 PHP
[原创]PHP实现生成vcf vcard文件功能类定义与使用方法详解【附demo源码下载】
2017/09/02 PHP
php表单习惯用的正则表达式
2017/10/11 PHP
js left,right,mid函数
2008/06/10 Javascript
javascript实现滑动解锁功能
2014/12/31 Javascript
jQuery找出网页上最高元素的方法
2015/03/20 Javascript
JavaScript显示表单内元素数量的方法
2015/04/02 Javascript
JavaScript中字符串分割函数split用法实例
2015/04/07 Javascript
JavaScript实现将xml转换成html table表格的方法
2015/04/17 Javascript
jQuery实现仿百度帖吧头部固定导航效果
2015/08/07 Javascript
浏览器兼容性问题大汇总
2015/12/17 Javascript
原生javascript+css3编写的3D魔方动画旋扭特效
2016/03/14 Javascript
jQuery Ajax实现跨域请求
2017/01/21 Javascript
JS对象是否拥有某属性如何判断
2017/02/03 Javascript
jQuery Masonry瀑布流布局神器使用详解
2017/05/25 jQuery
使用React手写一个对话框或模态框的方法示例
2019/04/25 Javascript
浅谈Vuex的this.$store.commit和在Vue项目中引用公共方法
2020/07/24 Javascript
Python中的MongoDB基本操作:连接、查询实例
2015/02/13 Python
Python sys.argv用法实例
2015/05/28 Python
Python 爬虫之超链接 url中含有中文出错及解决办法
2017/08/03 Python
Python解析多帧dicom数据详解
2020/01/13 Python
如何基于python实现归一化处理
2020/01/20 Python
基于Python绘制美观动态圆环图、饼图
2020/06/03 Python
基于python模拟TCP3次握手连接及发送数据
2020/11/06 Python
房地产销售员的自我评价分享
2013/12/04 职场文书
员工拾金不昧表扬信
2014/01/09 职场文书
美食节目策划方案
2014/05/31 职场文书
计算机专业求职信
2014/06/02 职场文书
舞蹈教育学专业自荐信
2014/06/15 职场文书
防邪知识进家庭活动方案
2014/08/26 职场文书
房屋认购协议书
2015/01/29 职场文书
升职自荐信范文
2015/03/27 职场文书
2015最新民情日记范文
2015/06/26 职场文书