解决Keras 自定义层时遇到版本的问题


Posted in Python onJune 16, 2020

在2.2.0版本前,

from keras import backend as K
from keras.engine.topology import Layer
 
class MyLayer(Layer):
 
  def __init__(self, output_dim, **kwargs):
    self.output_dim = output_dim
    super(MyLayer, self).__init__(**kwargs)
 
  def build(self, input_shape):
    # 为该层创建一个可训练的权重
    self.kernel = self.add_weight(name='kernel', 
                   shape=(input_shape[1], self.output_dim),
                   initializer='uniform',
                   trainable=True)
    super(MyLayer, self).build(input_shape) # 一定要在最后调用它
 
  def call(self, x):
    return K.dot(x, self.kernel)
 
  def compute_output_shape(self, input_shape):
    return (input_shape[0], self.output_dim)

2.2.0 版本时:

from keras import backend as K
from keras.layers import Layer
 
class MyLayer(Layer):
 
  def __init__(self, output_dim, **kwargs):
    self.output_dim = output_dim
    super(MyLayer, self).__init__(**kwargs)
 
  def build(self, input_shape):
    # Create a trainable weight variable for this layer.
    self.kernel = self.add_weight(name='kernel', 
                   shape=(input_shape[1], self.output_dim),
                   initializer='uniform',
                   trainable=True)
    super(MyLayer, self).build(input_shape) # Be sure to call this at the end
 
  def call(self, x):
    return K.dot(x, self.kernel)
 
  def compute_output_shape(self, input_shape):
    return (input_shape[0], self.output_dim)

如果你遇到:

<module> from keras.engine.base_layer import InputSpec ModuleNotFoundError: No module named 'keras.engine.base_layer'

不妨试试另一种引入!

补充知识:Keras自定义损失函数在场景分类的使用

在做图像场景分类的过程中,需要自定义损失函数,遇到很多坑。Keras自带的损失函数都在losses.py文件中。(以下默认为分类处理)

#losses.py
#y_true是分类的标签,y_pred是分类中预测值(这里指,模型最后一层为softmax层,输出的是每个类别的预测值)
def mean_squared_error(y_true, y_pred):
  return K.mean(K.square(y_pred - y_true), axis=-1)
def mean_absolute_error(y_true, y_pred):
  return K.mean(K.abs(y_pred - y_true), axis=-1)
def mean_absolute_percentage_error(y_true, y_pred):
  diff = K.abs((y_true - y_pred) / K.clip(K.abs(y_true),K.epsilon(),None))
  return 100. * K.mean(diff, axis=-1)
def mean_squared_logarithmic_error(y_true, y_pred):
  first_log = K.log(K.clip(y_pred, K.epsilon(), None) + 1.)
  second_log = K.log(K.clip(y_true, K.epsilon(), None) + 1.)
  return K.mean(K.square(first_log - second_log), axis=-1)
def squared_hinge(y_true, y_pred):
  return K.mean(K.square(K.maximum(1. - y_true * y_pred, 0.)), axis=-1)

这里面简单的来说,y_true就是训练数据的标签,y_pred就是模型训练时经过softmax层的预测值。经过计算,得出损失值。

那么我们要新建损失函数totoal_loss,就要在本文件下,进行新建。

def get_loss(labels,features, alpha,lambda_c,lambda_g,num_classes):
  #由于涉及研究内容,详细代码不做公开
  return loss
#total_loss(y_true,y_pred),y_true代表标签(类别),y_pred代表模型的输出
#( 如果是模型中间层输出,即代表特征,如果模型输出是经过softmax就是代表分类预测值)
#其他有需要的参数也可以写在里面
def total_loss(y_true,y_pred):
    git_loss=get_loss(y_true,y_pred,alpha=0.5,lambda_c=0.001,lambda_g=0.001,num_classes=45)
    return git_loss

自定义损失函数写好之后,可以进行使用了。这里,我使用交叉熵损失函数和自定义损失函数一起使用。

#这里使用vgg16模型
model = VGG16(input_tensor=image_input, include_top=True,weights='imagenet')
model.summary()
#fc2层输出为特征
last_layer = model.get_layer('fc2').output
#获取特征
feature = last_layer
#softmax层输出为各类的预测值
out = Dense(num_classes,activation = 'softmax',name='predictions')(last_layer)
#该模型有一个输入image_input,两个输出out,feature
custom_vgg_model = Model(inputs = image_input, outputs = [feature,out])
custom_vgg_model.summary()
#优化器,梯度下降
sgd = optimizers.SGD(lr=learn_Rate,decay=decay_Rate,momentum=0.9,nesterov=True)
#这里面,刚才有两个输出,这里面使用两个损失函数,total_loss对应的是fc2层输出的特征
#categorical_crossentropy对应softmax层的损失函数
#loss_weights两个损失函数的权重
custom_vgg_model.compile(loss={'fc2': 'total_loss','predictions': "categorical_crossentropy"},
             loss_weights={'fc2': 1, 'predictions':1},optimizer= sgd,
                   metrics={'predictions': 'accuracy'})
