keras K.function获取某层的输出操作


Posted in Python onJune 29, 2020

如下所示:

from keras import backend as K
from keras.models import load_model

models = load_model('models.hdf5')
image=r'image.png'
images=cv2.imread(r'image.png')
image_arr = process_image(image, (224, 224, 3))
image_arr = np.expand_dims(image_arr, axis=0)
layer_1 = K.function([base_model.get_input_at(0)], [base_model.get_layer('layer_name').output])
f1 = layer_1([image_arr])[0]

加载训练好并保存的网络模型

加载数据(图像),并将数据处理成array形式

指定输出层

将处理后的数据输入,然后获取输出

其中,K.function有两种不同的写法:

1. 获取名为layer_name的层的输出

layer_1 = K.function([base_model.get_input_at(0)], [base_model.get_layer('layer_name').output]) #指定输出层的名称

2. 获取第n层的输出

layer_1 = K.function([model.get_input_at(0)], [model.layers[5].output]) #指定输出层的序号(层号从0开始)

另外,需要注意的是,书写不规范会导致报错:

报错:

TypeError: inputs to a TensorFlow backend function should be a list or tuple

将该句:

f1 = layer_1(image_arr)[0]

修改为:

f1 = layer_1([image_arr])[0]

补充知识:keras.backend.function()

如下所示:

def function(inputs, outputs, updates=None, **kwargs):
 """Instantiates a Keras function.
 Arguments:
   inputs: List of placeholder tensors.
   outputs: List of output tensors.
   updates: List of update ops.
   **kwargs: Passed to `tf.Session.run`.
 Returns:
   Output values as Numpy arrays.
 Raises:
   ValueError: if invalid kwargs are passed in.
 """
 if kwargs:
  for key in kwargs:
   if (key not in tf_inspect.getargspec(session_module.Session.run)[0] and
     key not in tf_inspect.getargspec(Function.__init__)[0]):
    msg = ('Invalid argument "%s" passed to K.function with Tensorflow '
        'backend') % key
    raise ValueError(msg)
 return Function(inputs, outputs, updates=updates, **kwargs)

这是keras.backend.function()的源码。其中函数定义开头的注释就是官方文档对该函数的解释。

我们可以发现function()函数返回的是一个Function对象。下面是Function类的定义。

class Function(object):
 """Runs a computation graph.
 Arguments:
   inputs: Feed placeholders to the computation graph.
   outputs: Output tensors to fetch.
   updates: Additional update ops to be run at function call.
   name: a name to help users identify what this function does.
 """

 def __init__(self, inputs, outputs, updates=None, name=None,
        **session_kwargs):
  updates = updates or []
  if not isinstance(inputs, (list, tuple)):
   raise TypeError('`inputs` to a TensorFlow backend function '
           'should be a list or tuple.')
  if not isinstance(outputs, (list, tuple)):
   raise TypeError('`outputs` of a TensorFlow backend function '
           'should be a list or tuple.')
  if not isinstance(updates, (list, tuple)):
   raise TypeError('`updates` in a TensorFlow backend function '
           'should be a list or tuple.')
  self.inputs = list(inputs)
  self.outputs = list(outputs)
  with ops.control_dependencies(self.outputs):
   updates_ops = []
   for update in updates:
    if isinstance(update, tuple):
     p, new_p = update
     updates_ops.append(state_ops.assign(p, new_p))
    else:
     # assumed already an op
     updates_ops.append(update)
   self.updates_op = control_flow_ops.group(*updates_ops)
  self.name = name
  self.session_kwargs = session_kwargs

 def __call__(self, inputs):
  if not isinstance(inputs, (list, tuple)):
   raise TypeError('`inputs` should be a list or tuple.')
  feed_dict = {}
  for tensor, value in zip(self.inputs, inputs):
   if is_sparse(tensor):
    sparse_coo = value.tocoo()
    indices = np.concatenate((np.expand_dims(sparse_coo.row, 1),
                 np.expand_dims(sparse_coo.col, 1)), 1)
    value = (indices, sparse_coo.data, sparse_coo.shape)
   feed_dict[tensor] = value
  session = get_session()
  updated = session.run(
    self.outputs + [self.updates_op],
    feed_dict=feed_dict,
    **self.session_kwargs)
  return updated[:len(self.outputs)]

