keras小技巧——获取某一个网络层的输出方式


Posted in Python onMay 23, 2020

前言:

keras默认提供了如何获取某一个层的某一个节点的输出,但是没有提供如何获取某一个层的输出的接口,所以有时候我们需要获取某一个层的输出,则需要自己编写代码,但是鉴于keras高层封装的特性,编写起来实际上很简单,本文提供两种常见的方法来实现,基于上一篇文章的模型和代码: keras自定义回调函数查看训练的loss和accuracy

一、模型加载以及各个层的信息查看

从前面的定义可知,参见上一篇文章,一共定义了8个网络层,定义如下:

model.add(Convolution2D(filters=6, kernel_size=(5, 5), padding='valid', input_shape=(img_rows, img_cols, 1), activation='tanh'))
model.add(MaxPooling2D(pool_size=(2, 2)))
model.add(Convolution2D(filters=16, kernel_size=(5, 5), padding='valid', activation='tanh'))
model.add(MaxPooling2D(pool_size=(2, 2)))
model.add(Flatten())
model.add(Dense(120, activation='tanh'))
model.add(Dense(84, activation='tanh'))
model.add(Dense(n_classes, activation='softmax'))

这里每一个层都没有起名字,实际上最好给每一个层取一个名字,所以这里就使用索引来访问层,如下:

for index in range(8):
 layer=model.get_layer(index=index)
 # layer=model.layers[index] # 这样获取每一个层也是一样的
 print(model)
 
'''运行结果如下:
<keras.engine.sequential.Sequential object at 0x0000012A4F232E10>
<keras.engine.sequential.Sequential object at 0x0000012A4F232E10>
<keras.engine.sequential.Sequential object at 0x0000012A4F232E10>
<keras.engine.sequential.Sequential object at 0x0000012A4F232E10>
<keras.engine.sequential.Sequential object at 0x0000012A4F232E10>
<keras.engine.sequential.Sequential object at 0x0000012A4F232E10>
<keras.engine.sequential.Sequential object at 0x0000012A4F232E10>
<keras.engine.sequential.Sequential object at 0x0000012A4F232E10>
'''

当然由于 model.laters是一个列表,所以可以一次性打印出所有的层信息,即

print(model.layers) # 打印出所有的层

二、模型的加载

准备测试数据

# 训练参数
learning_rate = 0.001
epochs = 10
batch_size = 128
n_classes = 10
 
# 定义图像维度reshape
img_rows, img_cols = 28, 28
 
# 加载keras中的mnist数据集 分为60,000个训练集,10,000个测试集
(x_train, y_train), (x_test, y_test) = mnist.load_data()
 
# 将图片转化为(samples,width,height,channels)的格式
x_train = x_train.reshape(x_train.shape[0], img_rows, img_cols, 1)
x_test = x_test.reshape(x_test.shape[0], img_rows, img_cols, 1)
 
# 将X_train, X_test的数据格式转为float32
x_train = x_train.astype('float32')
x_test = x_test.astype('float32')
# 将X_train, X_test归一化0-1
x_train /= 255
x_test /= 255
 
# 输出0-9转换为ont-hot形式
y_train = np_utils.to_categorical(y_train, n_classes)
y_test = np_utils.to_categorical(y_test, n_classes)

模型的加载

model=keras.models.load_model('./models/lenet5_weight.h5')

注意事项:

keras的每一个层有一个input和output属性,但是它是只针对单节点的层而言的哦,否则就不需要我们再自己编写输出函数了,

如果一个层具有单个节点 (i.e. 如果它不是共享层), 你可以得到它的输入张量、输出张量、输入尺寸和输出尺寸:

layer.input
layer.output
layer.input_shape
layer.output_shape

如果层有多个节点 (参见: 层节点和共享层的概念), 您可以使用以下函数:

layer.get_input_at(node_index)
layer.get_output_at(node_index)
layer.get_input_shape_at(node_index)
layer.get_output_shape_at(node_index)

三、获取某一个层的输出的方法定义

3.1 第一种实现方法

def get_output_function(model,output_layer_index):
 '''
 model: 要保存的模型
 output_layer_index:要获取的那一个层的索引
 '''
 vector_funcrion=K.function([model.layers[0].input],[model.layers[output_layer_index].output])
 def inner(input_data):
  vector=vector_funcrion([input_data])[0]
  return vector
 
 return inner
 
