解决Keras 自定义层时遇到版本的问题


Posted in Python onJune 16, 2020

在2.2.0版本前,

from keras import backend as K
from keras.engine.topology import Layer
 
class MyLayer(Layer):
 
  def __init__(self, output_dim, **kwargs):
    self.output_dim = output_dim
    super(MyLayer, self).__init__(**kwargs)
 
  def build(self, input_shape):
    # 为该层创建一个可训练的权重
    self.kernel = self.add_weight(name='kernel', 
                   shape=(input_shape[1], self.output_dim),
                   initializer='uniform',
                   trainable=True)
    super(MyLayer, self).build(input_shape) # 一定要在最后调用它
 
  def call(self, x):
    return K.dot(x, self.kernel)
 
  def compute_output_shape(self, input_shape):
    return (input_shape[0], self.output_dim)

2.2.0 版本时:

from keras import backend as K
from keras.layers import Layer
 
class MyLayer(Layer):
 
  def __init__(self, output_dim, **kwargs):
    self.output_dim = output_dim
    super(MyLayer, self).__init__(**kwargs)
 
  def build(self, input_shape):
    # Create a trainable weight variable for this layer.
    self.kernel = self.add_weight(name='kernel', 
                   shape=(input_shape[1], self.output_dim),
                   initializer='uniform',
                   trainable=True)
    super(MyLayer, self).build(input_shape) # Be sure to call this at the end
 
  def call(self, x):
    return K.dot(x, self.kernel)
 
  def compute_output_shape(self, input_shape):
    return (input_shape[0], self.output_dim)

如果你遇到:

<module> from keras.engine.base_layer import InputSpec ModuleNotFoundError: No module named 'keras.engine.base_layer'

不妨试试另一种引入!

补充知识:Keras自定义损失函数在场景分类的使用

在做图像场景分类的过程中,需要自定义损失函数,遇到很多坑。Keras自带的损失函数都在losses.py文件中。(以下默认为分类处理)

#losses.py
#y_true是分类的标签,y_pred是分类中预测值(这里指,模型最后一层为softmax层,输出的是每个类别的预测值)
def mean_squared_error(y_true, y_pred):
  return K.mean(K.square(y_pred - y_true), axis=-1)
def mean_absolute_error(y_true, y_pred):
  return K.mean(K.abs(y_pred - y_true), axis=-1)
def mean_absolute_percentage_error(y_true, y_pred):
  diff = K.abs((y_true - y_pred) / K.clip(K.abs(y_true),K.epsilon(),None))
  return 100. * K.mean(diff, axis=-1)
def mean_squared_logarithmic_error(y_true, y_pred):
  first_log = K.log(K.clip(y_pred, K.epsilon(), None) + 1.)
  second_log = K.log(K.clip(y_true, K.epsilon(), None) + 1.)
  return K.mean(K.square(first_log - second_log), axis=-1)
def squared_hinge(y_true, y_pred):
  return K.mean(K.square(K.maximum(1. - y_true * y_pred, 0.)), axis=-1)

这里面简单的来说,y_true就是训练数据的标签,y_pred就是模型训练时经过softmax层的预测值。经过计算,得出损失值。

那么我们要新建损失函数totoal_loss,就要在本文件下,进行新建。

def get_loss(labels,features, alpha,lambda_c,lambda_g,num_classes):
  #由于涉及研究内容,详细代码不做公开
  return loss
#total_loss(y_true,y_pred),y_true代表标签(类别),y_pred代表模型的输出
#( 如果是模型中间层输出,即代表特征,如果模型输出是经过softmax就是代表分类预测值)
#其他有需要的参数也可以写在里面
def total_loss(y_true,y_pred):
    git_loss=get_loss(y_true,y_pred,alpha=0.5,lambda_c=0.001,lambda_g=0.001,num_classes=45)
    return git_loss

自定义损失函数写好之后,可以进行使用了。这里,我使用交叉熵损失函数和自定义损失函数一起使用。

#这里使用vgg16模型
model = VGG16(input_tensor=image_input, include_top=True,weights='imagenet')
model.summary()
#fc2层输出为特征
last_layer = model.get_layer('fc2').output
#获取特征
feature = last_layer
#softmax层输出为各类的预测值
out = Dense(num_classes,activation = 'softmax',name='predictions')(last_layer)
#该模型有一个输入image_input,两个输出out,feature
custom_vgg_model = Model(inputs = image_input, outputs = [feature,out])
custom_vgg_model.summary()
#优化器,梯度下降
sgd = optimizers.SGD(lr=learn_Rate,decay=decay_Rate,momentum=0.9,nesterov=True)
#这里面,刚才有两个输出,这里面使用两个损失函数,total_loss对应的是fc2层输出的特征
#categorical_crossentropy对应softmax层的损失函数
#loss_weights两个损失函数的权重
custom_vgg_model.compile(loss={'fc2': 'total_loss','predictions': "categorical_crossentropy"},
             loss_weights={'fc2': 1, 'predictions':1},optimizer= sgd,
                   metrics={'predictions': 'accuracy'})
