基于tensorflow for循环 while循环案例


Posted in Python onJune 30, 2020

我就废话不多说了,大家还是直接看代码吧~

import tensorflow as tf

n1 = tf.constant(2)
n2 = tf.constant(3)
n3 = tf.constant(4)

def cond1(i, a, b):
 return i < n1

def cond2(i, a, b):
 return i < n2

def cond3(i, a, b):
 return i < n3

def body(i, a, b):
 return i + 1, b, a + b

i1, a1, b1 = tf.while_loop(cond1, body, (2, 1, 1))
i2, a2, b2 = tf.while_loop(cond2, body, (2, 1, 1))
i3, a3, b3 = tf.while_loop(cond3, body, (2, 1, 1))
sess = tf.Session()

print(sess.run(i1))
print(sess.run(a1))
print(sess.run(b1))
print("-")
print(sess.run(i2))
print(sess.run(a2))
print(sess.run(b2))
print("-")
print(sess.run(i3))
print(sess.run(a3))
print(sess.run(b3))

print结果:

2
1
1
-
3
1
2
-
4
2
3

可见body函数返回的三个变量又传给了body

补充知识:tensorflow在tf.while_loop循环(非一般循环)中使用操纵变量该怎么做

代码(操纵全局变量)

xiaojie=1
i=tf.constant(0,dtype=tf.int32)
batch_len=tf.constant(10,dtype=tf.int32)
loop_cond = lambda a,b: tf.less(a,batch_len)
#yy=tf.Print(batch_len,[batch_len],"batch_len:")
yy=tf.constant(0)
loop_vars=[i,yy]
def _recurrence(i,yy):
 c=tf.constant(2,dtype=tf.int32)
 x=tf.multiply(i,c)
 global xiaojie
 xiaojie=xiaojie+1
 print_info=tf.Print(x,[x],"x:")
 yy=yy+print_info
 i=tf.add(i,1)
# print (xiaojie)
 return i,yy
i,yy=tf.while_loop(loop_cond,_recurrence,loop_vars,parallel_iterations=1)#可以批处理
sess = tf.Session()
print (sess.run(i))
print (xiaojie)

输出的是10和2。

也就是xiaojie只被修改了一次。

这个时候,在_recurrence循环体中添加语句

print (xiaojie)

会输出2。而且只输出一次。具体为什么,最后总结的时候再解释。

代码(操纵类成员变量)class RNN_Model():

def __init__(self):
  self.xiaojie=1
 def test_RNN(self):
  i=tf.constant(0,dtype=tf.int32)
  batch_len=tf.constant(10,dtype=tf.int32)
  loop_cond = lambda a,b: tf.less(a,batch_len)
  #yy=tf.Print(batch_len,[batch_len],"batch_len:")
  yy=tf.constant(0)
  loop_vars=[i,yy]
  def _recurrence(i,yy):
   c=tf.constant(2,dtype=tf.int32)
   x=tf.multiply(i,c)
   self.xiaojie=self.xiaojie+1
   print_info=tf.Print(x,[x],"x:")
   yy=yy+print_info
   i=tf.add(i,1)

  print ("_recurrence:",self.xiaojie)
   return i,yy
  i,yy=tf.while_loop(loop_cond,_recurrence,loop_vars,parallel_iterations=1)#可以批处理
  sess = tf.Session()
  sess.run(yy)
  print (self.xiaojie)
if __name__ == "__main__":
 model = RNN_Model()#构建树,并且构建词典
 model.test_RNN()

输出是:

_recurrence: 2
10
2

tf.while_loop操纵全局变量和类成员变量总结

为什么_recurrence中定义的print操作只执行一次呢,这是因为_recurrence中的print相当于一种对代码的定义,直接在定义的过程中就执行了。所以,可以看到输出是在sess.run之前的。但是,定义的其它操作就是数据流图中的操作,需要在sess.run中执行。