# 现在仅仅测试一张图片
#选择一张图片,选择第一张
x= np.expand_dims(x_test[1],axis=0) #[1,28,28,1] 的形状
 
get_feature=get_output_function(model,6) # 该函数的返回值依然是一个函数哦,获取第6层输出
 
feature=get_feature(x) # 相当于调用 定义在里面的inner函数
print(feature)
'''运行结果为
[[-0.99986297 -0.9988328 -0.9273474 0.9101525 -0.9054705 -0.95798373
 0.9911243 0.78576803 0.99676156 0.39356467 -0.9724135 -0.74534595
 0.8527011 -0.9968267 -0.9420816 -0.32765102 -0.41667578 0.99942905
 0.92333794 0.7565034 -0.38416263 -0.994241 0.3781617 0.9621943
 0.9443946 0.9671554 -0.01000021 -0.9984282 -0.96650964 -0.9925837
 -0.48193568 -0.9749565 -0.79769516 0.9651831 0.9678705 -0.9444472
 0.9405674 0.97538495 -0.12366439 -0.9973782 0.05803521 0.9159217
 -0.9627071 0.99898154 0.99429387 -0.985909 0.5787794 -0.9789403
 -0.94316894 0.9999644 0.9156823 0.46314353 -0.01582102 0.98359734
 0.5586145 -0.97360635 0.99058044 0.9995654 -0.9800733 0.99942625
 0.8786553 -0.9992093 0.99916387 -0.5141877 0.99970615 0.28427476
 0.86589384 0.7649907 -0.9986046 0.9999706 -0.9892468 0.99854743
 -0.86872625 -0.9997323 0.98981035 -0.87805724 -0.9999373 -0.7842255
 -0.97456616 -0.97237325 -0.729563 0.98718935 0.9992022 -0.5294769 ]]
'''

但是上面的实现方法似乎不是很简单,还有更加简单的方法,思想来源与keras中,可以将整个模型model也当成是层layer来处理,实现如下面。

3.2 第二种实现方法

import keras
import numpy as np
from keras.datasets import mnist
from keras.models import Model
 
model=keras.models.load_model('./models/lenet5_weight.h5')
 
#选择一张图片,选择第一张
x= np.expand_dims(x_test[1],axis=0) #[1,28,28,1] 的形状
 
# 将模型作为一个层,输出第7层的输出
layer_model = Model(inputs=model.input, outputs=model.layers[6].output)
 
feature=layer_model.predict(x)
 
print(feature)
'''运行结果为:
[[-0.99986297 -0.9988328 -0.9273474 0.9101525 -0.9054705 -0.95798373
 0.9911243 0.78576803 0.99676156 0.39356467 -0.9724135 -0.74534595
 0.8527011 -0.9968267 -0.9420816 -0.32765102 -0.41667578 0.99942905
 0.92333794 0.7565034 -0.38416263 -0.994241 0.3781617 0.9621943
 0.9443946 0.9671554 -0.01000021 -0.9984282 -0.96650964 -0.9925837
 -0.48193568 -0.9749565 -0.79769516 0.9651831 0.9678705 -0.9444472
 0.9405674 0.97538495 -0.12366439 -0.9973782 0.05803521 0.9159217
 -0.9627071 0.99898154 0.99429387 -0.985909 0.5787794 -0.9789403
 -0.94316894 0.9999644 0.9156823 0.46314353 -0.01582102 0.98359734
 0.5586145 -0.97360635 0.99058044 0.9995654 -0.9800733 0.99942625
 0.8786553 -0.9992093 0.99916387 -0.5141877 0.99970615 0.28427476
 0.86589384 0.7649907 -0.9986046 0.9999706 -0.9892468 0.99854743
 -0.86872625 -0.9997323 0.98981035 -0.87805724 -0.9999373 -0.7842255
 -0.97456616 -0.97237325 -0.729563 0.98718935 0.9992022 -0.5294769 ]]
'''

可见和上面的结果是一样的,

总结:

由于keras的层与模型之间实际上的转化关系,所以提供了非常灵活的输出方法,推荐使用第二种方法获得某一个层的输出。总结为以下几个主要的步骤(四步走):

import keras
import numpy as np
from keras.datasets import mnist
from keras.models import Model
 
# 第一步:准备输入数据
x= np.expand_dims(x_test[1],axis=0) #[1,28,28,1] 的形状
 
# 第二步:加载已经训练的模型
model=keras.models.load_model('./models/lenet5_weight.h5')
 
