基于tensorflow for循环 while循环案例


Posted in Python onJune 30, 2020

我就废话不多说了,大家还是直接看代码吧~

import tensorflow as tf

n1 = tf.constant(2)
n2 = tf.constant(3)
n3 = tf.constant(4)

def cond1(i, a, b):
 return i < n1

def cond2(i, a, b):
 return i < n2

def cond3(i, a, b):
 return i < n3

def body(i, a, b):
 return i + 1, b, a + b

i1, a1, b1 = tf.while_loop(cond1, body, (2, 1, 1))
i2, a2, b2 = tf.while_loop(cond2, body, (2, 1, 1))
i3, a3, b3 = tf.while_loop(cond3, body, (2, 1, 1))
sess = tf.Session()

print(sess.run(i1))
print(sess.run(a1))
print(sess.run(b1))
print("-")
print(sess.run(i2))
print(sess.run(a2))
print(sess.run(b2))
print("-")
print(sess.run(i3))
print(sess.run(a3))
print(sess.run(b3))

print结果:

2
1
1
-
3
1
2
-
4
2
3

可见body函数返回的三个变量又传给了body

补充知识:tensorflow在tf.while_loop循环(非一般循环)中使用操纵变量该怎么做

代码(操纵全局变量)

xiaojie=1
i=tf.constant(0,dtype=tf.int32)
batch_len=tf.constant(10,dtype=tf.int32)
loop_cond = lambda a,b: tf.less(a,batch_len)
#yy=tf.Print(batch_len,[batch_len],"batch_len:")
yy=tf.constant(0)
loop_vars=[i,yy]
def _recurrence(i,yy):
 c=tf.constant(2,dtype=tf.int32)
 x=tf.multiply(i,c)
 global xiaojie
 xiaojie=xiaojie+1
 print_info=tf.Print(x,[x],"x:")
 yy=yy+print_info
 i=tf.add(i,1)
# print (xiaojie)
 return i,yy
i,yy=tf.while_loop(loop_cond,_recurrence,loop_vars,parallel_iterations=1)#可以批处理
sess = tf.Session()
print (sess.run(i))
print (xiaojie)

输出的是10和2。

也就是xiaojie只被修改了一次。

这个时候,在_recurrence循环体中添加语句

print (xiaojie)

会输出2。而且只输出一次。具体为什么,最后总结的时候再解释。

代码(操纵类成员变量)class RNN_Model():

def __init__(self):
  self.xiaojie=1
 def test_RNN(self):
  i=tf.constant(0,dtype=tf.int32)
  batch_len=tf.constant(10,dtype=tf.int32)
  loop_cond = lambda a,b: tf.less(a,batch_len)
  #yy=tf.Print(batch_len,[batch_len],"batch_len:")
  yy=tf.constant(0)
  loop_vars=[i,yy]
  def _recurrence(i,yy):
   c=tf.constant(2,dtype=tf.int32)
   x=tf.multiply(i,c)
   self.xiaojie=self.xiaojie+1
   print_info=tf.Print(x,[x],"x:")
   yy=yy+print_info
   i=tf.add(i,1)

  print ("_recurrence:",self.xiaojie)
   return i,yy
  i,yy=tf.while_loop(loop_cond,_recurrence,loop_vars,parallel_iterations=1)#可以批处理
  sess = tf.Session()
  sess.run(yy)
  print (self.xiaojie)
if __name__ == "__main__":
 model = RNN_Model()#构建树,并且构建词典
 model.test_RNN()

输出是:

_recurrence: 2
10
2

tf.while_loop操纵全局变量和类成员变量总结

为什么_recurrence中定义的print操作只执行一次呢,这是因为_recurrence中的print相当于一种对代码的定义,直接在定义的过程中就执行了。所以,可以看到输出是在sess.run之前的。但是,定义的其它操作就是数据流图中的操作,需要在sess.run中执行。