就必须在sess.run中执行。但是,全局变量xiaojie也好,还是类成员变量xiaojie也好。其都不是图中的内容。因此,tf.while_loop执行的是tensorflow计算图中的循环,对于不是在计算图中的,就不会参与循环。注意:而且必须是与loop_vars中指定的变量存在数据依赖关系的tensor才可以!此外,即使是依赖关系,也必须是_recurrence循环体中return出的变量,才会真正的变化。比如,见下面的self.L。总之,想操纵变量,就要传入loop_vars!

如果对一个变量没有修改,就可以直接在循环中以操纵类成员变量或者全局变量的方式只读。

self.L与loop_vars中变量有依赖关系,但是并没有真正被修改。

#IIII通过计算将非叶子节点的词向量也放入nodes_tensor中。
   iiii=tf.constant(0,dtype=tf.int32)
   loop____cond = lambda a,b,c,d,e: tf.less(a,self.sentence_length-1)#iiii的范围是0到sl-2。注意,不包括sl-1。这是因为只需要计算sentence_length-1次,就能构建出一颗树
   loop____vars=[iiii,columnLinesOfL,node_tensors_cost_tensor,nodes_tensor,tfPrint]
   def ____recurrence(iiii,columnLinesOfL,node_tensors_cost_tensor,nodes_tensor,tfPrint):#循环的目的是实现Greedy算法
    ###
    #Greedy的主要目标就是确立树结构。    
    ###  
    c1 = self.L[:,0:columnLinesOfL-1]#这段代码是从RvNN的matlab的源码中复制过来的,但是Matlab的下标是从1开始,并且Matlab中1:2就是1和2,而python中1:2表示的是1,不包括2,所以,有很大的不同。
    c2 = self.L[:,1:columnLinesOfL]
    c=tf.concat([c1,c2],axis=0)
    p=tf.tanh(tf.matmul(self.W1,c)+tf.tile(self.b1,[1,columnLinesOfL-1]))
    p_normalization=self.normalization(p)
    y=tf.tanh(tf.matmul(self.U,p_normalization)+tf.tile(self.bs,[1,columnLinesOfL-1]))#根据Matlab中的源码来的,即重构后,也有一个激活的过程。
    #将Y矩阵拆分成上下部分之后,再分别进行标准化。
    columnlines_y=columnLinesOfL-1
    (y1,y2)=self.split_by_row(y,columnlines_y)
    y1_normalization=self.normalization(y1)
    y2_normalization=self.normalization(y2)
    #论文中提出一种计算重构误差时要考虑的权重信息。具体见论文,这里暂时不实现。
    #这个权重是可以修改的。
    alpha_cat=1 
    bcat=1
    #计算重构误差矩阵
##    constant1=tf.constant([[1.0,2.0,3.0],[4.0,5.0,6.0],[7.0,8.0,9.0]])
##    constant2=tf.constant([[1.0,2.0,3.0],[1.0,4.0,2.0],[1.0,6.0,1.0]])
##    constructionErrorMatrix=self.constructionError(constant1,constant2,alpha_cat,bcat)
    y1c1=tf.subtract(y1_normalization,c1)
    y2c2=tf.subtract(y2_normalization,c2)    
    constructionErrorMatrix=self.constructionError(y1c1,y2c2,alpha_cat,bcat)
################################################################################
    print_info=tf.Print(iiii,[iiii],"\niiii:")#专门为了调试用,输出相关信息。
    tfPrint=print_info+tfPrint
    print_info=tf.Print(columnLinesOfL,[columnLinesOfL],"\nbefore modify. columnLinesOfL:")#专门为了调试用,输出相关信息。
    tfPrint=print_info+tfPrint
    print_info=tf.Print(constructionErrorMatrix,[constructionErrorMatrix],"\nbefore modify. constructionErrorMatrix:",summarize=100)#专门为了调试用,输出相关信息。
    tfPrint=tf.to_int32(print_info[0])+tfPrint#一种不断输出tf.Print的方式,注意tf.Print的返回值。
