keras小技巧——获取某一个网络层的输出方式


Posted in Python onMay 23, 2020

前言:

keras默认提供了如何获取某一个层的某一个节点的输出,但是没有提供如何获取某一个层的输出的接口,所以有时候我们需要获取某一个层的输出,则需要自己编写代码,但是鉴于keras高层封装的特性,编写起来实际上很简单,本文提供两种常见的方法来实现,基于上一篇文章的模型和代码: keras自定义回调函数查看训练的loss和accuracy

一、模型加载以及各个层的信息查看

从前面的定义可知,参见上一篇文章,一共定义了8个网络层,定义如下:

model.add(Convolution2D(filters=6, kernel_size=(5, 5), padding='valid', input_shape=(img_rows, img_cols, 1), activation='tanh'))
model.add(MaxPooling2D(pool_size=(2, 2)))
model.add(Convolution2D(filters=16, kernel_size=(5, 5), padding='valid', activation='tanh'))
model.add(MaxPooling2D(pool_size=(2, 2)))
model.add(Flatten())
model.add(Dense(120, activation='tanh'))
model.add(Dense(84, activation='tanh'))
model.add(Dense(n_classes, activation='softmax'))

这里每一个层都没有起名字,实际上最好给每一个层取一个名字,所以这里就使用索引来访问层,如下:

for index in range(8):
 layer=model.get_layer(index=index)
 # layer=model.layers[index] # 这样获取每一个层也是一样的
 print(model)
 
'''运行结果如下:
<keras.engine.sequential.Sequential object at 0x0000012A4F232E10>
<keras.engine.sequential.Sequential object at 0x0000012A4F232E10>
<keras.engine.sequential.Sequential object at 0x0000012A4F232E10>
<keras.engine.sequential.Sequential object at 0x0000012A4F232E10>
<keras.engine.sequential.Sequential object at 0x0000012A4F232E10>
<keras.engine.sequential.Sequential object at 0x0000012A4F232E10>
<keras.engine.sequential.Sequential object at 0x0000012A4F232E10>
<keras.engine.sequential.Sequential object at 0x0000012A4F232E10>
'''

当然由于 model.laters是一个列表,所以可以一次性打印出所有的层信息,即

print(model.layers) # 打印出所有的层

二、模型的加载

准备测试数据

# 训练参数
learning_rate = 0.001
epochs = 10
batch_size = 128
n_classes = 10
 
# 定义图像维度reshape
img_rows, img_cols = 28, 28
 
# 加载keras中的mnist数据集 分为60,000个训练集,10,000个测试集
(x_train, y_train), (x_test, y_test) = mnist.load_data()
 
# 将图片转化为(samples,width,height,channels)的格式
x_train = x_train.reshape(x_train.shape[0], img_rows, img_cols, 1)
x_test = x_test.reshape(x_test.shape[0], img_rows, img_cols, 1)
 
# 将X_train, X_test的数据格式转为float32
x_train = x_train.astype('float32')
x_test = x_test.astype('float32')
# 将X_train, X_test归一化0-1
x_train /= 255
x_test /= 255
 
# 输出0-9转换为ont-hot形式
y_train = np_utils.to_categorical(y_train, n_classes)
y_test = np_utils.to_categorical(y_test, n_classes)

模型的加载

model=keras.models.load_model('./models/lenet5_weight.h5')

注意事项:

keras的每一个层有一个input和output属性,但是它是只针对单节点的层而言的哦,否则就不需要我们再自己编写输出函数了,

如果一个层具有单个节点 (i.e. 如果它不是共享层), 你可以得到它的输入张量、输出张量、输入尺寸和输出尺寸:

layer.input
layer.output
layer.input_shape
layer.output_shape

如果层有多个节点 (参见: 层节点和共享层的概念), 您可以使用以下函数:

layer.get_input_at(node_index)
layer.get_output_at(node_index)
layer.get_input_shape_at(node_index)
layer.get_output_shape_at(node_index)

三、获取某一个层的输出的方法定义

3.1 第一种实现方法

def get_output_function(model,output_layer_index):
 '''
 model: 要保存的模型
 output_layer_index:要获取的那一个层的索引
 '''
 vector_funcrion=K.function([model.layers[0].input],[model.layers[output_layer_index].output])
 def inner(input_data):
  vector=vector_funcrion([input_data])[0]
  return vector
 
 return inner
 