就必须在sess.run中执行。但是,全局变量xiaojie也好,还是类成员变量xiaojie也好。其都不是图中的内容。因此,tf.while_loop执行的是tensorflow计算图中的循环,对于不是在计算图中的,就不会参与循环。注意:而且必须是与loop_vars中指定的变量存在数据依赖关系的tensor才可以!此外,即使是依赖关系,也必须是_recurrence循环体中return出的变量,才会真正的变化。比如,见下面的self.L。总之,想操纵变量,就要传入loop_vars!

如果对一个变量没有修改,就可以直接在循环中以操纵类成员变量或者全局变量的方式只读。

self.L与loop_vars中变量有依赖关系,但是并没有真正被修改。

#IIII通过计算将非叶子节点的词向量也放入nodes_tensor中。
   iiii=tf.constant(0,dtype=tf.int32)
   loop____cond = lambda a,b,c,d,e: tf.less(a,self.sentence_length-1)#iiii的范围是0到sl-2。注意,不包括sl-1。这是因为只需要计算sentence_length-1次,就能构建出一颗树
   loop____vars=[iiii,columnLinesOfL,node_tensors_cost_tensor,nodes_tensor,tfPrint]
   def ____recurrence(iiii,columnLinesOfL,node_tensors_cost_tensor,nodes_tensor,tfPrint):#循环的目的是实现Greedy算法
    ###
    #Greedy的主要目标就是确立树结构。    
    ###  
    c1 = self.L[:,0:columnLinesOfL-1]#这段代码是从RvNN的matlab的源码中复制过来的,但是Matlab的下标是从1开始,并且Matlab中1:2就是1和2,而python中1:2表示的是1,不包括2,所以,有很大的不同。
    c2 = self.L[:,1:columnLinesOfL]
    c=tf.concat([c1,c2],axis=0)
    p=tf.tanh(tf.matmul(self.W1,c)+tf.tile(self.b1,[1,columnLinesOfL-1]))
    p_normalization=self.normalization(p)
    y=tf.tanh(tf.matmul(self.U,p_normalization)+tf.tile(self.bs,[1,columnLinesOfL-1]))#根据Matlab中的源码来的,即重构后,也有一个激活的过程。
    #将Y矩阵拆分成上下部分之后,再分别进行标准化。
    columnlines_y=columnLinesOfL-1
    (y1,y2)=self.split_by_row(y,columnlines_y)
    y1_normalization=self.normalization(y1)
    y2_normalization=self.normalization(y2)
    #论文中提出一种计算重构误差时要考虑的权重信息。具体见论文,这里暂时不实现。
    #这个权重是可以修改的。
    alpha_cat=1 
    bcat=1
    #计算重构误差矩阵
##    constant1=tf.constant([[1.0,2.0,3.0],[4.0,5.0,6.0],[7.0,8.0,9.0]])
##    constant2=tf.constant([[1.0,2.0,3.0],[1.0,4.0,2.0],[1.0,6.0,1.0]])
##    constructionErrorMatrix=self.constructionError(constant1,constant2,alpha_cat,bcat)
    y1c1=tf.subtract(y1_normalization,c1)
    y2c2=tf.subtract(y2_normalization,c2)    
    constructionErrorMatrix=self.constructionError(y1c1,y2c2,alpha_cat,bcat)
################################################################################
    print_info=tf.Print(iiii,[iiii],"\niiii:")#专门为了调试用,输出相关信息。
    tfPrint=print_info+tfPrint
    print_info=tf.Print(columnLinesOfL,[columnLinesOfL],"\nbefore modify. columnLinesOfL:")#专门为了调试用,输出相关信息。
    tfPrint=print_info+tfPrint
    print_info=tf.Print(constructionErrorMatrix,[constructionErrorMatrix],"\nbefore modify. constructionErrorMatrix:",summarize=100)#专门为了调试用,输出相关信息。
    tfPrint=tf.to_int32(print_info[0])+tfPrint#一种不断输出tf.Print的方式,注意tf.Print的返回值。
################################################################################
    J_minpos=tf.to_int32(tf.argmin(constructionErrorMatrix))#如果不转换的话,下面调用delete_one_column中,会调用tf.slice,之后tf.slice的参数中的类型必须是一样的。
    J_min=constructionErrorMatrix[J_minpos]
    #一共要进行sl-1次循环。因为是从sl个叶子节点,两两结合sl-1次,才能形成一颗完整的树,而且是采用Greedy的方式。
    #所以,需要为下次循环做准备。
    #第一步,从该sentence的词向量矩阵中删除第J_minpos+1列,因为第J_minpos和第J_minpos+1列对应的单词要合并为一个新的节点,这里就是修改L