#这里使用dummy1,dummy2做演示,为0
dummy1 = np.zeros((y_train.shape[0],4096))
dummy2 = np.zeros((y_test.shape[0],4096))
#模型的输入输出必须和model.fit()中x,y两个参数维度相同
#dummy1的维度和fc2层输出的feature维度相同,y_train和softmax层输出的预测值维度相同
#validation_data验证数据集也是如此,需要和输出层的维度相同
hist = custom_vgg_model.fit(x = X_train,y = {'fc2':dummy1,'predictions':y_train},batch_size=batch_Sizes,
                epochs=epoch_Times, verbose=1,validation_data=(X_test, {'fc2':dummy2,'predictions':y_test}))

写到这里差不多就可以了,不够详细,以后再做补充。

以上这篇解决Keras 自定义层时遇到版本的问题就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
web.py在模板中输出美元符号的方法
Aug 26 Python
详解Python中的__getitem__方法与slice对象的切片操作
Jun 27 Python
Python3的urllib.parse常用函数小结(urlencode,quote,quote_plus,unquote,unquote_plus等)
Sep 18 Python
python3大文件解压和基本操作
Dec 15 Python
Python3.遍历某文件夹提取特定文件名的实例
Apr 26 Python
python实现归并排序算法
Nov 22 Python
python批量获取html内body内容的实例
Jan 02 Python
python为QT程序添加图标的方法详解
Mar 09 Python
keras中epoch,batch,loss,val_loss用法说明
Jul 02 Python
pycharm 添加解释器的方法步骤
Aug 31 Python
Python通过Schema实现数据验证方式
Nov 12 Python
Python time库的时间时钟处理
May 02 Python
Keras实现支持masking的Flatten层代码
Jun 16 #Python
Keras自定义实现带masking的meanpooling层方式
Jun 16 #Python
浅谈keras 的抽象后端(from keras import backend as K)
Jun 16 #Python
记录模型训练时loss值的变化情况
Jun 16 #Python
python实现批量转换图片为黑白
Jun 16 #Python
在keras中实现查看其训练loss值
Jun 16 #Python
安装python3.7编译器后如何正确安装opnecv的方法详解
Jun 16 #Python
You might like
使用PHP socke 向指定页面提交数据
2008/07/23 PHP
使用php判断服务器是否支持Gzip压缩功能
2013/09/24 PHP
php使用gzip压缩传输js和css文件的方法
2015/07/29 PHP
laravel中Redis队列监听中断的分析
2020/09/14 PHP
Prototype Selector对象学习
2009/07/23 Javascript
jquery.fileEveryWhere.js 一个跨浏览器的file显示插件
2011/10/24 Javascript
Javascript 实现的数独解题算法网页实例
2013/10/15 Javascript
简介JavaScript中的getSeconds()方法的使用
2015/06/10 Javascript
JavaScript编程中的Promise使用大全
2015/07/28 Javascript
javascript断点调试心得分享
2016/04/23 Javascript
js选择器全面解析
2016/06/27 Javascript
JavaScript实现form表单的多文件上传
2020/03/27 Javascript
ES6学习教程之模板字符串详解
2017/10/09 Javascript
浅谈手写node可读流之流动模式
2018/06/01 Javascript
JS实现简单的星期格式转换功能示例
2018/07/23 Javascript
jquery实现动态添加附件功能
2018/10/23 jQuery
vue移动端屏幕适配详解
2019/04/30 Javascript
vue控制多行文字展开收起的实现示例
2019/10/11 Javascript
python实现从字典中删除元素的方法
2015/05/04 Python
Python采用Django制作简易的知乎日报API
2016/08/03 Python
tensorflow实现简单的卷积神经网络
2018/05/24 Python
pandas dataframe的合并实现(append, merge, concat)
2019/06/24 Python
django多个APP的urls设置方法(views重复问题解决)
2019/07/19 Python
Python字符串格式化输出代码实例
2019/11/22 Python
在python tkinter界面中添加按钮的实例
2020/03/04 Python
Python基于pandas绘制散点图矩阵代码实例
2020/06/04 Python
Python ConfigParser模块的使用示例
2020/10/12 Python
10张动图学会python循环与递归问题
2021/02/06 Python
CSS3属性使网站设计增强同时不消弱可用性
2009/08/29 HTML / CSS
利用CSS3参考手册和CSS3代码生成工具加速来学习网页制
2012/07/11 HTML / CSS
银行优秀员工事迹
2014/02/06 职场文书
环保倡议书
2014/04/14 职场文书
2014财务人员自我评价范文
2014/09/21 职场文书
2014年教研组工作总结
2014/11/26 职场文书
2016秋季运动会前导词
2015/11/25 职场文书
年中了,该如何写好个人述职报告?
2019/07/02 职场文书