tensorflow中next_batch的具体使用


Posted in Python onFebruary 02, 2018

本文介绍了tensorflow中next_batch的具体使用,分享给大家,具体如下:

此处给出了几种不同的next_batch方法,该文章只是做出代码片段的解释,以备以后查看:

def next_batch(self, batch_size, fake_data=False):
  """Return the next `batch_size` examples from this data set."""
  if fake_data:
   fake_image = [1] * 784
   if self.one_hot:
    fake_label = [1] + [0] * 9
   else:
    fake_label = 0
   return [fake_image for _ in xrange(batch_size)], [
     fake_label for _ in xrange(batch_size)
   ]
  start = self._index_in_epoch
  self._index_in_epoch += batch_size
  if self._index_in_epoch > self._num_examples: # epoch中的句子下标是否大于所有语料的个数,如果为True,开始新一轮的遍历
   # Finished epoch
   self._epochs_completed += 1
   # Shuffle the data
   perm = numpy.arange(self._num_examples) # arange函数用于创建等差数组
   numpy.random.shuffle(perm) # 打乱
   self._images = self._images[perm]
   self._labels = self._labels[perm]
   # Start next epoch
   start = 0
   self._index_in_epoch = batch_size
   assert batch_size <= self._num_examples
  end = self._index_in_epoch
  return self._images[start:end], self._labels[start:end]

该段代码摘自mnist.py文件,从代码第12行start = self._index_in_epoch开始解释,_index_in_epoch-1是上一次batch个图片中最后一张图片的下边,这次epoch第一张图片的下标是从 _index_in_epoch开始,最后一张图片的下标是_index_in_epoch+batch, 如果 _index_in_epoch 大于语料中图片的个数,表示这个epoch是不合适的,就算是完成了语料的一遍的遍历,所以应该对图片洗牌然后开始新一轮的语料组成batch开始

def ptb_iterator(raw_data, batch_size, num_steps):
 """Iterate on the raw PTB data.

 This generates batch_size pointers into the raw PTB data, and allows
 minibatch iteration along these pointers.

 Args:
  raw_data: one of the raw data outputs from ptb_raw_data.
  batch_size: int, the batch size.
  num_steps: int, the number of unrolls.

 Yields:
  Pairs of the batched data, each a matrix of shape [batch_size, num_steps].
  The second element of the tuple is the same data time-shifted to the
  right by one.

 Raises:
  ValueError: if batch_size or num_steps are too high.
 """
 raw_data = np.array(raw_data, dtype=np.int32)

 data_len = len(raw_data)
 batch_len = data_len // batch_size #有多少个batch
 data = np.zeros([batch_size, batch_len], dtype=np.int32) # batch_len 有多少个单词
 for i in range(batch_size): # batch_size 有多少个batch
  data[i] = raw_data[batch_len * i:batch_len * (i + 1)]

 epoch_size = (batch_len - 1) // num_steps # batch_len 是指一个batch中有多少个句子
 #epoch_size = ((len(data) // model.batch_size) - 1) // model.num_steps # // 表示整数除法
 if epoch_size == 0:
  raise ValueError("epoch_size == 0, decrease batch_size or num_steps")

 for i in range(epoch_size):
  x = data[:, i*num_steps:(i+1)*num_steps]
  y = data[:, i*num_steps+1:(i+1)*num_steps+1]
  yield (x, y)

第三种方式:

def next(self, batch_size):
    """ Return a batch of data. When dataset end is reached, start over.
    """
    if self.batch_id == len(self.data):
      self.batch_id = 0
    batch_data = (self.data[self.batch_id:min(self.batch_id +
                         batch_size, len(self.data))])
    batch_labels = (self.labels[self.batch_id:min(self.batch_id +
                         batch_size, len(self.data))])
    batch_seqlen = (self.seqlen[self.batch_id:min(self.batch_id +
                         batch_size, len(self.data))])
    self.batch_id = min(self.batch_id + batch_size, len(self.data))
    return batch_data, batch_labels, batch_seqlen

第四种方式:

def batch_iter(sourceData, batch_size, num_epochs, shuffle=True):
  data = np.array(sourceData) # 将sourceData转换为array存储
  data_size = len(sourceData)
  num_batches_per_epoch = int(len(sourceData) / batch_size) + 1
  for epoch in range(num_epochs):
    # Shuffle the data at each epoch
    if shuffle:
      shuffle_indices = np.random.permutation(np.arange(data_size))
      shuffled_data = sourceData[shuffle_indices]
    else:
      shuffled_data = sourceData

    for batch_num in range(num_batches_per_epoch):
      start_index = batch_num * batch_size
      end_index = min((batch_num + 1) * batch_size, data_size)

      yield shuffled_data[start_index:end_index]