################################################################################
    print_info=tf.Print(self.L,[self.L[0]],"\nbefore modify. L row 0:",summarize=100)#专门为了调试用,输出相关信息。
    tfPrint=tf.to_int32(print_info[0][0])+tfPrint
    print_info=tf.Print(self.L,[tf.shape(self.L)],"\nbefore modify. L shape:")#专门为了调试用,输出相关信息。
    tfPrint=tf.to_int32(print_info[0][0])+tfPrint
################################################################################
    deleteColumnIndex=J_minpos+1
    self.L=self.delete_one_column(self.L,deleteColumnIndex,self.numlinesOfL,columnLinesOfL)
    columnLinesOfL=tf.subtract(columnLinesOfL,1) #列数减去1.
################################################################################
    print_info=tf.Print(deleteColumnIndex,[deleteColumnIndex],"\nbefore modify. deleteColumnIndex:")#专门为了调试用,输出相关信息。
    tfPrint=print_info+tfPrint
    print_info=tf.Print(self.L,[self.L[0]],"\nafter modify. L row 0:",summarize=100)#专门为了调试用,输出相关信息。
    tfPrint=tf.to_int32(print_info[0][0])+tfPrint
    
    print_info=tf.Print(self.L,[tf.shape(self.L)],"\nafter modify. L shape:")#专门为了调试用,输出相关信息。
    tfPrint=tf.to_int32(print_info[0][0])+tfPrint
    print_info=tf.Print(columnLinesOfL,[columnLinesOfL],"\nafter modify. columnLinesOfL:")#专门为了调试用,输出相关信息。
    tfPrint=print_info+tfPrint
################################################################################
    
    #第二步,将新的词向量赋值给第J_minpos列
    columnTensor=p_normalization[:,J_minpos]
    new_column_tensor=tf.expand_dims(columnTensor,1)
    self.L=self.modify_one_column(self.L,new_column_tensor,J_minpos,self.numlinesOfL,columnLinesOfL)
    #第三步,同时将新的非叶子节点的词向量存入nodes_tensor
    modified_index_tensor=tf.to_int32(tf.add(iiii,self.sentence_length))
    nodes_tensor=self.modify_one_column(nodes_tensor,new_column_tensor,modified_index_tensor,self.numlines_tensor,self.numcolunms_tensor)
    #第四步:记录合并节点的最小损失,存入node_tensors_cost_tensor
    J_min_tensor=tf.expand_dims(tf.expand_dims(J_min,0),1)
    node_tensors_cost_tensor=self.modify_one_column(node_tensors_cost_tensor,J_min_tensor,iiii,self.numlines_tensor2,self.numcolunms_tensor2)
    ####进入下一次循环
    iiii=tf.add(iiii,1)
    print_info=tf.Print(J_minpos,[J_minpos,J_minpos+1],"node:")#专门为了调试用,输出相关信息。
    tfPrint=tfPrint+print_info
#    columnLinesOfL=tf.subtract(columnLinesOfL,1) #在上面的循环体中已经执行了,没有必要再执行。
    return iiii,columnLinesOfL,node_tensors_cost_tensor,nodes_tensor,tfPrint
   iiii,columnLinesOfL,node_tensors_cost_tensor,nodes_tensor,tfPrint=tf.while_loop(loop____cond,____recurrence,loop____vars,parallel_iterations=1)
   pass

上述代码是Greedy算法,递归构建神经网络树结构。

但是程序出错了,后来不断的调试,才发现self.L虽然跟循环loop____vars中的变量有依赖关系,也就是在tf.while_loop进行循环的时候,也可以输出它的值。

但是,它每一次都无法真正意义上对self.L进行修改。会发现,每一次循环结束之后,进入下一次循环时,self.L仍然没有变化。

执行结果如下:

before modify. columnLinesOfL:[31]
iiii:[0]

