keras K.function获取某层的输出操作


Posted in Python onJune 29, 2020

如下所示:

from keras import backend as K
from keras.models import load_model

models = load_model('models.hdf5')
image=r'image.png'
images=cv2.imread(r'image.png')
image_arr = process_image(image, (224, 224, 3))
image_arr = np.expand_dims(image_arr, axis=0)
layer_1 = K.function([base_model.get_input_at(0)], [base_model.get_layer('layer_name').output])
f1 = layer_1([image_arr])[0]

加载训练好并保存的网络模型

加载数据(图像),并将数据处理成array形式

指定输出层

将处理后的数据输入,然后获取输出

其中,K.function有两种不同的写法:

1. 获取名为layer_name的层的输出

layer_1 = K.function([base_model.get_input_at(0)], [base_model.get_layer('layer_name').output]) #指定输出层的名称

2. 获取第n层的输出

layer_1 = K.function([model.get_input_at(0)], [model.layers[5].output]) #指定输出层的序号(层号从0开始)

另外,需要注意的是,书写不规范会导致报错:

报错:

TypeError: inputs to a TensorFlow backend function should be a list or tuple

将该句:

f1 = layer_1(image_arr)[0]

修改为:

f1 = layer_1([image_arr])[0]

补充知识:keras.backend.function()

如下所示:

def function(inputs, outputs, updates=None, **kwargs):
 """Instantiates a Keras function.
 Arguments:
   inputs: List of placeholder tensors.
   outputs: List of output tensors.
   updates: List of update ops.
   **kwargs: Passed to `tf.Session.run`.
 Returns:
   Output values as Numpy arrays.
 Raises:
   ValueError: if invalid kwargs are passed in.
 """
 if kwargs:
  for key in kwargs:
   if (key not in tf_inspect.getargspec(session_module.Session.run)[0] and
     key not in tf_inspect.getargspec(Function.__init__)[0]):
    msg = ('Invalid argument "%s" passed to K.function with Tensorflow '
        'backend') % key
    raise ValueError(msg)
 return Function(inputs, outputs, updates=updates, **kwargs)

这是keras.backend.function()的源码。其中函数定义开头的注释就是官方文档对该函数的解释。

我们可以发现function()函数返回的是一个Function对象。下面是Function类的定义。

class Function(object):
 """Runs a computation graph.
 Arguments:
   inputs: Feed placeholders to the computation graph.
   outputs: Output tensors to fetch.
   updates: Additional update ops to be run at function call.
   name: a name to help users identify what this function does.
 """

 def __init__(self, inputs, outputs, updates=None, name=None,
        **session_kwargs):
  updates = updates or []
  if not isinstance(inputs, (list, tuple)):
   raise TypeError('`inputs` to a TensorFlow backend function '
           'should be a list or tuple.')
  if not isinstance(outputs, (list, tuple)):
   raise TypeError('`outputs` of a TensorFlow backend function '
           'should be a list or tuple.')
  if not isinstance(updates, (list, tuple)):
   raise TypeError('`updates` in a TensorFlow backend function '
           'should be a list or tuple.')
  self.inputs = list(inputs)
  self.outputs = list(outputs)
  with ops.control_dependencies(self.outputs):
   updates_ops = []
   for update in updates:
    if isinstance(update, tuple):
     p, new_p = update
     updates_ops.append(state_ops.assign(p, new_p))
    else:
     # assumed already an op
     updates_ops.append(update)
   self.updates_op = control_flow_ops.group(*updates_ops)
  self.name = name
  self.session_kwargs = session_kwargs

 def __call__(self, inputs):
  if not isinstance(inputs, (list, tuple)):
   raise TypeError('`inputs` should be a list or tuple.')
  feed_dict = {}
  for tensor, value in zip(self.inputs, inputs):
   if is_sparse(tensor):
    sparse_coo = value.tocoo()
    indices = np.concatenate((np.expand_dims(sparse_coo.row, 1),
                 np.expand_dims(sparse_coo.col, 1)), 1)
    value = (indices, sparse_coo.data, sparse_coo.shape)
   feed_dict[tensor] = value
  session = get_session()
  updated = session.run(
    self.outputs + [self.updates_op],
    feed_dict=feed_dict,
    **self.session_kwargs)
  return updated[:len(self.outputs)]

