解决Keras 自定义层时遇到版本的问题


Posted in Python onJune 16, 2020

在2.2.0版本前,

from keras import backend as K
from keras.engine.topology import Layer
 
class MyLayer(Layer):
 
  def __init__(self, output_dim, **kwargs):
    self.output_dim = output_dim
    super(MyLayer, self).__init__(**kwargs)
 
  def build(self, input_shape):
    # 为该层创建一个可训练的权重
    self.kernel = self.add_weight(name='kernel', 
                   shape=(input_shape[1], self.output_dim),
                   initializer='uniform',
                   trainable=True)
    super(MyLayer, self).build(input_shape) # 一定要在最后调用它
 
  def call(self, x):
    return K.dot(x, self.kernel)
 
  def compute_output_shape(self, input_shape):
    return (input_shape[0], self.output_dim)

2.2.0 版本时:

from keras import backend as K
from keras.layers import Layer
 
class MyLayer(Layer):
 
  def __init__(self, output_dim, **kwargs):
    self.output_dim = output_dim
    super(MyLayer, self).__init__(**kwargs)
 
  def build(self, input_shape):
    # Create a trainable weight variable for this layer.
    self.kernel = self.add_weight(name='kernel', 
                   shape=(input_shape[1], self.output_dim),
                   initializer='uniform',
                   trainable=True)
    super(MyLayer, self).build(input_shape) # Be sure to call this at the end
 
  def call(self, x):
    return K.dot(x, self.kernel)
 
  def compute_output_shape(self, input_shape):
    return (input_shape[0], self.output_dim)

如果你遇到:

<module> from keras.engine.base_layer import InputSpec ModuleNotFoundError: No module named 'keras.engine.base_layer'

不妨试试另一种引入!

补充知识:Keras自定义损失函数在场景分类的使用

在做图像场景分类的过程中,需要自定义损失函数,遇到很多坑。Keras自带的损失函数都在losses.py文件中。(以下默认为分类处理)

#losses.py
#y_true是分类的标签,y_pred是分类中预测值(这里指,模型最后一层为softmax层,输出的是每个类别的预测值)
def mean_squared_error(y_true, y_pred):
  return K.mean(K.square(y_pred - y_true), axis=-1)
def mean_absolute_error(y_true, y_pred):
  return K.mean(K.abs(y_pred - y_true), axis=-1)
def mean_absolute_percentage_error(y_true, y_pred):
  diff = K.abs((y_true - y_pred) / K.clip(K.abs(y_true),K.epsilon(),None))
  return 100. * K.mean(diff, axis=-1)
def mean_squared_logarithmic_error(y_true, y_pred):
  first_log = K.log(K.clip(y_pred, K.epsilon(), None) + 1.)
  second_log = K.log(K.clip(y_true, K.epsilon(), None) + 1.)
  return K.mean(K.square(first_log - second_log), axis=-1)
def squared_hinge(y_true, y_pred):
  return K.mean(K.square(K.maximum(1. - y_true * y_pred, 0.)), axis=-1)

这里面简单的来说,y_true就是训练数据的标签,y_pred就是模型训练时经过softmax层的预测值。经过计算,得出损失值。

那么我们要新建损失函数totoal_loss,就要在本文件下,进行新建。

def get_loss(labels,features, alpha,lambda_c,lambda_g,num_classes):
  #由于涉及研究内容,详细代码不做公开
  return loss
#total_loss(y_true,y_pred),y_true代表标签(类别),y_pred代表模型的输出
#( 如果是模型中间层输出,即代表特征,如果模型输出是经过softmax就是代表分类预测值)
#其他有需要的参数也可以写在里面
def total_loss(y_true,y_pred):
    git_loss=get_loss(y_true,y_pred,alpha=0.5,lambda_c=0.001,lambda_g=0.001,num_classes=45)
    return git_loss

自定义损失函数写好之后,可以进行使用了。这里,我使用交叉熵损失函数和自定义损失函数一起使用。

#这里使用vgg16模型
model = VGG16(input_tensor=image_input, include_top=True,weights='imagenet')
model.summary()
#fc2层输出为特征
last_layer = model.get_layer('fc2').output
#获取特征
feature = last_layer
#softmax层输出为各类的预测值
out = Dense(num_classes,activation = 'softmax',name='predictions')(last_layer)
#该模型有一个输入image_input,两个输出out,feature
custom_vgg_model = Model(inputs = image_input, outputs = [feature,out])
custom_vgg_model.summary()
#优化器,梯度下降
sgd = optimizers.SGD(lr=learn_Rate,decay=decay_Rate,momentum=0.9,nesterov=True)
#这里面,刚才有两个输出,这里面使用两个损失函数,total_loss对应的是fc2层输出的特征
#categorical_crossentropy对应softmax层的损失函数
#loss_weights两个损失函数的权重
custom_vgg_model.compile(loss={'fc2': 'total_loss','predictions': "categorical_crossentropy"},
             loss_weights={'fc2': 1, 'predictions':1},optimizer= sgd,
                   metrics={'predictions': 'accuracy'})