after modify. columnLinesOfL:[30]

before modify. L shape:[300 31]

before modify. L row 0:[0.126693 -0.013654 -0.166731 -0.13703 -0.261395 0.11459 0.016001 0.016001 0.144603 0.05588 0.171787 0.016001 1.064545 0.144603 0.130615 -0.13703 -0.261395 1.064545 -0.261395 0.144603 0.036626 1.064545 0.188871 0.201198 0.05588 0.203795 0.201198 0.03536 0.089345 0.083778 0.103635]
node:[0][1]

before modify. constructionErrorMatrix:[3.0431733686706206 11.391056715427794 19.652819956115856 13.713453313903868 11.625973829805879 12.827533320819564 9.7513513723204746 13.009151292890811 13.896089243289065 10.649829109971648 9.45239374745086 15.704486086921641 18.274065790781862 12.447866299915024 15.302996103637689 13.713453313903868 14.295549844738751 13.779406175789358 11.625212314259059 16.340507223201449 19.095964364689717 15.10149194936319 11.989443162329437 13.436654650354058 11.120373311110505 12.39345317975002 13.568052800712424 10.998430341124633 8.3223909323599869 6.8896857405641851]

after modify. L shape:[300 30]

after modify. L row 0:[0.126693 -0.166731 -0.13703 -0.261395 0.11459 0.016001 0.016001 0.144603 0.05588 0.171787 0.016001 1.064545 0.144603 0.130615 -0.13703 -0.261395 1.064545 -0.261395 0.144603 0.036626 1.064545 0.188871 0.201198 0.05588 0.203795 0.201198 0.03536 0.089345 0.083778 0.103635]

before modify. deleteColumnIndex:[1]

before modify. columnLinesOfL:[30]

iiii:[1]

before modify. L shape:[300 31]

after modify. columnLinesOfL:[29]

before modify. L row 0:[0.126693 -0.013654 -0.166731 -0.13703 -0.261395 0.11459 0.016001 0.016001 0.144603 0.05588 0.171787 0.016001 1.064545 0.144603 0.130615 -0.13703 -0.261395 1.064545 -0.261395 0.144603 0.036626 1.064545 0.188871 0.201198 0.05588 0.203795 0.201198 0.03536 0.089345 0.083778 0.103635]

before modify. deleteColumnIndex:[1]
node:[0][1]

before modify. constructionErrorMatrix:[3.0431733686706206 11.391056715427794 19.652819956115856 13.713453313903868 11.625973829805879 12.827533320819564 9.7513513723204746 13.009151292890811 13.896089243289065 10.649829109971648 9.45239374745086 15.704486086921641 18.274065790781862 12.447866299915024 15.302996103637689 13.713453313903868 14.295549844738751 13.779406175789358 11.625212314259059 16.340507223201449 19.095964364689717 15.10149194936319 11.989443162329437 13.436654650354058 11.120373311110505 12.39345317975002 13.568052800712424 10.998430341124633 8.3223909323599869]

after modify. L shape:[300 29]

after modify. L row 0:[0.126693 -0.166731 -0.13703 -0.261395 0.11459 0.016001 0.016001 0.144603 0.05588 0.171787 0.016001 1.064545 0.144603 0.130615 -0.13703 -0.261395 1.064545 -0.261395 0.144603 0.036626 1.064545 0.188871 0.201198 0.05588 0.203795 0.201198 0.03536 0.089345 0.083778]

before modify. columnLinesOfL:[29]

iiii:[2]

后面那个after modify时L shape为[300 29]的原因是:执行

self.L=self.modify_one_column(self.L,new_column_tensor,J_minpos,self.numlinesOfL,columnLinesOfL)

时,columnLinesOfL是循环loop____vars中的变量,因此会随着每次循环发生变化,我写的modify_one_column见我的博文“修改tensor张量矩阵的某一列”。它决定了

修改后tensor的维度。

但是,无论如何,每一次循环,都是

before modify. L shape:[300 31]

说明self.L在循环体中虽然被修改了。但是下次循环又会被重置为初始值。