所以,function函数利用我们之前已经创建好的comuptation graph。遵循计算图,从输入到定义的输出。这也是为什么该函数经常用于提取中间层结果。

以上这篇keras K.function获取某层的输出操作就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
Python简单调用MySQL存储过程并获得返回值的方法
Jul 20 Python
python中range()与xrange()用法分析
Sep 21 Python
VSCode下配置python调试运行环境的方法
Apr 06 Python
Pandas 数据处理,数据清洗详解
Jul 10 Python
python+pyqt5实现KFC点餐收银系统
Jan 24 Python
Python之NumPy(axis=0 与axis=1)区分详解
May 27 Python
python多线程下信号处理程序示例
May 31 Python
Python实现微信机器人的方法
Sep 06 Python
Python 中的 import 机制之实现远程导入模块
Oct 29 Python
Pytorch之finetune使用详解
Jan 18 Python
基于Tensorflow:CPU性能分析
Feb 10 Python
Python基于Faker假数据构造库
Nov 30 Python
Python pytesseract验证码识别库用法解析
Jun 29 #Python
用Python开发app后端有优势吗
Jun 29 #Python
在keras里实现自定义上采样层
Jun 28 #Python
Python如何对XML 解析
Jun 28 #Python
keras 自定义loss层+接受输入实例
Jun 28 #Python
python批量处理多DNS多域名的nslookup解析实现
Jun 28 #Python
解决Keras TensorFlow 混编中 trainable=False设置无效问题
Jun 28 #Python
You might like
php 禁止页面缓存输出
2009/01/07 PHP
PHP函数篇详解十进制、二进制、八进制和十六进制转换函数说明
2011/12/05 PHP
Zend Framework中的简单工厂模式 图文
2012/07/10 PHP
PHP中基于ts与nts版本- vc6和vc9编译版本的区别详解
2013/04/26 PHP
typecho插件编写教程(一):Hello World
2015/05/28 PHP
屏蔽F1~F12的快捷键的js函数
2010/05/06 Javascript
JQuery实现倒计时按钮的实现代码
2012/03/23 Javascript
jQuery 开发者应该注意的9个错误
2012/05/03 Javascript
Jquery实现弹出层分享微博插件具备动画效果
2013/04/03 Javascript
jquery简单实现滚动条下拉DIV固定在头部不动
2013/11/25 Javascript
手机号码,密码正则验证
2014/09/04 Javascript
js仿土豆网带缩略图的焦点图片切换效果实现方法
2015/02/23 Javascript
jQuery实现按钮的点击 全选/反选 单选框/复选框 文本框 表单验证
2015/06/25 Javascript
jQuery代码实现发展历程时间轴特效
2015/07/30 Javascript
使用JQuery实现Ctrl+Enter提交表单的方法
2015/10/22 Javascript
详解javascript中原始数据类型Null和Undefined
2015/12/17 Javascript
js数组去重的hash方法
2016/12/22 Javascript
使用vue-cli创建项目的图文教程(新手入门篇)
2018/05/02 Javascript
深入解析Python中的lambda表达式的用法
2015/08/28 Python
python实现csv格式文件转为asc格式文件的方法
2018/03/23 Python
python操作excel的包(openpyxl、xlsxwriter)
2018/06/11 Python
对python 读取线的shp文件实例详解
2018/12/22 Python
Python简单基础小程序的实例代码
2019/04/28 Python
python 实现性别识别
2020/11/21 Python
html5-Canvas可以在web中绘制各种图形
2012/12/26 HTML / CSS
荷兰多品牌网上鞋店:Stoute Schoenen
2017/08/24 全球购物
VLAN和VPN有什么区别?分别实现在OSI的第几层?
2014/12/23 面试题
护士感人事迹
2014/05/01 职场文书
品牌服务方案
2014/06/03 职场文书
企业财务人员岗位职责
2015/04/14 职场文书
公司职员入党自传书
2015/06/26 职场文书
2015年国庆节演讲稿范文
2015/07/30 职场文书
学雷锋主题班会教案
2015/08/13 职场文书
CDPR谈《巫师》新作用虚幻5原因 称不会为Epic独占
2022/04/06 其他游戏
Nginx配置之禁止指定IP访问
2022/05/02 Servers
Linux中一对多配置日志服务器的详细步骤
2022/07/23 Servers