keras K.function获取某层的输出操作


Posted in Python onJune 29, 2020

如下所示:

from keras import backend as K
from keras.models import load_model

models = load_model('models.hdf5')
image=r'image.png'
images=cv2.imread(r'image.png')
image_arr = process_image(image, (224, 224, 3))
image_arr = np.expand_dims(image_arr, axis=0)
layer_1 = K.function([base_model.get_input_at(0)], [base_model.get_layer('layer_name').output])
f1 = layer_1([image_arr])[0]

加载训练好并保存的网络模型

加载数据(图像),并将数据处理成array形式

指定输出层

将处理后的数据输入,然后获取输出

其中,K.function有两种不同的写法:

1. 获取名为layer_name的层的输出

layer_1 = K.function([base_model.get_input_at(0)], [base_model.get_layer('layer_name').output]) #指定输出层的名称

2. 获取第n层的输出

layer_1 = K.function([model.get_input_at(0)], [model.layers[5].output]) #指定输出层的序号(层号从0开始)

另外,需要注意的是,书写不规范会导致报错:

报错:

TypeError: inputs to a TensorFlow backend function should be a list or tuple

将该句:

f1 = layer_1(image_arr)[0]

修改为:

f1 = layer_1([image_arr])[0]

补充知识:keras.backend.function()

如下所示:

def function(inputs, outputs, updates=None, **kwargs):
 """Instantiates a Keras function.
 Arguments:
   inputs: List of placeholder tensors.
   outputs: List of output tensors.
   updates: List of update ops.
   **kwargs: Passed to `tf.Session.run`.
 Returns:
   Output values as Numpy arrays.
 Raises:
   ValueError: if invalid kwargs are passed in.
 """
 if kwargs:
  for key in kwargs:
   if (key not in tf_inspect.getargspec(session_module.Session.run)[0] and
     key not in tf_inspect.getargspec(Function.__init__)[0]):
    msg = ('Invalid argument "%s" passed to K.function with Tensorflow '
        'backend') % key
    raise ValueError(msg)
 return Function(inputs, outputs, updates=updates, **kwargs)

这是keras.backend.function()的源码。其中函数定义开头的注释就是官方文档对该函数的解释。

我们可以发现function()函数返回的是一个Function对象。下面是Function类的定义。

class Function(object):
 """Runs a computation graph.
 Arguments:
   inputs: Feed placeholders to the computation graph.
   outputs: Output tensors to fetch.
   updates: Additional update ops to be run at function call.
   name: a name to help users identify what this function does.
 """

 def __init__(self, inputs, outputs, updates=None, name=None,
        **session_kwargs):
  updates = updates or []
  if not isinstance(inputs, (list, tuple)):
   raise TypeError('`inputs` to a TensorFlow backend function '
           'should be a list or tuple.')
  if not isinstance(outputs, (list, tuple)):
   raise TypeError('`outputs` of a TensorFlow backend function '
           'should be a list or tuple.')
  if not isinstance(updates, (list, tuple)):
   raise TypeError('`updates` in a TensorFlow backend function '
           'should be a list or tuple.')
  self.inputs = list(inputs)
  self.outputs = list(outputs)
  with ops.control_dependencies(self.outputs):
   updates_ops = []
   for update in updates:
    if isinstance(update, tuple):
     p, new_p = update
     updates_ops.append(state_ops.assign(p, new_p))
    else:
     # assumed already an op
     updates_ops.append(update)
   self.updates_op = control_flow_ops.group(*updates_ops)
  self.name = name
  self.session_kwargs = session_kwargs

 def __call__(self, inputs):
  if not isinstance(inputs, (list, tuple)):
   raise TypeError('`inputs` should be a list or tuple.')
  feed_dict = {}
  for tensor, value in zip(self.inputs, inputs):
   if is_sparse(tensor):
    sparse_coo = value.tocoo()
    indices = np.concatenate((np.expand_dims(sparse_coo.row, 1),
                 np.expand_dims(sparse_coo.col, 1)), 1)
    value = (indices, sparse_coo.data, sparse_coo.shape)
   feed_dict[tensor] = value
  session = get_session()
  updated = session.run(
    self.outputs + [self.updates_op],
    feed_dict=feed_dict,
    **self.session_kwargs)
  return updated[:len(self.outputs)]