################################################################################
    J_minpos=tf.to_int32(tf.argmin(constructionErrorMatrix))#如果不转换的话,下面调用delete_one_column中,会调用tf.slice,之后tf.slice的参数中的类型必须是一样的。
    J_min=constructionErrorMatrix[J_minpos]
    #一共要进行sl-1次循环。因为是从sl个叶子节点,两两结合sl-1次,才能形成一颗完整的树,而且是采用Greedy的方式。
    #所以,需要为下次循环做准备。
    #第一步,从该sentence的词向量矩阵中删除第J_minpos+1列,因为第J_minpos和第J_minpos+1列对应的单词要合并为一个新的节点,这里就是修改L
################################################################################
    print_info=tf.Print(self.L,[self.L[0]],"\nbefore modify. L row 0:",summarize=100)#专门为了调试用,输出相关信息。
    tfPrint=tf.to_int32(print_info[0][0])+tfPrint
    print_info=tf.Print(self.L,[tf.shape(self.L)],"\nbefore modify. L shape:")#专门为了调试用,输出相关信息。
    tfPrint=tf.to_int32(print_info[0][0])+tfPrint
################################################################################
    deleteColumnIndex=J_minpos+1
    self.L=self.delete_one_column(self.L,deleteColumnIndex,self.numlinesOfL,columnLinesOfL)
    columnLinesOfL=tf.subtract(columnLinesOfL,1) #列数减去1.
################################################################################
    print_info=tf.Print(deleteColumnIndex,[deleteColumnIndex],"\nbefore modify. deleteColumnIndex:")#专门为了调试用,输出相关信息。
    tfPrint=print_info+tfPrint
    print_info=tf.Print(self.L,[self.L[0]],"\nafter modify. L row 0:",summarize=100)#专门为了调试用,输出相关信息。
    tfPrint=tf.to_int32(print_info[0][0])+tfPrint
    
    print_info=tf.Print(self.L,[tf.shape(self.L)],"\nafter modify. L shape:")#专门为了调试用,输出相关信息。
    tfPrint=tf.to_int32(print_info[0][0])+tfPrint
    print_info=tf.Print(columnLinesOfL,[columnLinesOfL],"\nafter modify. columnLinesOfL:")#专门为了调试用,输出相关信息。
    tfPrint=print_info+tfPrint
################################################################################
    
    #第二步,将新的词向量赋值给第J_minpos列
    columnTensor=p_normalization[:,J_minpos]
    new_column_tensor=tf.expand_dims(columnTensor,1)
    self.L=self.modify_one_column(self.L,new_column_tensor,J_minpos,self.numlinesOfL,columnLinesOfL)
    #第三步,同时将新的非叶子节点的词向量存入nodes_tensor
    modified_index_tensor=tf.to_int32(tf.add(iiii,self.sentence_length))
    nodes_tensor=self.modify_one_column(nodes_tensor,new_column_tensor,modified_index_tensor,self.numlines_tensor,self.numcolunms_tensor)
    #第四步:记录合并节点的最小损失,存入node_tensors_cost_tensor
    J_min_tensor=tf.expand_dims(tf.expand_dims(J_min,0),1)
    node_tensors_cost_tensor=self.modify_one_column(node_tensors_cost_tensor,J_min_tensor,iiii,self.numlines_tensor2,self.numcolunms_tensor2)
    ####进入下一次循环
    iiii=tf.add(iiii,1)
    print_info=tf.Print(J_minpos,[J_minpos,J_minpos+1],"node:")#专门为了调试用,输出相关信息。
    tfPrint=tfPrint+print_info
#    columnLinesOfL=tf.subtract(columnLinesOfL,1) #在上面的循环体中已经执行了,没有必要再执行。
    return iiii,columnLinesOfL,node_tensors_cost_tensor,nodes_tensor,tfPrint
   iiii,columnLinesOfL,node_tensors_cost_tensor,nodes_tensor,tfPrint=tf.while_loop(loop____cond,____recurrence,loop____vars,parallel_iterations=1)
   pass

上述代码是Greedy算法,递归构建神经网络树结构。

但是程序出错了,后来不断的调试,才发现self.L虽然跟循环loop____vars中的变量有依赖关系,也就是在tf.while_loop进行循环的时候,也可以输出它的值。

但是,它每一次都无法真正意义上对self.L进行修改。会发现,每一次循环结束之后,进入下一次循环时,self.L仍然没有变化。

