keras K.function获取某层的输出操作


Posted in Python onJune 29, 2020

如下所示:

from keras import backend as K
from keras.models import load_model

models = load_model('models.hdf5')
image=r'image.png'
images=cv2.imread(r'image.png')
image_arr = process_image(image, (224, 224, 3))
image_arr = np.expand_dims(image_arr, axis=0)
layer_1 = K.function([base_model.get_input_at(0)], [base_model.get_layer('layer_name').output])
f1 = layer_1([image_arr])[0]

加载训练好并保存的网络模型

加载数据(图像),并将数据处理成array形式

指定输出层

将处理后的数据输入,然后获取输出

其中,K.function有两种不同的写法:

1. 获取名为layer_name的层的输出

layer_1 = K.function([base_model.get_input_at(0)], [base_model.get_layer('layer_name').output]) #指定输出层的名称

2. 获取第n层的输出

layer_1 = K.function([model.get_input_at(0)], [model.layers[5].output]) #指定输出层的序号(层号从0开始)

另外,需要注意的是,书写不规范会导致报错:

报错:

TypeError: inputs to a TensorFlow backend function should be a list or tuple

将该句:

f1 = layer_1(image_arr)[0]

修改为:

f1 = layer_1([image_arr])[0]

补充知识:keras.backend.function()

如下所示:

def function(inputs, outputs, updates=None, **kwargs):
 """Instantiates a Keras function.
 Arguments:
   inputs: List of placeholder tensors.
   outputs: List of output tensors.
   updates: List of update ops.
   **kwargs: Passed to `tf.Session.run`.
 Returns:
   Output values as Numpy arrays.
 Raises:
   ValueError: if invalid kwargs are passed in.
 """
 if kwargs:
  for key in kwargs:
   if (key not in tf_inspect.getargspec(session_module.Session.run)[0] and
     key not in tf_inspect.getargspec(Function.__init__)[0]):
    msg = ('Invalid argument "%s" passed to K.function with Tensorflow '
        'backend') % key
    raise ValueError(msg)
 return Function(inputs, outputs, updates=updates, **kwargs)

这是keras.backend.function()的源码。其中函数定义开头的注释就是官方文档对该函数的解释。

我们可以发现function()函数返回的是一个Function对象。下面是Function类的定义。

class Function(object):
 """Runs a computation graph.
 Arguments:
   inputs: Feed placeholders to the computation graph.
   outputs: Output tensors to fetch.
   updates: Additional update ops to be run at function call.
   name: a name to help users identify what this function does.
 """

 def __init__(self, inputs, outputs, updates=None, name=None,
        **session_kwargs):
  updates = updates or []
  if not isinstance(inputs, (list, tuple)):
   raise TypeError('`inputs` to a TensorFlow backend function '
           'should be a list or tuple.')
  if not isinstance(outputs, (list, tuple)):
   raise TypeError('`outputs` of a TensorFlow backend function '
           'should be a list or tuple.')
  if not isinstance(updates, (list, tuple)):
   raise TypeError('`updates` in a TensorFlow backend function '
           'should be a list or tuple.')
  self.inputs = list(inputs)
  self.outputs = list(outputs)
  with ops.control_dependencies(self.outputs):
   updates_ops = []
   for update in updates:
    if isinstance(update, tuple):
     p, new_p = update
     updates_ops.append(state_ops.assign(p, new_p))
    else:
     # assumed already an op
     updates_ops.append(update)
   self.updates_op = control_flow_ops.group(*updates_ops)
  self.name = name
  self.session_kwargs = session_kwargs

 def __call__(self, inputs):
  if not isinstance(inputs, (list, tuple)):
   raise TypeError('`inputs` should be a list or tuple.')
  feed_dict = {}
  for tensor, value in zip(self.inputs, inputs):
   if is_sparse(tensor):
    sparse_coo = value.tocoo()
    indices = np.concatenate((np.expand_dims(sparse_coo.row, 1),
                 np.expand_dims(sparse_coo.col, 1)), 1)
    value = (indices, sparse_coo.data, sparse_coo.shape)
   feed_dict[tensor] = value
  session = get_session()
  updated = session.run(
    self.outputs + [self.updates_op],
    feed_dict=feed_dict,
    **self.session_kwargs)
  return updated[:len(self.outputs)]

