浅谈tensorflow中dataset.shuffle和dataset.batch dataset.repeat注意点


Posted in Python onJune 08, 2020

batch很好理解,就是batch size。注意在一个epoch中最后一个batch大小可能小于等于batch size

dataset.repeat就是俗称epoch,但在tf中与dataset.shuffle的使用顺序可能会导致个epoch的混合

dataset.shuffle就是说维持一个buffer size 大小的 shuffle buffer,图中所需的每个样本从shuffle buffer中获取,取得一个样本后,就从源数据集中加入一个样本到shuffle buffer中。

import os
os.environ['CUDA_VISIBLE_DEVICES'] = ""
import numpy as np
import tensorflow as tf
np.random.seed(0)
x = np.random.sample((11,2))
# make a dataset from a numpy array
print(x)
print()
dataset = tf.data.Dataset.from_tensor_slices(x)
dataset = dataset.shuffle(3)
dataset = dataset.batch(4)
dataset = dataset.repeat(2)

# create the iterator
iter = dataset.make_one_shot_iterator()
el = iter.get_next()

with tf.Session() as sess:
  print(sess.run(el))
  print(sess.run(el))
  print(sess.run(el))
  print(sess.run(el))
  print(sess.run(el))
  print(sess.run(el))
  print(sess.run(el))
  print(sess.run(el))
  print(sess.run(el))
  print(sess.run(el))
  print(sess.run(el))
  print(sess.run(el))
  print(sess.run(el))
#源数据集
[[ 0.5488135  0.71518937]
 [ 0.60276338 0.54488318]
 [ 0.4236548  0.64589411]
 [ 0.43758721 0.891773 ]
 [ 0.96366276 0.38344152]
 [ 0.79172504 0.52889492]
 [ 0.56804456 0.92559664]
 [ 0.07103606 0.0871293 ]
 [ 0.0202184  0.83261985]
 [ 0.77815675 0.87001215]
 [ 0.97861834 0.79915856]]

# 通过shuffle batch后取得的样本
[[ 0.4236548  0.64589411]
 [ 0.60276338 0.54488318]
 [ 0.43758721 0.891773 ]
 [ 0.5488135  0.71518937]]
[[ 0.96366276 0.38344152]
 [ 0.56804456 0.92559664]
 [ 0.0202184  0.83261985]
 [ 0.79172504 0.52889492]]
[[ 0.07103606 0.0871293 ]
 [ 0.97861834 0.79915856]
 [ 0.77815675 0.87001215]] #最后一个batch样本个数为3
[[ 0.60276338 0.54488318]
 [ 0.5488135  0.71518937]
 [ 0.43758721 0.891773 ]
 [ 0.79172504 0.52889492]]
[[ 0.4236548  0.64589411]
 [ 0.56804456 0.92559664]
 [ 0.0202184  0.83261985]
 [ 0.07103606 0.0871293 ]]
[[ 0.77815675 0.87001215]
 [ 0.96366276 0.38344152]
 [ 0.97861834 0.79915856]] #最后一个batch样本个数为3

1、按照shuffle中设置的buffer size,首先从源数据集取得三个样本:
shuffle buffer:
[ 0.5488135 0.71518937]
[ 0.60276338 0.54488318]
[ 0.4236548 0.64589411]
2、从buffer中取一个样本到batch中得:
shuffle buffer:
[ 0.5488135 0.71518937]
[ 0.60276338 0.54488318]
batch:
[ 0.4236548 0.64589411]
3、shuffle buffer不足三个样本,从源数据集提取一个样本:
shuffle buffer:
[ 0.5488135 0.71518937]
[ 0.60276338 0.54488318]
[ 0.43758721 0.891773 ]
4、从buffer中取一个样本到batch中得:
shuffle buffer:
[ 0.5488135 0.71518937]
[ 0.43758721 0.891773 ]
batch:
[ 0.4236548 0.64589411]
[ 0.60276338 0.54488318]
5、如此反复。这就意味中如果shuffle 的buffer size=1,数据集不打乱。如果shuffle 的buffer size=数据集样本数量,随机打乱整个数据集