执行结果如下:

before modify. columnLinesOfL:[31]
iiii:[0]

after modify. columnLinesOfL:[30]

before modify. L shape:[300 31]

before modify. L row 0:[0.126693 -0.013654 -0.166731 -0.13703 -0.261395 0.11459 0.016001 0.016001 0.144603 0.05588 0.171787 0.016001 1.064545 0.144603 0.130615 -0.13703 -0.261395 1.064545 -0.261395 0.144603 0.036626 1.064545 0.188871 0.201198 0.05588 0.203795 0.201198 0.03536 0.089345 0.083778 0.103635]
node:[0][1]

before modify. constructionErrorMatrix:[3.0431733686706206 11.391056715427794 19.652819956115856 13.713453313903868 11.625973829805879 12.827533320819564 9.7513513723204746 13.009151292890811 13.896089243289065 10.649829109971648 9.45239374745086 15.704486086921641 18.274065790781862 12.447866299915024 15.302996103637689 13.713453313903868 14.295549844738751 13.779406175789358 11.625212314259059 16.340507223201449 19.095964364689717 15.10149194936319 11.989443162329437 13.436654650354058 11.120373311110505 12.39345317975002 13.568052800712424 10.998430341124633 8.3223909323599869 6.8896857405641851]

after modify. L shape:[300 30]

after modify. L row 0:[0.126693 -0.166731 -0.13703 -0.261395 0.11459 0.016001 0.016001 0.144603 0.05588 0.171787 0.016001 1.064545 0.144603 0.130615 -0.13703 -0.261395 1.064545 -0.261395 0.144603 0.036626 1.064545 0.188871 0.201198 0.05588 0.203795 0.201198 0.03536 0.089345 0.083778 0.103635]

before modify. deleteColumnIndex:[1]

before modify. columnLinesOfL:[30]

iiii:[1]

before modify. L shape:[300 31]

after modify. columnLinesOfL:[29]

before modify. L row 0:[0.126693 -0.013654 -0.166731 -0.13703 -0.261395 0.11459 0.016001 0.016001 0.144603 0.05588 0.171787 0.016001 1.064545 0.144603 0.130615 -0.13703 -0.261395 1.064545 -0.261395 0.144603 0.036626 1.064545 0.188871 0.201198 0.05588 0.203795 0.201198 0.03536 0.089345 0.083778 0.103635]

before modify. deleteColumnIndex:[1]
node:[0][1]

before modify. constructionErrorMatrix:[3.0431733686706206 11.391056715427794 19.652819956115856 13.713453313903868 11.625973829805879 12.827533320819564 9.7513513723204746 13.009151292890811 13.896089243289065 10.649829109971648 9.45239374745086 15.704486086921641 18.274065790781862 12.447866299915024 15.302996103637689 13.713453313903868 14.295549844738751 13.779406175789358 11.625212314259059 16.340507223201449 19.095964364689717 15.10149194936319 11.989443162329437 13.436654650354058 11.120373311110505 12.39345317975002 13.568052800712424 10.998430341124633 8.3223909323599869]

after modify. L shape:[300 29]

after modify. L row 0:[0.126693 -0.166731 -0.13703 -0.261395 0.11459 0.016001 0.016001 0.144603 0.05588 0.171787 0.016001 1.064545 0.144603 0.130615 -0.13703 -0.261395 1.064545 -0.261395 0.144603 0.036626 1.064545 0.188871 0.201198 0.05588 0.203795 0.201198 0.03536 0.089345 0.083778]

before modify. columnLinesOfL:[29]

iiii:[2]

后面那个after modify时L shape为[300 29]的原因是:执行

self.L=self.modify_one_column(self.L,new_column_tensor,J_minpos,self.numlinesOfL,columnLinesOfL)

时,columnLinesOfL是循环loop____vars中的变量,因此会随着每次循环发生变化,我写的modify_one_column见我的博文“修改tensor张量矩阵的某一列”。它决定了

修改后tensor的维度。

但是,无论如何,每一次循环,都是

before modify. L shape:[300 31]