# 第三步:将模型作为一个层,输出第7层的输出
layer_model = Model(inputs=model.input, outputs=model.layers[6].output)
 
# 第四步:调用新建的“曾模型”的predict方法,得到模型的输出
feature=layer_model.predict(x)
 
print(feature)

以上这篇keras小技巧——获取某一个网络层的输出方式就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
python3新特性函数注释Function Annotations用法分析
Jul 28 Python
python如何获取服务器硬件信息
May 11 Python
详解python里使用正则表达式的全匹配功能
Oct 19 Python
Python画柱状统计图操作示例【基于matplotlib库】
Jul 04 Python
在Pycharm中自动添加时间日期作者等信息的方法
Jan 16 Python
用python实现刷点击率的示例代码
Feb 21 Python
Django使用AJAX调用自己写的API接口的方法
Mar 06 Python
Python代码实现删除一个list里面重复元素的方法
Apr 02 Python
pyqt实现.ui文件批量转换为对应.py文件脚本
Jun 19 Python
python pandas移动窗口函数rolling的用法
Feb 29 Python
Python 批量读取文件中指定字符的实现
Mar 06 Python
pycharm如何设置官方中文(如何汉化)
Dec 29 Python
keras自定义回调函数查看训练的loss和accuracy方式
May 23 #Python
Keras设定GPU使用内存大小方式(Tensorflow backend)
May 22 #Python
tensorflow使用L2 regularization正则化修正overfitting过拟合方式
May 22 #Python
Softmax函数原理及Python实现过程解析
May 22 #Python
Python接口测试文件上传实例解析
May 22 #Python
计算Python Numpy向量之间的欧氏距离实例
May 22 #Python
python numpy矩阵信息说明,shape,size,dtype
May 22 #Python
You might like
php &amp;&amp; 逻辑与运算符使用说明
2010/03/04 PHP
使用bcompiler对PHP文件进行加密的代码
2010/08/29 PHP
PHP CURL模拟登录新浪微博抓取页面内容 基于EaglePHP框架开发
2012/01/16 PHP
ThinkPHP多语言支持与多模板支持概述
2014/08/22 PHP
基于ThinkPHP实现批量删除
2015/12/18 PHP
CodeIgniter框架常见用法工作总结
2017/03/16 PHP
PHP封装的page分页类定义与用法完整示例
2018/12/24 PHP
动态调用CSS文件的JS代码
2010/07/29 Javascript
Jquery 表格合并的问题分享
2011/09/17 Javascript
JavaScript实现的日期控件具体代码
2013/11/18 Javascript
jQuery选择器简明总结(含用法实例,一目了然)
2014/04/25 Javascript
PHP PDO操作总结
2014/11/17 Javascript
jQuery类选择器用法实例
2014/12/23 Javascript
jQuery实现产品对比功能附源码下载
2016/08/09 Javascript
Javascript实现代码折叠功能
2016/08/25 Javascript
js防阻塞加载的实现方法
2016/09/09 Javascript
解析NodeJs的调试方法
2016/12/11 NodeJs
通过jsonp获取json数据实现AJAX跨域请求
2017/01/22 Javascript
angularjs使用directive实现分页组件的示例
2017/02/07 Javascript
mui back 返回刷新页面的实例
2017/12/06 Javascript
使用ngrok+express解决本地环境中微信接口调试问题
2018/02/26 Javascript
mpvue全局引入sass文件的方法步骤
2019/03/06 Javascript
javascript实现手动点赞效果
2019/04/09 Javascript
Layui table field初始化加载时进行隐藏的方法
2019/09/19 Javascript
jquery实现轮播图特效
2020/04/12 jQuery
vue中使用腾讯云Im的示例
2020/10/23 Javascript
Python的Django框架中的select_related函数对QuerySet 查询的优化
2015/04/01 Python
用Python中的字典来处理索引统计的方法
2015/05/05 Python
django接入新浪微博OAuth的方法
2015/06/29 Python
Pandas 重塑(stack)和轴向旋转(pivot)的实现
2019/07/22 Python
numpy 声明空数组详解
2019/12/05 Python
python中matplotlib实现随鼠标滑动自动标注代码
2020/04/23 Python
全网最细 Python 格式化输出用法讲解(推荐)
2021/01/18 Python
新加坡时尚网上购物:Zalora新加坡
2016/07/26 全球购物
考博自荐信
2013/10/25 职场文书
详解Go语言运用广度优先搜索走迷宫
2021/06/23 Python