import os
os.environ['CUDA_VISIBLE_DEVICES'] = ""
import numpy as np
import tensorflow as tf
np.random.seed(0)
x = np.random.sample((11,2))
# make a dataset from a numpy array
print(x)
print()
dataset = tf.data.Dataset.from_tensor_slices(x)
dataset = dataset.shuffle(1)
dataset = dataset.batch(4)
dataset = dataset.repeat(2)

# create the iterator
iter = dataset.make_one_shot_iterator()
el = iter.get_next()

with tf.Session() as sess:
  print(sess.run(el))
  print(sess.run(el))
  print(sess.run(el))
  print(sess.run(el))
  print(sess.run(el))
  print(sess.run(el))
  print(sess.run(el))
  print(sess.run(el))
  print(sess.run(el))
  print(sess.run(el))
  print(sess.run(el))
  print(sess.run(el))
  print(sess.run(el))

[[ 0.5488135  0.71518937]
 [ 0.60276338 0.54488318]
 [ 0.4236548  0.64589411]
 [ 0.43758721 0.891773 ]
 [ 0.96366276 0.38344152]
 [ 0.79172504 0.52889492]
 [ 0.56804456 0.92559664]
 [ 0.07103606 0.0871293 ]
 [ 0.0202184  0.83261985]
 [ 0.77815675 0.87001215]
 [ 0.97861834 0.79915856]]

[[ 0.5488135  0.71518937]
 [ 0.60276338 0.54488318]
 [ 0.4236548  0.64589411]
 [ 0.43758721 0.891773 ]]
[[ 0.96366276 0.38344152]
 [ 0.79172504 0.52889492]
 [ 0.56804456 0.92559664]
 [ 0.07103606 0.0871293 ]]
[[ 0.0202184  0.83261985]
 [ 0.77815675 0.87001215]
 [ 0.97861834 0.79915856]]
[[ 0.5488135  0.71518937]
 [ 0.60276338 0.54488318]
 [ 0.4236548  0.64589411]
 [ 0.43758721 0.891773 ]]
[[ 0.96366276 0.38344152]
 [ 0.79172504 0.52889492]
 [ 0.56804456 0.92559664]
 [ 0.07103606 0.0871293 ]]
[[ 0.0202184  0.83261985]
 [ 0.77815675 0.87001215]
 [ 0.97861834 0.79915856]]

注意如果repeat在shuffle之前使用:

官方说repeat在shuffle之前使用能提高性能,但模糊了数据样本的epoch关系

import os
os.environ['CUDA_VISIBLE_DEVICES'] = ""
import numpy as np
import tensorflow as tf
np.random.seed(0)
x = np.random.sample((11,2))
# make a dataset from a numpy array
print(x)
print()
dataset = tf.data.Dataset.from_tensor_slices(x)
dataset = dataset.repeat(2)
dataset = dataset.shuffle(11)
dataset = dataset.batch(4)

# create the iterator
iter = dataset.make_one_shot_iterator()
el = iter.get_next()

with tf.Session() as sess:
  print(sess.run(el))
  print(sess.run(el))
  print(sess.run(el))
  print(sess.run(el))
  print(sess.run(el))
  print(sess.run(el))
  print(sess.run(el))
  print(sess.run(el))
  print(sess.run(el))
  print(sess.run(el))
  print(sess.run(el))
  print(sess.run(el))
  print(sess.run(el))

[[ 0.5488135  0.71518937]
 [ 0.60276338 0.54488318]
 [ 0.4236548  0.64589411]
 [ 0.43758721 0.891773 ]
 [ 0.96366276 0.38344152]
 [ 0.79172504 0.52889492]
 [ 0.56804456 0.92559664]
 [ 0.07103606 0.0871293 ]
 [ 0.0202184  0.83261985]
 [ 0.77815675 0.87001215]
 [ 0.97861834 0.79915856]]