迭代器的用法,具体学习Python迭代器的用法

另外需要注意的是,前三种方式只是所有语料遍历一次,而最后一种方法是,所有语料遍历了num_epochs次

以上就是本文的全部内容,希望对大家的学习有所帮助,也希望大家多多支持三水点靠木。

Python 相关文章推荐
朴素贝叶斯算法的python实现方法
Nov 18 Python
Python写的服务监控程序实例
Jan 31 Python
10款最好的Web开发的 Python 框架
Mar 18 Python
Python使用QQ邮箱发送Email的方法实例
Feb 09 Python
浅谈Python2.6和Python3.0中八进制数字表示的区别
Apr 28 Python
Python各类图像库的图片读写方式总结(推荐)
Feb 23 Python
使用Python对微信好友进行数据分析
Jun 27 Python
浅谈Python traceback的优雅处理
Aug 31 Python
python实现集中式的病毒扫描功能详解
Jul 09 Python
Python Websocket服务端通信的使用示例
Feb 25 Python
keras多显卡训练方式
Jun 10 Python
使用python求解迷宫问题的三种实现方法
Mar 17 Python
Python输出各行命令详解
Feb 01 #Python
Python输出由1,2,3,4组成的互不相同且无重复的三位数
Feb 01 #Python
Python实现的视频播放器功能完整示例
Feb 01 #Python
Python线性回归实战分析
Feb 01 #Python
Python使用matplotlib简单绘图示例
Feb 01 #Python
Python解决抛小球问题 求小球下落经历的距离之和示例
Feb 01 #Python
Python 判断 有向图 是否有环的实例讲解
Feb 01 #Python
You might like
暴雪前总裁遗憾:没尽早追赶Dota 取消星际争霸幽灵
2020/03/08 星际争霸
咖啡的植物学知识
2021/03/03 咖啡文化
php加密解密函数authcode的用法详细解析
2013/10/28 PHP
php支付宝手机网页支付类实例
2015/03/04 PHP
深入理解php printf() 输出格式化的字符串
2016/05/23 PHP
Jquery显示、隐藏元素以及添加删除样式
2013/08/09 Javascript
ECMAScript 5严格模式(Strict Mode)介绍
2015/03/02 Javascript
JavaScript中指定函数名称的相关方法
2015/06/04 Javascript
JavaScript中pop()方法的使用教程
2015/06/09 Javascript
javaScript实现可缩放的显示区效果代码
2015/10/26 Javascript
flag和jq on 的绑定多个对象和方法(必看)
2017/02/27 Javascript
JS实现的简单拖拽功能示例
2017/03/13 Javascript
简单谈谈原生js的math对象
2017/06/27 Javascript
JQuery用$.ajax或$.getJSON跨域获取JSON数据的实现代码
2017/09/23 jQuery
JS简单实现滑动加载数据的方法示例
2017/10/18 Javascript
自定义PC微信扫码登录样式写法
2017/12/12 Javascript
详解vue 数组和对象渲染问题
2018/09/21 Javascript
Vue 请求传公共参数的操作
2020/07/31 Javascript
[52:05]EG vs OG 2019国际邀请赛小组赛 BO2 第二场 8.16
2019/08/18 DOTA
使用Python的Twisted框架构建非阻塞下载程序的实例教程
2016/05/25 Python
python3新特性函数注释Function Annotations用法分析
2016/07/28 Python
Python单例模式实例详解
2017/03/01 Python
matplotlib.pyplot绘图显示控制方法
2019/01/15 Python
解决ROC曲线画出来只有一个点的问题
2020/02/28 Python
详解pandas赋值失败问题解决
2020/11/29 Python
联想美国官方商城:Lenovo美国
2017/06/19 全球购物
德国孕妇装和婴童服装网上商店:bellybutton
2018/04/12 全球购物
PHP笔试题
2012/02/22 面试题
可口可乐广告词
2014/03/20 职场文书
无房证明范本
2014/09/17 职场文书
高中生第一学年自我鉴定2015
2014/09/28 职场文书
暑期实践个人总结
2015/03/06 职场文书
总账会计岗位职责
2015/04/02 职场文书
使用canvas实现雪花飘动效果的示例代码
2021/03/30 HTML / CSS
Nginx域名转发使用场景代码实例
2021/03/31 Servers
详解MySQL中的pid与socket
2021/06/15 MySQL