#这里使用dummy1,dummy2做演示,为0
dummy1 = np.zeros((y_train.shape[0],4096))
dummy2 = np.zeros((y_test.shape[0],4096))
#模型的输入输出必须和model.fit()中x,y两个参数维度相同
#dummy1的维度和fc2层输出的feature维度相同,y_train和softmax层输出的预测值维度相同
#validation_data验证数据集也是如此,需要和输出层的维度相同
hist = custom_vgg_model.fit(x = X_train,y = {'fc2':dummy1,'predictions':y_train},batch_size=batch_Sizes,
                epochs=epoch_Times, verbose=1,validation_data=(X_test, {'fc2':dummy2,'predictions':y_test}))

写到这里差不多就可以了,不够详细,以后再做补充。

以上这篇解决Keras 自定义层时遇到版本的问题就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
python备份文件的脚本
Aug 11 Python
举例详解Python中threading模块的几个常用方法
Jun 18 Python
在Django框架中编写Contact表单的教程
Jul 17 Python
python编程之requests在网络请求中添加cookies参数方法详解
Oct 25 Python
python读取LMDB中图像的方法
Jul 02 Python
浅谈pytorch和Numpy的区别以及相互转换方法
Jul 26 Python
Tensorflow 实现修改张量特定元素的值方法
Jul 30 Python
python输出决策树图形的例子
Aug 09 Python
Tensorflow限制CPU个数实例
Feb 06 Python
详解python如何引用包package
Jun 07 Python
python 实现一个简单的线性回归案例
Dec 17 Python
Pycharm制作搞怪弹窗的实现代码
Feb 19 Python
Keras实现支持masking的Flatten层代码
Jun 16 #Python
Keras自定义实现带masking的meanpooling层方式
Jun 16 #Python
浅谈keras 的抽象后端(from keras import backend as K)
Jun 16 #Python
记录模型训练时loss值的变化情况
Jun 16 #Python
python实现批量转换图片为黑白
Jun 16 #Python
在keras中实现查看其训练loss值
Jun 16 #Python
安装python3.7编译器后如何正确安装opnecv的方法详解
Jun 16 #Python
You might like
PHP简洁函数小结
2011/08/12 PHP
linux实现php定时执行cron任务详解
2013/12/24 PHP
如何使用Gitblog和Markdown建自己的博客
2015/07/31 PHP
理解 JavaScript 预解析
2009/10/25 Javascript
写JQuery插件的基本知识
2013/11/25 Javascript
jquery通过visible来判断标签是否显示或隐藏
2014/05/08 Javascript
通过js为元素添加多项样式,浏览器全兼容写法
2014/08/30 Javascript
JavaScript删除指定子元素代码实例
2015/01/13 Javascript
JavaScript:Date类型全面解析
2016/05/19 Javascript
jQuery实现带延时功能的水平多级菜单效果【附demo源码下载】
2016/09/21 Javascript
Node.js中如何合并两个复杂对象详解
2016/12/31 Javascript
利用百度地图API获取当前位置信息的实例
2017/11/06 Javascript
vue单页应用加百度统计代码(亲测有效)
2018/01/31 Javascript
React学习笔记之高阶组件应用
2018/06/02 Javascript
使用nvm和nrm优化node.js工作流的方法
2019/01/17 Javascript
jQuery实现input输入框获取焦点与失去焦点时提示的消失与显示功能示例
2019/05/27 jQuery
微信小程序canvas分享海报功能
2019/10/31 Javascript
JavaScrip如果基于url实现图片下载
2020/07/03 Javascript
Python操作MongoDB详解及实例
2017/05/18 Python
python使用sklearn实现决策树的方法示例
2019/09/12 Python
线程安全及Python中的GIL原理分析
2019/10/29 Python
python用Configobj模块读取配置文件
2020/09/26 Python
python与idea的集成的实现
2020/11/20 Python
HTML5中的Scoped属性使用实例
2014/04/23 HTML / CSS
html5关于外链嵌入页面通信问题(postMessage解决跨域通信)
2020/07/20 HTML / CSS
英国皇家造币厂:The Royal Mint
2018/10/05 全球购物
介绍一下游标
2012/01/10 面试题
会计系毕业个人自荐信格式
2013/09/23 职场文书
幼儿园小班植树节活动方案
2014/03/04 职场文书
交通事故委托书范本(2篇)
2014/09/21 职场文书
向国旗敬礼活动小结
2014/09/27 职场文书
基层工作经历证明
2015/06/19 职场文书
2015中秋节晚会开场白
2015/07/30 职场文书
客户答谢会致辞
2015/07/30 职场文书
Java用自带的Image IO给图片添加水印
2021/06/15 Java/Android
世界十大儿童漫画书排名,法国国宝漫画排第五,第二是轰动日本连环
2022/03/18 欧美动漫