[[ 0.56804456 0.92559664]
 [ 0.5488135  0.71518937]
 [ 0.60276338 0.54488318]
 [ 0.07103606 0.0871293 ]]
[[ 0.96366276 0.38344152]
 [ 0.43758721 0.891773 ]
 [ 0.43758721 0.891773 ]
 [ 0.77815675 0.87001215]]
[[ 0.79172504 0.52889492]  #出现相同样本出现在同一个batch中
 [ 0.79172504 0.52889492]
 [ 0.60276338 0.54488318]
 [ 0.4236548  0.64589411]]
[[ 0.07103606 0.0871293 ]
 [ 0.4236548  0.64589411]
 [ 0.96366276 0.38344152]
 [ 0.5488135  0.71518937]]
[[ 0.97861834 0.79915856]
 [ 0.0202184  0.83261985]
 [ 0.77815675 0.87001215]
 [ 0.56804456 0.92559664]]
[[ 0.0202184  0.83261985]
 [ 0.97861834 0.79915856]]     #可以看到最后个batch为2,而前面都是4

使用案例:

def input_fn(filenames, batch_size=32, num_epochs=1, perform_shuffle=False):
  print('Parsing', filenames)
  def decode_libsvm(line):
    #columns = tf.decode_csv(value, record_defaults=CSV_COLUMN_DEFAULTS)
    #features = dict(zip(CSV_COLUMNS, columns))
    #labels = features.pop(LABEL_COLUMN)
    columns = tf.string_split([line], ' ')
    labels = tf.string_to_number(columns.values[0], out_type=tf.float32)
    splits = tf.string_split(columns.values[1:], ':')
    id_vals = tf.reshape(splits.values,splits.dense_shape)
    feat_ids, feat_vals = tf.split(id_vals,num_or_size_splits=2,axis=1)
    feat_ids = tf.string_to_number(feat_ids, out_type=tf.int32)
    feat_vals = tf.string_to_number(feat_vals, out_type=tf.float32)
    #feat_ids = tf.reshape(feat_ids,shape=[-1,FLAGS.field_size])
    #for i in range(splits.dense_shape.eval()[0]):
    #  feat_ids.append(tf.string_to_number(splits.values[2*i], out_type=tf.int32))
    #  feat_vals.append(tf.string_to_number(splits.values[2*i+1]))
    #return tf.reshape(feat_ids,shape=[-1,field_size]), tf.reshape(feat_vals,shape=[-1,field_size]), labels
    return {"feat_ids": feat_ids, "feat_vals": feat_vals}, labels

  # Extract lines from input files using the Dataset API, can pass one filename or filename list
  dataset = tf.data.TextLineDataset(filenames).map(decode_libsvm, num_parallel_calls=10).prefetch(500000)  # multi-thread pre-process then prefetch

  # Randomizes input using a window of 256 elements (read into memory)
  if perform_shuffle:
    dataset = dataset.shuffle(buffer_size=256)

  # epochs from blending together.
  dataset = dataset.repeat(num_epochs)
  dataset = dataset.batch(batch_size) # Batch size to use

  #return dataset.make_one_shot_iterator()
  iterator = dataset.make_one_shot_iterator()
  batch_features, batch_labels = iterator.get_next()
  #return tf.reshape(batch_ids,shape=[-1,field_size]), tf.reshape(batch_vals,shape=[-1,field_size]), batch_labels
  return batch_features, batch_labels

到此这篇关于浅谈tensorflow中dataset.shuffle和dataset.batch dataset.repeat注意点的文章就介绍到这了,更多相关tensorflow中dataset.shuffle和dataset.batch dataset.repeat内容请搜索三水点靠木以前的文章或继续浏览下面的相关文章希望大家以后多多支持三水点靠木! 