所以,function函数利用我们之前已经创建好的comuptation graph。遵循计算图,从输入到定义的输出。这也是为什么该函数经常用于提取中间层结果。

以上这篇keras K.function获取某层的输出操作就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
Python中的进程分支fork和exec详解
Apr 11 Python
python学习 流程控制语句详解
Jun 01 Python
Python实现购物程序思路及代码
Jul 24 Python
使用python判断jpeg图片的完整性实例
Jun 10 Python
python根据多个文件名批量查找文件
Aug 13 Python
QML使用Python的函数过程解析
Sep 26 Python
python主线程与子线程的结束顺序实例解析
Dec 17 Python
python文件操作seek()偏移量,读取指正到指定位置操作
Jul 05 Python
详解Python直接赋值,深拷贝和浅拷贝
Jul 09 Python
Python unittest discover批量执行代码实例
Sep 08 Python
Python list去重且保持原顺序不变的方法
Apr 03 Python
python APScheduler执行定时任务介绍
Apr 19 Python
Python pytesseract验证码识别库用法解析
Jun 29 #Python
用Python开发app后端有优势吗
Jun 29 #Python
在keras里实现自定义上采样层
Jun 28 #Python
Python如何对XML 解析
Jun 28 #Python
keras 自定义loss层+接受输入实例
Jun 28 #Python
python批量处理多DNS多域名的nslookup解析实现
Jun 28 #Python
解决Keras TensorFlow 混编中 trainable=False设置无效问题
Jun 28 #Python
You might like
在php中取得image按钮传递的name值
2006/10/09 PHP
用 PHP5 轻松解析 XML
2006/12/04 PHP
一个PHP模板,主要想体现一下思路
2006/12/25 PHP
php下关于中英数字混排的字符串分割问题
2010/04/06 PHP
支持中文字母数字、自定义字体php验证码代码
2012/02/27 PHP
php根据数据id自动生成编号的实现方法
2016/10/16 PHP
javascript重复绑定事件造成的后果说明
2013/03/02 Javascript
解析JSON对象与字符串之间的相互转换
2013/12/18 Javascript
AngularJS 最常用的功能汇总
2016/02/17 Javascript
JavaScript数据存储 Cookie篇
2016/07/02 Javascript
jQuery实现简单的回到顶部totop功能示例
2017/10/16 jQuery
jquery.tagsinput.js实现记录checkbox勾选的顺序
2019/09/21 jQuery
React.js组件实现拖拽排序组件功能过程解析
2020/04/27 Javascript
[02:12]探秘2016国际邀请赛中国区预选赛选手房间
2016/06/25 DOTA
python实现根据用户输入从电影网站获取影片信息的方法
2015/04/07 Python
Python的Django框架中消息通知的计数器实现教程
2016/06/13 Python
Python实现图片转字符画的示例代码
2017/08/21 Python
Python排序搜索基本算法之希尔排序实例分析
2017/12/09 Python
Centos 升级到python3后pip 无法使用的解决方法
2018/06/12 Python
用Python实现数据的透视表的方法
2018/11/16 Python
python路径的写法及目录的获取方式
2019/12/26 Python
pyspark给dataframe增加新的一列的实现示例
2020/04/24 Python
keras中的卷积层&池化层的用法
2020/05/22 Python
Python如何执行精确的浮点数运算
2020/07/31 Python
Python实现自动签到脚本功能
2020/08/20 Python
python识别验证码的思路及解决方案
2020/09/13 Python
CSS3动画之流彩文字效果+图片模糊效果+边框伸展效果实现代码合集
2017/08/18 HTML / CSS
What is the purpose of Void class? Void类的作用是什么?
2016/10/31 面试题
shell变量的作用空间是什么
2013/08/17 面试题
写求职信要注意什么问题
2014/04/12 职场文书
大型演出策划方案
2014/05/28 职场文书
4s店活动策划方案
2014/08/25 职场文书
学校政风行风自查自纠报告
2014/10/21 职场文书
解放思想大讨论活动总结
2015/05/09 职场文书
观看《信仰》心得体会
2016/01/15 职场文书
分析Python list操作为什么会错误
2021/11/17 Python