# 现在仅仅测试一张图片
#选择一张图片,选择第一张
x= np.expand_dims(x_test[1],axis=0) #[1,28,28,1] 的形状
 
get_feature=get_output_function(model,6) # 该函数的返回值依然是一个函数哦,获取第6层输出
 
feature=get_feature(x) # 相当于调用 定义在里面的inner函数
print(feature)
'''运行结果为
[[-0.99986297 -0.9988328 -0.9273474 0.9101525 -0.9054705 -0.95798373
 0.9911243 0.78576803 0.99676156 0.39356467 -0.9724135 -0.74534595
 0.8527011 -0.9968267 -0.9420816 -0.32765102 -0.41667578 0.99942905
 0.92333794 0.7565034 -0.38416263 -0.994241 0.3781617 0.9621943
 0.9443946 0.9671554 -0.01000021 -0.9984282 -0.96650964 -0.9925837
 -0.48193568 -0.9749565 -0.79769516 0.9651831 0.9678705 -0.9444472
 0.9405674 0.97538495 -0.12366439 -0.9973782 0.05803521 0.9159217
 -0.9627071 0.99898154 0.99429387 -0.985909 0.5787794 -0.9789403
 -0.94316894 0.9999644 0.9156823 0.46314353 -0.01582102 0.98359734
 0.5586145 -0.97360635 0.99058044 0.9995654 -0.9800733 0.99942625
 0.8786553 -0.9992093 0.99916387 -0.5141877 0.99970615 0.28427476
 0.86589384 0.7649907 -0.9986046 0.9999706 -0.9892468 0.99854743
 -0.86872625 -0.9997323 0.98981035 -0.87805724 -0.9999373 -0.7842255
 -0.97456616 -0.97237325 -0.729563 0.98718935 0.9992022 -0.5294769 ]]
'''

但是上面的实现方法似乎不是很简单,还有更加简单的方法,思想来源与keras中,可以将整个模型model也当成是层layer来处理,实现如下面。

3.2 第二种实现方法

import keras
import numpy as np
from keras.datasets import mnist
from keras.models import Model
 
model=keras.models.load_model('./models/lenet5_weight.h5')
 
#选择一张图片,选择第一张
x= np.expand_dims(x_test[1],axis=0) #[1,28,28,1] 的形状
 
# 将模型作为一个层,输出第7层的输出
layer_model = Model(inputs=model.input, outputs=model.layers[6].output)
 
feature=layer_model.predict(x)
 
print(feature)
'''运行结果为:
[[-0.99986297 -0.9988328 -0.9273474 0.9101525 -0.9054705 -0.95798373
 0.9911243 0.78576803 0.99676156 0.39356467 -0.9724135 -0.74534595
 0.8527011 -0.9968267 -0.9420816 -0.32765102 -0.41667578 0.99942905
 0.92333794 0.7565034 -0.38416263 -0.994241 0.3781617 0.9621943
 0.9443946 0.9671554 -0.01000021 -0.9984282 -0.96650964 -0.9925837
 -0.48193568 -0.9749565 -0.79769516 0.9651831 0.9678705 -0.9444472
 0.9405674 0.97538495 -0.12366439 -0.9973782 0.05803521 0.9159217
 -0.9627071 0.99898154 0.99429387 -0.985909 0.5787794 -0.9789403
 -0.94316894 0.9999644 0.9156823 0.46314353 -0.01582102 0.98359734
 0.5586145 -0.97360635 0.99058044 0.9995654 -0.9800733 0.99942625
 0.8786553 -0.9992093 0.99916387 -0.5141877 0.99970615 0.28427476
 0.86589384 0.7649907 -0.9986046 0.9999706 -0.9892468 0.99854743
 -0.86872625 -0.9997323 0.98981035 -0.87805724 -0.9999373 -0.7842255
 -0.97456616 -0.97237325 -0.729563 0.98718935 0.9992022 -0.5294769 ]]
'''

可见和上面的结果是一样的,

总结:

由于keras的层与模型之间实际上的转化关系,所以提供了非常灵活的输出方法,推荐使用第二种方法获得某一个层的输出。总结为以下几个主要的步骤(四步走):

import keras
import numpy as np
from keras.datasets import mnist
from keras.models import Model
 
# 第一步:准备输入数据
x= np.expand_dims(x_test[1],axis=0) #[1,28,28,1] 的形状
 
# 第二步:加载已经训练的模型
model=keras.models.load_model('./models/lenet5_weight.h5')
 