所以,function函数利用我们之前已经创建好的comuptation graph。遵循计算图,从输入到定义的输出。这也是为什么该函数经常用于提取中间层结果。

以上这篇keras K.function获取某层的输出操作就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
linux系统使用python监测系统负载脚本分享
Jan 15 Python
Python实现根据IP地址和子网掩码算出网段的方法
Jul 30 Python
Python打包可执行文件的方法详解
Sep 19 Python
关于python pyqt5安装失败问题的解决方法
Aug 08 Python
Anaconda多环境多版本python配置操作方法
Sep 12 Python
Python3.4实现从HTTP代理网站批量获取代理并筛选的方法示例
Sep 26 Python
pandas每次多Sheet写入文件的方法
Dec 10 Python
Python中模块(Module)和包(Package)的区别详解
Aug 07 Python
Python GUI自动化实现绕过验证码登录
Jan 10 Python
python 双循环遍历list 变量判断代码
May 04 Python
python speech模块的使用方法
Sep 09 Python
浅谈Python中对象是如何被调用的
Apr 06 Python
Python pytesseract验证码识别库用法解析
Jun 29 #Python
用Python开发app后端有优势吗
Jun 29 #Python
在keras里实现自定义上采样层
Jun 28 #Python
Python如何对XML 解析
Jun 28 #Python
keras 自定义loss层+接受输入实例
Jun 28 #Python
python批量处理多DNS多域名的nslookup解析实现
Jun 28 #Python
解决Keras TensorFlow 混编中 trainable=False设置无效问题
Jun 28 #Python
You might like
file_get_contents("php://input", "r")实例介绍
2013/07/01 PHP
总结PHP中初始化空数组的最佳方法
2019/02/13 PHP
alixixi runcode.asp的代码不错的应用
2007/08/08 Javascript
JS的千分位算法实现思路
2013/07/31 Javascript
JQuery表格拖动调整列宽效果(自己动手写的)
2014/09/01 Javascript
jquery判断输入密码两次是否相等
2020/04/22 Javascript
ztree获取选中节点时不能进入可视区域出现BUG如何解决
2015/12/03 Javascript
AngularJS实现元素显示和隐藏的几个案例
2015/12/09 Javascript
基于JS实现导航条之调用网页助手小精灵的方法
2016/06/17 Javascript
js实现无缝滚动图
2017/02/22 Javascript
vue axios整合使用全攻略
2018/05/24 Javascript
vue点击input弹出带搜索键盘并监听该元素的方法
2018/08/25 Javascript
基于JavaScript实现每日签到打卡轨迹功能
2018/11/29 Javascript
[05:07]DOTA2英雄梦之声_第14期_暗影恶魔
2014/06/20 DOTA
python 统计数组中元素出现次数并进行排序的实例
2018/07/02 Python
flask入门之表单的实现
2018/07/18 Python
Python中的字符串切片(截取字符串)的详解
2019/05/15 Python
Python中使用gflags实例及原理解析
2019/12/13 Python
python入门之基础语法学习笔记
2020/02/08 Python
台湾旅游网站:雄狮旅游网
2017/08/16 全球购物
意大利消费电子产品购物网站:SLG Store
2019/12/26 全球购物
几道Web/Ajax的面试题
2016/11/05 面试题
面料业务员岗位职责
2013/12/26 职场文书
应届大学生简历中的自我评价
2014/01/15 职场文书
求职自荐信怎么写
2014/03/06 职场文书
保研推荐信
2014/05/09 职场文书
超市理货员岗位职责
2014/07/04 职场文书
幼儿园八一建军节活动方案
2014/08/27 职场文书
学校政风行风评议心得体会
2014/10/21 职场文书
结婚通知短信大全
2015/04/17 职场文书
实施意见格式范本
2015/06/05 职场文书
战友聚会致辞
2015/07/28 职场文书
如何利用opencv判断两张图片是否相同详解
2021/07/07 Python
Nginx配置文件详解以及优化建议指南
2021/09/15 Servers
SpringBoot连接MySQL获取数据写后端接口的操作方法
2021/11/02 MySQL
使用python绘制横竖条形图
2022/04/21 Python