以上这篇基于tensorflow for循环 while循环案例就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
Python3安装Scrapy的方法步骤
Nov 23 Python
Python自动化运维之IP地址处理模块详解
Dec 10 Python
Pyinstaller将py打包成exe的实例
Mar 31 Python
Win10下python 2.7.13 安装配置方法图文教程
Sep 18 Python
python绘制热力图heatmap
Mar 23 Python
Python 给屏幕打印信息加上颜色的实现方法
Apr 24 Python
浅谈Pytorch中的torch.gather函数的含义
Aug 18 Python
初次部署django+gunicorn+nginx的方法步骤
Sep 11 Python
Python socket聊天脚本代码实例
Jan 02 Python
解决pymysql cursor.fetchall() 获取不到数据的问题
May 15 Python
matlab、python中矩阵的互相导入导出方式
Jun 01 Python
python通过opencv调用摄像头操作实例分析
Jun 07 Python
解析Tensorflow之MNIST的使用
Jun 30 #Python
Tensorflow tensor 数学运算和逻辑运算方式
Jun 30 #Python
Python requests模块安装及使用教程图解
Jun 30 #Python
在Tensorflow中实现leakyRelu操作详解(高效)
Jun 30 #Python
TensorFlow-gpu和opencv安装详细教程
Jun 30 #Python
tensorflow 2.1.0 安装与实战教程(CASIA FACE v5)
Jun 30 #Python
python 最简单的实现适配器设计模式的示例
Jun 30 #Python
You might like
PHP中static关键字原理的学习研究分析
2011/07/18 PHP
php 文件上传类代码
2011/08/06 PHP
php empty() 检查一个变量是否为空
2011/11/10 PHP
深入分析PHP优化及注意事项
2016/07/04 PHP
PHP Header用于页面跳转时的几个注意事项
2016/10/21 PHP
PHP读取文件的常见几种方法
2016/11/03 PHP
JavaScript中void(0)的具体含义解释
2007/02/27 Javascript
dojo 之基础篇(三)之向服务器发送数据
2007/03/24 Javascript
Ext.FormPanel 提交和 Ext.Ajax.request 异步提交函数的区别
2009/11/12 Javascript
node.js中的fs.writeFile方法使用说明
2014/12/14 Javascript
javascript 使用正则test( )第一次是 true,第二次是false
2017/02/22 Javascript
node.js入门学习之url模块
2017/02/25 Javascript
js实现日期显示的一些操作(实例讲解)
2017/07/27 Javascript
vue2实现可复用的轮播图carousel组件详解
2017/11/27 Javascript
简单谈谈CommonsChunkPlugin抽取公共模块
2017/12/31 Javascript
微信小程序实现横向增长表格的方法
2018/07/24 Javascript
vue-swiper的使用教程
2018/08/30 Javascript
详谈js的变量提升以及使用方法
2018/10/06 Javascript
vue 表单之通过v-model绑定单选按钮radio
2019/05/13 Javascript
[04:10]2016国际邀请赛中国区预选赛第二日TOP10精彩集锦
2016/06/28 DOTA
Python sys.argv用法实例
2015/05/28 Python
python DataFrame获取行数、列数、索引及第几行第几列的值方法
2018/04/08 Python
python使用rpc框架gRPC的方法
2018/08/24 Python
Python创建字典的八种方式
2019/02/27 Python
使用Python制作一个打字训练小工具
2019/10/01 Python
Django实现文件上传和下载功能
2019/10/06 Python
用css3实现转换过渡和动画效果
2020/03/13 HTML / CSS
文科教师毕业的自我评价
2014/01/16 职场文书
2014年小学植树节活动方案
2014/03/02 职场文书
2014基层党员干部学习全国两会心得体会
2014/03/17 职场文书
干部鉴定材料
2014/05/18 职场文书
汽车运用工程专业求职信
2014/06/18 职场文书
园林专业毕业生自荐信
2014/07/04 职场文书
JavaScript+HTML实现学生信息管理系统
2021/04/20 Javascript
nginx请求限制配置方法
2021/07/09 Servers
Rust 连接 PostgreSQL 数据库的详细过程
2022/01/22 PostgreSQL