解决TensorFlow调用Keras库函数存在的问题


Posted in Python onJuly 06, 2020

tensorflow在1.4版本引入了keras,封装成库。现想将keras版本的GRU代码移植到TensorFlow中,看到TensorFlow中有Keras库,大喜,故将神经网络定义部分使用Keras的Function API方式进行定义,训练部分则使用TensorFlow来进行编写。一顿操作之后,运行,没有报错,不由得一喜。但是输出结果,发现,和预期的不一样。难道是欠拟合?故采用正弦波预测余弦来验证算法模型。

部分调用keras库代码如上图所示,用正弦波预测余弦波,出现如下现象:

def interface(_input):
  tmp = tf.keras.layers.Dense(10)(_input)
  vad_gru = tf.keras.layers.GRU(24, return_sequences=True)(tmp)
  denoise_output = tf.keras.layers.Dense(1)(vad_gru)
  return denoise_output

波形是断断续续的。而且最后不收敛。

解决TensorFlow调用Keras库函数存在的问题

运行N久。。。之后

基本断定是程序本身的问题,于是通过排查,发现应该是GRU的initial_state没有进行更新导致的。导致波形是断断续续的,没有学习到前一次网络的输出。于是,决定不使用Keras库实现一遍:

部分代码如下:

def interface(_input):
  tmp = tf.keras.layers.Dense(10)(_input)
  gru_cell = tf.nn.rnn_cell.GRUCell(vad_cell_size)
  with tf.name_scope('initial_state'):
    cell_init_state = gru_cell.zero_state(batch_size, dtype=tf.float32)
  cell_outputs, cell_final_state = tf.nn.dynamic_rnn(
    gru_cell, tmp, initial_state=cell_init_state, time_major=False)
  denoise_output = tf.keras.layers.Dense(1)(cell_outputs)
  return denoise_output, cell_init_state, cell_final_state

波形图如下(这才是GRU的正确打开方式啊~):

解决TensorFlow调用Keras库函数存在的问题

再回头看之前写的调用keras,既然知道了是initial_state没有更新,那么如何进行更新呢?

网上查找了大量的资料,说要加上

update_ops = []
for old_value, new_value in layers.updates:
  update_ops.append(tf.assign(old_value, new_value))

但是加上去没有效果,是我加错了还是其他的,大家欢迎指出来

以下是我做的一些尝试,就不一一详细说明了,大家看一下,具体不再展开,有问题大家交流一下,有解决方法的,能够分享出来,感激不尽~

def interface(_input):
  # input_layer = tf.keras.layers.Input([None, 1])
  # input_layer = tf.keras.layers.Input(batch_shape=(50, 20, 1))
  tmp = tf.keras.layers.Dense(10)(_input)
  # tmp = tf.keras.layers.Dense(24)(tmp)
 
  # with tf.variable_scope('vad_gru', reuse=tf.AUTO_REUSE):
  # vad_gru, final_state = tf.keras.layers.GRU(24, return_sequences=True, return_state=True, stateful=True)(tmp)
  # print(vad_gru)
  # _initial_state = vad_gru.zero_state(50, tf.float32)
  # tf.get_variable_scope().reuse_variables()
 
  # vad_gru = tf.contrib.
 
  # tmp = tf.reshape(tmp, [-1, TIME_STEPS, vad_cell_size])
  gru_cell = tf.nn.rnn_cell.GRUCell(vad_cell_size)
  # gru_cell = tf.keras.layers.GRUCell(self.vad_cell_size)
  with tf.name_scope('initial_state'):
    cell_init_state = gru_cell.zero_state(batch_size, dtype=tf.float32)
  cell_outputs, cell_final_state = tf.nn.dynamic_rnn(
    gru_cell, tmp, initial_state=cell_init_state, time_major=False)
  # print(cell_outputs.get_shape().as_list())
 
  # cell_outputs = tf.reshape(cell_outputs, [-1, vad_cell_size])
 
  denoise_output = tf.keras.layers.Dense(1)(cell_outputs)
  print(denoise_output.get_shape().as_list())
 
  # model = tf.keras.models.Model(input_layer, denoise_output)
  # update_ops = []
  # for old_value, new_value in model.layers[1].updates:
  #   update_ops.append(tf.assign(old_value, new_value))
 
  return denoise_output, cell_init_state, cell_final_state

补充知识:TensorFlow和Keras常用方法(避坑)

TensorFlow

在TensorFlow中,除法运算:

1.tensor除法会使结果的精度高一级,可能会导致后面计算类型不匹配,如float32 / float32 = float64。

2.除法需要分子分母同类型,否则报错。

产生类似错误提示如下:

-1.TypeError: x and y must have the same dtype, got tf.float32 != tf.int32

-2.TypeError: Input ‘y' of ‘Mul' Op has type float32 that does not match type float64 of argument ‘x'.

-3.ValueError: Tensor conversion requested dtype float64 for Tensor with dtype float32: ‘Tensor(“Sum:0”, shape=(), dtype=float32)'

-4.ValueError: Incompatible type conversion requested to type ‘int32' for variable of type ‘float32_ref'

解决办法:

tf.cast(a, tf.float32) # 转换成同类型即可

tf.boolean_mask

K.gather

K.argmax

K.max

以上这篇解决TensorFlow调用Keras库函数存在的问题就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
c++生成dll使用python调用dll的方法
Jan 20 Python
python实现定时播放mp3
Mar 29 Python
python开发中range()函数用法实例分析
Nov 12 Python
深入理解python对json的操作总结
Jan 05 Python
Python表示矩阵的方法分析
May 26 Python
python爬虫之百度API调用方法
Jun 11 Python
简单实现python聊天程序
Apr 01 Python
python3+PyQt5使用数据库表视图
Apr 24 Python
Pytorch抽取网络层的Feature Map(Vgg)实例
Aug 20 Python
Python递归实现打印多重列表代码
Feb 27 Python
Python virtualenv虚拟环境实现过程解析
Apr 18 Python
python 基于wx实现音乐播放
Nov 24 Python
python else语句在循环中的运用详解
Jul 06 #Python
Keras模型转成tensorflow的.pb操作
Jul 06 #Python
python如何进入交互模式
Jul 06 #Python
python3.4中清屏的处理方法
Jul 06 #Python
Python3基于print打印带颜色字符串
Jul 06 #Python
python判断是空的实例分享
Jul 06 #Python
python三引号如何输入
Jul 06 #Python
You might like
打造计数器DIY三步曲(中)
2006/10/09 PHP
PHP 程序员的调试技术小结
2009/11/15 PHP
php开启安全模式后禁用的函数集合
2011/06/26 PHP
php.ini 配置文件的深入解析
2013/06/17 PHP
PHP实现的分页类定义与用法示例
2017/07/05 PHP
asp批量修改记录的代码
2008/06/25 Javascript
Jquery cookie操作代码
2010/03/14 Javascript
web前端开发也需要日志
2010/12/09 Javascript
JQuery扩展插件Validate 1 基本使用方法并打包下载
2011/09/05 Javascript
javascript与jquery中跳出循环的区别总结
2013/11/04 Javascript
js中的eventType事件及其浏览器支持性介绍
2013/11/29 Javascript
javascript学习笔记(六)数据类型和JSON格式
2014/10/08 Javascript
Javascript添加监听与删除监听用法详解
2014/12/19 Javascript
js实现拖拽效果
2015/02/12 Javascript
自己动手写的javascript前端等待控件
2015/10/30 Javascript
非常实用的12个jquery代码片段
2015/11/02 Javascript
jquery实现加载进度条提示效果
2015/11/23 Javascript
jquery中ajax处理跨域的三大方式
2016/01/05 Javascript
jQuery中$.each()函数的用法引申实例
2016/05/12 Javascript
jQuery实现的多张图无缝滚动效果【测试可用】
2016/09/12 Javascript
Node.js开发第三方微信公众平台
2017/06/05 Javascript
js中apply和Math.max()函数的问题及区别介绍
2018/03/27 Javascript
谈谈我在vue-cli3中用预渲染遇到的坑
2020/04/22 Javascript
[01:02]2014 DOTA2国际邀请赛中国区预选赛 现场抢先看
2014/05/22 DOTA
[33:23]Secret vs Serenity 2018国际邀请赛小组赛BO2 第二场 8.16
2018/08/17 DOTA
python 函数传参之传值还是传引用的分析
2017/09/07 Python
python 实现selenium断言和验证的方法
2019/02/13 Python
pyqt 实现在Widgets中显示图片和文字的方法
2019/06/13 Python
通过字符串导入 Python 模块的方法详解
2019/10/27 Python
Fairyseason:为个人和批发商提供女装和配件
2017/03/01 全球购物
Banana Republic英国官网:香蕉共和国,GAP集团旗下偏贵族风
2018/04/24 全球购物
能源工程专业应届生求职信
2014/03/01 职场文书
村抢险救灾方案
2014/05/09 职场文书
小学生红领巾广播稿
2015/08/19 职场文书
只需要100行Python代码就可以实现的贪吃蛇小游戏
2021/05/27 Python
新的CSS 伪类函数 :is() 和 :where()示例详解
2022/08/05 HTML / CSS