所以,function函数利用我们之前已经创建好的comuptation graph。遵循计算图,从输入到定义的输出。这也是为什么该函数经常用于提取中间层结果。

以上这篇keras K.function获取某层的输出操作就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
Python之PyUnit单元测试实例
Oct 11 Python
使用Python编写一个简单的tic-tac-toe游戏的教程
Apr 16 Python
Python中用post、get方式提交数据的方法示例
Sep 22 Python
基于Python列表解析(列表推导式)
Jun 23 Python
Python将多个list合并为1个list的方法
Jun 27 Python
Python中出现IndentationError:unindent does not match any outer indentation level错误的解决方法
Apr 18 Python
Python实现合并两个有序链表的方法示例
Jan 31 Python
python自动发微信监控报警
Sep 06 Python
python中的Elasticsearch操作汇总
Oct 30 Python
Python编程快速上手——PDF文件操作案例分析
Feb 28 Python
Ubuntu 20.04安装Pycharm2020.2及锁定到任务栏的问题(小白级操作)
Oct 29 Python
python palywright库基本使用
Jan 21 Python
Python pytesseract验证码识别库用法解析
Jun 29 #Python
用Python开发app后端有优势吗
Jun 29 #Python
在keras里实现自定义上采样层
Jun 28 #Python
Python如何对XML 解析
Jun 28 #Python
keras 自定义loss层+接受输入实例
Jun 28 #Python
python批量处理多DNS多域名的nslookup解析实现
Jun 28 #Python
解决Keras TensorFlow 混编中 trainable=False设置无效问题
Jun 28 #Python
You might like
基于数据库的在线人数,日访问量等统计
2006/10/09 PHP
PHP实现的比较完善的购物车类
2014/12/02 PHP
PHP截取指定图片大小的方法
2014/12/10 PHP
ThinkPHP Where 条件中常用表达式示例(详解)
2017/03/31 PHP
PHP实现动态添加XML中数据的方法
2018/03/30 PHP
JS在IE和FireFox之间常用函数的区别小结
2010/03/12 Javascript
JS图片自动轮换效果实现思路附截图
2014/04/30 Javascript
JavaScript通过select动态更换图片的方法
2015/03/23 Javascript
微信小程序 开发工具快捷键整理
2016/10/31 Javascript
如何解决vue与传统jquery插件冲突
2017/03/20 Javascript
angular学习之从零搭建一个angular4.0项目
2017/07/10 Javascript
Vue中的Vux配置指南
2017/12/08 Javascript
JS中promise化微信小程序api
2018/04/12 Javascript
vue2.0的虚拟DOM渲染思路分析
2018/08/09 Javascript
jQuery实现侧边栏隐藏与显示的方法详解
2018/12/22 jQuery
js回调函数仿360开机
2019/12/26 Javascript
Pandas:Series和DataFrame删除指定轴上数据的方法
2018/11/10 Python
python实现两个经纬度点之间的距离和方位角的方法
2019/07/05 Python
对Django中static(静态)文件详解以及{% static %}标签的使用方法
2019/07/28 Python
基于python框架Scrapy爬取自己的博客内容过程详解
2019/08/05 Python
Python with关键字,上下文管理器,@contextmanager文件操作示例
2019/10/17 Python
python shutil文件操作工具使用实例分析
2019/12/25 Python
python matplotlib模块基本图形绘制方法小结【直线,曲线,直方图,饼图等】
2020/04/26 Python
捷克浴室和厨房设备购物网站:SIKO
2018/08/11 全球购物
英国珠宝和手表专家:Pleasance & Harper
2020/10/21 全球购物
小学生放飞梦想演讲稿
2014/08/26 职场文书
2015年父亲节活动总结
2015/02/12 职场文书
财务部岗位职责范本
2015/04/14 职场文书
审查起诉阶段律师意见书
2015/05/19 职场文书
2015年度招聘工作总结
2015/05/28 职场文书
未来,这5大方向都很适合创业
2019/07/22 职场文书
创业项目大全(适合在家创业的项目)
2019/08/15 职场文书
Mac M1安装mnmp (Mac+Nginx+MySQL+PHP) 开发环境
2021/03/29 PHP
HTML页面滚动时部分内容位置固定不滚动的实现
2021/04/14 HTML / CSS
MySQL 自动填充 create_time 和 update_time
2022/05/20 MySQL
Spring boot实现上传文件到本地服务器
2022/08/14 Java/Android