解决Keras 自定义层时遇到版本的问题


Posted in Python onJune 16, 2020

在2.2.0版本前,

from keras import backend as K
from keras.engine.topology import Layer
 
class MyLayer(Layer):
 
  def __init__(self, output_dim, **kwargs):
    self.output_dim = output_dim
    super(MyLayer, self).__init__(**kwargs)
 
  def build(self, input_shape):
    # 为该层创建一个可训练的权重
    self.kernel = self.add_weight(name='kernel', 
                   shape=(input_shape[1], self.output_dim),
                   initializer='uniform',
                   trainable=True)
    super(MyLayer, self).build(input_shape) # 一定要在最后调用它
 
  def call(self, x):
    return K.dot(x, self.kernel)
 
  def compute_output_shape(self, input_shape):
    return (input_shape[0], self.output_dim)

2.2.0 版本时:

from keras import backend as K
from keras.layers import Layer
 
class MyLayer(Layer):
 
  def __init__(self, output_dim, **kwargs):
    self.output_dim = output_dim
    super(MyLayer, self).__init__(**kwargs)
 
  def build(self, input_shape):
    # Create a trainable weight variable for this layer.
    self.kernel = self.add_weight(name='kernel', 
                   shape=(input_shape[1], self.output_dim),
                   initializer='uniform',
                   trainable=True)
    super(MyLayer, self).build(input_shape) # Be sure to call this at the end
 
  def call(self, x):
    return K.dot(x, self.kernel)
 
  def compute_output_shape(self, input_shape):
    return (input_shape[0], self.output_dim)

如果你遇到:

<module> from keras.engine.base_layer import InputSpec ModuleNotFoundError: No module named 'keras.engine.base_layer'

不妨试试另一种引入!

补充知识:Keras自定义损失函数在场景分类的使用

在做图像场景分类的过程中,需要自定义损失函数,遇到很多坑。Keras自带的损失函数都在losses.py文件中。(以下默认为分类处理)

#losses.py
#y_true是分类的标签,y_pred是分类中预测值(这里指,模型最后一层为softmax层,输出的是每个类别的预测值)
def mean_squared_error(y_true, y_pred):
  return K.mean(K.square(y_pred - y_true), axis=-1)
def mean_absolute_error(y_true, y_pred):
  return K.mean(K.abs(y_pred - y_true), axis=-1)
def mean_absolute_percentage_error(y_true, y_pred):
  diff = K.abs((y_true - y_pred) / K.clip(K.abs(y_true),K.epsilon(),None))
  return 100. * K.mean(diff, axis=-1)
def mean_squared_logarithmic_error(y_true, y_pred):
  first_log = K.log(K.clip(y_pred, K.epsilon(), None) + 1.)
  second_log = K.log(K.clip(y_true, K.epsilon(), None) + 1.)
  return K.mean(K.square(first_log - second_log), axis=-1)
def squared_hinge(y_true, y_pred):
  return K.mean(K.square(K.maximum(1. - y_true * y_pred, 0.)), axis=-1)

这里面简单的来说,y_true就是训练数据的标签,y_pred就是模型训练时经过softmax层的预测值。经过计算,得出损失值。

那么我们要新建损失函数totoal_loss,就要在本文件下,进行新建。

def get_loss(labels,features, alpha,lambda_c,lambda_g,num_classes):
  #由于涉及研究内容,详细代码不做公开
  return loss
#total_loss(y_true,y_pred),y_true代表标签(类别),y_pred代表模型的输出
#( 如果是模型中间层输出,即代表特征,如果模型输出是经过softmax就是代表分类预测值)
#其他有需要的参数也可以写在里面
def total_loss(y_true,y_pred):
    git_loss=get_loss(y_true,y_pred,alpha=0.5,lambda_c=0.001,lambda_g=0.001,num_classes=45)
    return git_loss

自定义损失函数写好之后,可以进行使用了。这里,我使用交叉熵损失函数和自定义损失函数一起使用。

#这里使用vgg16模型
model = VGG16(input_tensor=image_input, include_top=True,weights='imagenet')
model.summary()
#fc2层输出为特征
last_layer = model.get_layer('fc2').output
#获取特征
feature = last_layer
#softmax层输出为各类的预测值
out = Dense(num_classes,activation = 'softmax',name='predictions')(last_layer)
#该模型有一个输入image_input,两个输出out,feature
custom_vgg_model = Model(inputs = image_input, outputs = [feature,out])
custom_vgg_model.summary()
#优化器,梯度下降
sgd = optimizers.SGD(lr=learn_Rate,decay=decay_Rate,momentum=0.9,nesterov=True)
#这里面,刚才有两个输出,这里面使用两个损失函数,total_loss对应的是fc2层输出的特征
#categorical_crossentropy对应softmax层的损失函数
#loss_weights两个损失函数的权重
custom_vgg_model.compile(loss={'fc2': 'total_loss','predictions': "categorical_crossentropy"},
             loss_weights={'fc2': 1, 'predictions':1},optimizer= sgd,
                   metrics={'predictions': 'accuracy'})