# 第三步:将模型作为一个层,输出第7层的输出
layer_model = Model(inputs=model.input, outputs=model.layers[6].output)
 
# 第四步:调用新建的“曾模型”的predict方法,得到模型的输出
feature=layer_model.predict(x)
 
print(feature)

以上这篇keras小技巧——获取某一个网络层的输出方式就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
Python中dictionary items()系列函数的用法实例
Aug 21 Python
Python入门篇之函数
Oct 20 Python
Python简单进程锁代码实例
Apr 27 Python
分享6个隐藏的python功能
Dec 07 Python
在pycharm中python切换解释器失败的解决方法
Oct 29 Python
selenium+python自动化测试之页面元素定位
Jan 23 Python
Python常用爬虫代码总结方便查询
Feb 25 Python
搞清楚 Python traceback的具体使用方法
May 13 Python
python实现多线程端口扫描
Aug 31 Python
一文解决django 2.2与mysql兼容性问题
Jul 15 Python
matplotlib grid()设置网格线外观的实现
Feb 22 Python
python元组拆包实现方法
Feb 28 Python
keras自定义回调函数查看训练的loss和accuracy方式
May 23 #Python
Keras设定GPU使用内存大小方式(Tensorflow backend)
May 22 #Python
tensorflow使用L2 regularization正则化修正overfitting过拟合方式
May 22 #Python
Softmax函数原理及Python实现过程解析
May 22 #Python
Python接口测试文件上传实例解析
May 22 #Python
计算Python Numpy向量之间的欧氏距离实例
May 22 #Python
python numpy矩阵信息说明,shape,size,dtype
May 22 #Python
You might like
php strlen mb_strlen计算中英文混排字符串长度
2009/07/10 PHP
php实现的简单检验登陆类
2015/06/18 PHP
4种PHP异步执行的常用方式
2015/12/24 PHP
jQuery 常见学习网站与参考书
2009/11/09 Javascript
js DOM 元素ID就是全局变量
2012/09/20 Javascript
jquery可见性过滤选择器使用示例
2013/06/24 Javascript
js简单实现根据身份证号码识别性别年龄生日
2013/11/29 Javascript
浅析Cookie中的Path与domain
2013/12/18 Javascript
在jquery中的ajax方法怎样通过JSONP进行远程调用
2014/04/04 Javascript
JQuery中$(document)是什么意思有什么作用
2014/07/21 Javascript
60个很实用的jQuery代码开发技巧收集
2014/12/15 Javascript
浅析Node.js的Stream模块中的Readable对象
2015/07/29 Javascript
BootStrap的Datepicker控件使用心得分享
2016/05/25 Javascript
一个仿微博登陆邮箱提示框js开发案例
2016/07/28 Javascript
标准的js无缝滚动效果
2016/08/30 Javascript
js 博客内容进度插件详解
2017/02/19 Javascript
基于JavaScript实现图片剪切效果
2017/03/07 Javascript
javascript实现圣旨卷轴展开效果(代码分享)
2017/03/23 Javascript
细说webpack源码之compile流程-rules参数处理技巧(1)
2017/12/26 Javascript
JavaScript遍历数组和对象的元素简单操作示例
2019/07/09 Javascript
高效jQuery选择器的5个技巧实例分析
2019/11/26 jQuery
微信小程序使用自定义组件导航实现当前页面高亮
2020/01/02 Javascript
Python实现配置文件备份的方法
2015/07/30 Python
基于Python的接口测试框架实例
2016/11/04 Python
详解Django项目中模板标签及模板的继承与引用(网站中快速布置广告)
2019/03/27 Python
jupyter notebook 中输出pyecharts图实例
2020/04/23 Python
Python3查找列表中重复元素的个数的3种方法详解
2020/02/13 Python
python opencv实现图片缺陷检测(讲解直方图以及相关系数对比法)
2020/04/07 Python
Jupyter notebook无法导入第三方模块的解决方式
2020/04/15 Python
django和flask哪个值得研究学习
2020/07/31 Python
HTML5 Canvas 起步(1) - 基本概念
2009/05/12 HTML / CSS
设计师家具购买和委托在线市场:Viyet
2016/11/16 全球购物
哥伦比亚加拿大官网:Columbia Sportswear Canada
2020/09/07 全球购物
学校三八妇女节活动情况总结
2014/03/09 职场文书
导游词之丹东鸭绿江
2019/10/24 职场文书
源码解读Spring-Integration执行过程
2021/06/11 Java/Android