#这里使用dummy1,dummy2做演示,为0
dummy1 = np.zeros((y_train.shape[0],4096))
dummy2 = np.zeros((y_test.shape[0],4096))
#模型的输入输出必须和model.fit()中x,y两个参数维度相同
#dummy1的维度和fc2层输出的feature维度相同,y_train和softmax层输出的预测值维度相同
#validation_data验证数据集也是如此,需要和输出层的维度相同
hist = custom_vgg_model.fit(x = X_train,y = {'fc2':dummy1,'predictions':y_train},batch_size=batch_Sizes,
                epochs=epoch_Times, verbose=1,validation_data=(X_test, {'fc2':dummy2,'predictions':y_test}))

写到这里差不多就可以了,不够详细,以后再做补充。

以上这篇解决Keras 自定义层时遇到版本的问题就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
Python列表推导式的使用方法
Nov 21 Python
Python中实现结构相似的函数调用方法
Mar 10 Python
Python常用小技巧总结
Jun 01 Python
Windows环境下python环境安装使用图文教程
Mar 13 Python
详解python3中用HTMLTestRunner.py报ImportError: No module named 'StringIO'如何解决
Aug 27 Python
Python制作词云图代码实例
Sep 09 Python
在python中创建指定大小的多维数组方式
Nov 28 Python
django model 条件过滤 queryset.filter(**condtions)用法详解
May 20 Python
python安装及变量名介绍详解
Dec 12 Python
python如何获取网络数据
Apr 11 Python
Matlab求解数组中的最大值及它所在的具体位置
Apr 16 Python
Python使用华为API为图像设置多个锚点标签
Apr 12 Python
Keras实现支持masking的Flatten层代码
Jun 16 #Python
Keras自定义实现带masking的meanpooling层方式
Jun 16 #Python
浅谈keras 的抽象后端(from keras import backend as K)
Jun 16 #Python
记录模型训练时loss值的变化情况
Jun 16 #Python
python实现批量转换图片为黑白
Jun 16 #Python
在keras中实现查看其训练loss值
Jun 16 #Python
安装python3.7编译器后如何正确安装opnecv的方法详解
Jun 16 #Python
You might like
PHP下通过exec获得计算机的唯一标识[CPU,网卡 MAC地址]
2011/06/09 PHP
PHP crc32()函数讲解
2019/02/14 PHP
jQuery 判断页面元素是否存在的代码
2009/08/14 Javascript
javascript 获取元素位置的快速方法 getBoundingClientRect()
2009/11/26 Javascript
学习并汇集javascript匿名函数
2010/11/25 Javascript
让复选框只能选择一项的方法
2013/10/08 Javascript
JavaScript中实现继承的三种方式和实例
2015/01/29 Javascript
IE中鼠标经过option触发mouseout的解决方法
2015/01/29 Javascript
如何利用Promises编写更优雅的JavaScript代码
2016/05/17 Javascript
基于JavaScript实现无缝滚动效果
2017/07/21 Javascript
Vue 2.5 Level E 发布了: 新功能特性一览
2017/10/24 Javascript
Node.js进阶之核心模块https入门
2018/05/23 Javascript
浅谈vue首屏加载优化
2018/06/28 Javascript
详解es6超好用的语法糖Decorator
2018/08/01 Javascript
vue通过滚动行为实现从列表到详情,返回列表原位置的方法
2018/08/31 Javascript
Node.js 使用axios读写influxDB的方法示例
2018/10/26 Javascript
vue点击标签切换选中及互相排斥操作
2020/07/17 Javascript
[03:37]2016完美“圣”典 风云人物:Mikasa专访
2016/12/07 DOTA
跟老齐学Python之Python文档
2014/10/10 Python
在Python中使用__slots__方法的详细教程
2015/04/28 Python
python实现根据ip地址反向查找主机名称的方法
2015/04/29 Python
python使用matplotlib绘制柱状图教程
2017/02/08 Python
关于Python 3中print函数的换行详解
2017/08/08 Python
python通过socket实现多个连接并实现ssh功能详解
2017/11/08 Python
分享8个非常流行的 Python 可视化工具包
2019/06/05 Python
用django设置session过期时间的方法解析
2019/08/05 Python
利用python在大量数据文件下删除某一行的例子
2019/08/21 Python
python excel转换csv代码实例
2019/08/26 Python
英国第一的滑雪服装和装备零售商:Snow+Rock
2020/02/01 全球购物
COSETTE官网:奢华,每天
2020/03/22 全球购物
大学生求职信范文
2014/05/24 职场文书
毕业生银行实习自我鉴定
2014/10/14 职场文书
雨中的树观后感
2015/06/03 职场文书
小学体育跳绳课教学反思
2016/02/16 职场文书
CSS3 实现NES游戏机的示例代码
2021/04/21 HTML / CSS
html+css实现文字折叠特效实例
2021/06/02 HTML / CSS