说明self.L在循环体中虽然被修改了。但是下次循环又会被重置为初始值。

以上这篇基于tensorflow for循环 while循环案例就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
教你用Type Hint提高Python程序开发效率
Aug 08 Python
200行自定义python异步非阻塞Web框架
Mar 15 Python
python机器学习理论与实战(一)K近邻法
Jan 28 Python
Python实现使用卷积提取图片轮廓功能示例
May 12 Python
详谈Pandas中iloc和loc以及ix的区别
Jun 08 Python
django如何连接已存在数据的数据库
Aug 14 Python
解决python中遇到字典里key值为None的情况,取不出来的问题
Oct 17 Python
Python实现的旋转数组功能算法示例
Feb 23 Python
python实现AES加密和解密
Mar 27 Python
python 类的继承 实例方法.静态方法.类方法的代码解析
Aug 23 Python
Python 实现顺序高斯消元法示例
Dec 09 Python
python实现简单的购物程序代码实例
Mar 03 Python
解析Tensorflow之MNIST的使用
Jun 30 #Python
Tensorflow tensor 数学运算和逻辑运算方式
Jun 30 #Python
Python requests模块安装及使用教程图解
Jun 30 #Python
在Tensorflow中实现leakyRelu操作详解(高效)
Jun 30 #Python
TensorFlow-gpu和opencv安装详细教程
Jun 30 #Python
tensorflow 2.1.0 安装与实战教程(CASIA FACE v5)
Jun 30 #Python
python 最简单的实现适配器设计模式的示例
Jun 30 #Python
You might like
图书管理程序(三)
2006/10/09 PHP
PHP执行zip与rar解压缩方法实现代码
2010/12/05 PHP
深入php define()函数以及defined()函数的用法详解
2013/06/05 PHP
php定义一个参数带有默认值的函数实例分析
2015/03/16 PHP
PHP Echo字符串的连接格式
2016/03/07 PHP
Laravel 模型关联基础教程详解
2019/09/17 PHP
PHP下载文件函数与用法示例
2019/09/27 PHP
Thinkphp5.0 框架Model模型简单用法分析
2019/10/11 PHP
推荐10个超棒的jQuery工具提示插件
2011/10/11 Javascript
基于jQuery实现下拉收缩(展开与折叠)特效
2012/12/25 Javascript
js 单击式的下拉菜单效果实例
2013/08/13 Javascript
中止javascript执行的方法
2014/02/14 Javascript
为jquery的ajaxfileupload增加附加参数的方法
2014/03/04 Javascript
window.open不被拦截的简单实现代码(推荐)
2016/08/04 Javascript
jQuery无刷新上传之uploadify简单代码
2017/01/17 Javascript
js实现图片加载淡入淡出效果
2017/04/07 Javascript
浅谈vuejs实现数据驱动视图原理
2018/02/23 Javascript
用node撸一个监测复联4开售短信提醒的实现代码
2019/04/10 Javascript
利用Python如何生成随机密码
2016/04/20 Python
详解python中的文件与目录操作
2017/07/11 Python
解决pycharm无法识别本地site-packages的问题
2018/10/13 Python
在python中以相同顺序shuffle两个list的方法
2018/12/13 Python
python实现中文文本分句的例子
2019/07/15 Python
django做form表单的数据验证过程详解
2019/07/26 Python
Python处理mysql特殊字符的问题
2020/03/02 Python
html5自动播放mov格式视频的实例代码
2020/01/14 HTML / CSS
canvas绘制太极图的实现示例
2020/04/29 HTML / CSS
下面代码从性能上考虑,有什么问题
2015/04/03 面试题
如何掌握自荐信格式呢
2013/11/19 职场文书
区域总监的岗位职责
2013/11/21 职场文书
投标保密承诺书
2014/05/19 职场文书
幼儿园运动会口号
2014/06/07 职场文书
毕业生找工作自荐书
2014/06/30 职场文书
报案材料怎么写
2015/05/25 职场文书
三八节祝酒词
2015/08/11 职场文书
2016年寒假见闻
2015/10/10 职场文书