#这里使用dummy1,dummy2做演示,为0
dummy1 = np.zeros((y_train.shape[0],4096))
dummy2 = np.zeros((y_test.shape[0],4096))
#模型的输入输出必须和model.fit()中x,y两个参数维度相同
#dummy1的维度和fc2层输出的feature维度相同,y_train和softmax层输出的预测值维度相同
#validation_data验证数据集也是如此,需要和输出层的维度相同
hist = custom_vgg_model.fit(x = X_train,y = {'fc2':dummy1,'predictions':y_train},batch_size=batch_Sizes,
                epochs=epoch_Times, verbose=1,validation_data=(X_test, {'fc2':dummy2,'predictions':y_test}))

写到这里差不多就可以了,不够详细,以后再做补充。

以上这篇解决Keras 自定义层时遇到版本的问题就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
python数据结构之二叉树的统计与转换实例
Apr 29 Python
python编程实现随机生成多个椭圆实例代码
Jan 03 Python
python实现微信跳一跳辅助工具步骤详解
Jan 04 Python
Python subprocess模块详细解读
Jan 29 Python
对Python 3.5拼接列表的新语法详解
Nov 08 Python
python获取服务器响应cookie的实例
Dec 28 Python
wxpython布局的实现方法
Nov 01 Python
Pandas-Cookbook 时间戳处理方式
Dec 07 Python
Python urlopen()和urlretrieve()用法解析
Jan 07 Python
Tensorflow实现在训练好的模型上进行测试
Jan 20 Python
Python 格式化输出_String Formatting_控制小数点位数的实例详解
Feb 04 Python
python为Django项目上的每个应用程序创建不同的自定义404页面(最佳答案)
Mar 09 Python
Keras实现支持masking的Flatten层代码
Jun 16 #Python
Keras自定义实现带masking的meanpooling层方式
Jun 16 #Python
浅谈keras 的抽象后端(from keras import backend as K)
Jun 16 #Python
记录模型训练时loss值的变化情况
Jun 16 #Python
python实现批量转换图片为黑白
Jun 16 #Python
在keras中实现查看其训练loss值
Jun 16 #Python
安装python3.7编译器后如何正确安装opnecv的方法详解
Jun 16 #Python
You might like
融入意大利的咖啡文化
2021/03/03 咖啡文化
《PHP编程最快明白》第八讲:php启发和小结
2010/11/01 PHP
smarty 缓存控制前的页面静态化原理
2013/03/15 PHP
PHP实例分享判断客户端是否使用代理服务器及其匿名级别
2014/06/04 PHP
ThinkPHP、ZF2、Yaf、Laravel框架路由大比拼
2015/03/25 PHP
php curl抓取网页的介绍和推广及使用CURL抓取淘宝页面集成方法
2015/11/30 PHP
FormValidate 表单验证功能代码更新并提供下载
2008/08/23 Javascript
jQuery Ajax文件上传(php)
2009/06/16 Javascript
jquery实现固定顶部导航效果(仿蘑菇街)
2013/03/21 Javascript
appendChild() 或 insertBefore()使用与区别介绍
2013/10/11 Javascript
addEventListener 的用法示例介绍
2014/05/07 Javascript
JavaScript中的函数重载深入理解
2014/08/04 Javascript
javascript 应用小技巧方法汇总
2015/07/05 Javascript
浅析JavaScript访问对象属性和方法及区别
2015/11/16 Javascript
JavaScript递归操作实例浅析
2016/10/31 Javascript
js实现上下左右弹框划出效果
2017/03/08 Javascript
React 高阶组件入门介绍
2018/01/11 Javascript
react router4+redux实现路由权限控制的方法
2018/05/03 Javascript
解决淘宝cnpm 安装后cnpm不是内部或外部命令的问题
2018/05/17 Javascript
浅谈Node.js 中间件模式
2018/06/12 Javascript
js动态设置select下拉菜单的默认选中项实例
2018/08/21 Javascript
详解json串反转义(消除反斜杠)
2019/08/12 Javascript
Angular 中使用 FineReport不显示报表直接打印预览
2019/08/21 Javascript
Vue中使用JsonView来展示Json树的实例代码
2020/11/16 Javascript
Vue3+elementui plus创建项目的方法
2020/12/01 Vue.js
Vue如何实现变量表达式选择器
2021/02/18 Vue.js
[48:20]OpTic vs Serenity 2018国际邀请赛小组赛BO2 第二场 8.18
2018/08/19 DOTA
浅谈Python中的闭包
2015/07/08 Python
python中使用序列的方法
2015/08/03 Python
TensorFlow平台下Python实现神经网络
2018/03/10 Python
django传值给模板, 再用JS接收并进行操作的实例
2018/05/28 Python
Python实现从N个数中找到最大的K个数
2020/04/02 Python
基于CSS3实现的黑色个性导航菜单效果
2015/09/14 HTML / CSS
小班下学期评语
2014/05/04 职场文书
最新农村养殖致富:资金投入较低的创业项目有哪些?
2019/09/26 职场文书
python基础之匿名函数详解
2021/04/21 Python