Python 相关文章推荐
python类:class创建、数据方法属性及访问控制详解
Jul 25 Python
python2.6.6如何升级到python2.7.14
Apr 08 Python
详解Django之auth模块(用户认证)
Apr 17 Python
使用sklearn之LabelEncoder将Label标准化的方法
Jul 11 Python
Python wxPython库Core组件BoxSizer用法示例
Sep 03 Python
python 模拟创建seafile 目录操作示例
Sep 26 Python
Python中remove漏删和索引越界问题的解决
Mar 18 Python
Django基于客户端下载文件实现方法
Apr 21 Python
python3.6.8 + pycharm + PyQt5 环境搭建的图文教程
Jun 11 Python
Python OpenCV超详细讲解读取图像视频和网络摄像头
Apr 02 Python
virtualenv隔离Python环境的问题解析
Jun 21 Python
Python可视化神器pyecharts之绘制箱形图
Jul 07 Python
Python3通过chmod修改目录或文件权限的方法示例
Jun 08 #Python
win10下python3.8的PIL库安装过程
Jun 08 #Python
python rolling regression. 使用 Python 实现滚动回归操作
Jun 08 #Python
Python selenium爬虫实现定时任务过程解析
Jun 08 #Python
python:HDF和CSV存储优劣对比分析
Jun 08 #Python
Python实现一个简单的毕业生信息管理系统的示例代码
Jun 08 #Python
Python while true实现爬虫定时任务
Jun 08 #Python
You might like
espresso double下 咖啡粉超细时 饼压力对咖啡的影响
2021/03/03 冲泡冲煮
Win2003下APACHE+PHP5+MYSQL4+PHPMYADMIN 的简易安装配置
2006/11/18 PHP
如何在PHP中使用正则表达式进行查找替换
2013/06/13 PHP
destoon后台网站设置变成空白的解决方法
2014/06/21 PHP
轻松掌握php设计模式之访问者模式
2016/09/23 PHP
PHP实现的DES加密解密类定义与用法示例
2020/11/02 PHP
Laravel Reponse响应客户端示例详解
2020/09/03 PHP
List the UTC Time on a Computer
2007/06/11 Javascript
关于B/S判断浏览器断开的问题讨论
2008/10/29 Javascript
在页面上用action传递参数到后台出现乱码的解决方法
2013/12/31 Javascript
jQuery实现向下滑出的二级菜单效果实例
2015/08/22 Javascript
AngularJS Bootstrap详细介绍及实例代码
2016/07/28 Javascript
Vue.js实现表格动态增加删除的方法(附源码下载)
2017/01/20 Javascript
父组件中vuex方法更新state子组件不能及时更新并渲染的完美解决方法
2018/04/25 Javascript
JavaScript实现指定数量的并发限制的示例代码
2020/03/10 Javascript
记一次react前端项目打包优化的方法
2020/03/30 Javascript
JS中间件设计模式的深入探讨与实例分析
2020/04/11 Javascript
python实现从字典中删除元素的方法
2015/05/04 Python
python的变量与赋值详细分析
2017/11/08 Python
Python 经典面试题 21 道【不可错过】
2018/09/21 Python
Python split() 函数拆分字符串将字符串转化为列的方法
2019/07/16 Python
python将图片转base64,实现前端显示
2020/01/09 Python
python with (as)语句实例详解
2020/02/04 Python
Python如何输出整数
2020/06/07 Python
Python描述数据结构学习之哈夫曼树篇
2020/09/07 Python
python调用摄像头的示例代码
2020/09/28 Python
jupyter 添加不同内核的操作
2021/02/06 Python
详解淘宝H5 sign加密算法
2020/08/25 HTML / CSS
Radley英国官网:英国莱德利小狗包
2019/03/21 全球购物
eBay英国购物网站:eBay.co.uk
2019/06/19 全球购物
英国高街奥特莱斯:Highstreet Outlet
2019/11/21 全球购物
Vector, ArrayList, HashTable, HashMap哪些是线程安全的,哪些不是
2015/10/12 面试题
德育标兵事迹材料
2014/08/24 职场文书
忠诚奉献演讲稿
2014/09/12 职场文书
小学生勤俭节约倡议书
2015/04/29 职场文书
一篇文章弄懂Python中的内建函数
2021/08/07 Python