解决Keras 自定义层时遇到版本的问题


Posted in Python onJune 16, 2020

在2.2.0版本前,

from keras import backend as K
from keras.engine.topology import Layer
 
class MyLayer(Layer):
 
  def __init__(self, output_dim, **kwargs):
    self.output_dim = output_dim
    super(MyLayer, self).__init__(**kwargs)
 
  def build(self, input_shape):
    # 为该层创建一个可训练的权重
    self.kernel = self.add_weight(name='kernel', 
                   shape=(input_shape[1], self.output_dim),
                   initializer='uniform',
                   trainable=True)
    super(MyLayer, self).build(input_shape) # 一定要在最后调用它
 
  def call(self, x):
    return K.dot(x, self.kernel)
 
  def compute_output_shape(self, input_shape):
    return (input_shape[0], self.output_dim)

2.2.0 版本时:

from keras import backend as K
from keras.layers import Layer
 
class MyLayer(Layer):
 
  def __init__(self, output_dim, **kwargs):
    self.output_dim = output_dim
    super(MyLayer, self).__init__(**kwargs)
 
  def build(self, input_shape):
    # Create a trainable weight variable for this layer.
    self.kernel = self.add_weight(name='kernel', 
                   shape=(input_shape[1], self.output_dim),
                   initializer='uniform',
                   trainable=True)
    super(MyLayer, self).build(input_shape) # Be sure to call this at the end
 
  def call(self, x):
    return K.dot(x, self.kernel)
 
  def compute_output_shape(self, input_shape):
    return (input_shape[0], self.output_dim)

如果你遇到:

<module> from keras.engine.base_layer import InputSpec ModuleNotFoundError: No module named 'keras.engine.base_layer'

不妨试试另一种引入!

补充知识:Keras自定义损失函数在场景分类的使用

在做图像场景分类的过程中,需要自定义损失函数,遇到很多坑。Keras自带的损失函数都在losses.py文件中。(以下默认为分类处理)

#losses.py
#y_true是分类的标签,y_pred是分类中预测值(这里指,模型最后一层为softmax层,输出的是每个类别的预测值)
def mean_squared_error(y_true, y_pred):
  return K.mean(K.square(y_pred - y_true), axis=-1)
def mean_absolute_error(y_true, y_pred):
  return K.mean(K.abs(y_pred - y_true), axis=-1)
def mean_absolute_percentage_error(y_true, y_pred):
  diff = K.abs((y_true - y_pred) / K.clip(K.abs(y_true),K.epsilon(),None))
  return 100. * K.mean(diff, axis=-1)
def mean_squared_logarithmic_error(y_true, y_pred):
  first_log = K.log(K.clip(y_pred, K.epsilon(), None) + 1.)
  second_log = K.log(K.clip(y_true, K.epsilon(), None) + 1.)
  return K.mean(K.square(first_log - second_log), axis=-1)
def squared_hinge(y_true, y_pred):
  return K.mean(K.square(K.maximum(1. - y_true * y_pred, 0.)), axis=-1)

这里面简单的来说,y_true就是训练数据的标签,y_pred就是模型训练时经过softmax层的预测值。经过计算,得出损失值。

那么我们要新建损失函数totoal_loss,就要在本文件下,进行新建。

def get_loss(labels,features, alpha,lambda_c,lambda_g,num_classes):
  #由于涉及研究内容,详细代码不做公开
  return loss
#total_loss(y_true,y_pred),y_true代表标签(类别),y_pred代表模型的输出
#( 如果是模型中间层输出,即代表特征,如果模型输出是经过softmax就是代表分类预测值)
#其他有需要的参数也可以写在里面
def total_loss(y_true,y_pred):
    git_loss=get_loss(y_true,y_pred,alpha=0.5,lambda_c=0.001,lambda_g=0.001,num_classes=45)
    return git_loss

自定义损失函数写好之后,可以进行使用了。这里,我使用交叉熵损失函数和自定义损失函数一起使用。

#这里使用vgg16模型
model = VGG16(input_tensor=image_input, include_top=True,weights='imagenet')
model.summary()
#fc2层输出为特征
last_layer = model.get_layer('fc2').output
#获取特征
feature = last_layer
#softmax层输出为各类的预测值
out = Dense(num_classes,activation = 'softmax',name='predictions')(last_layer)
#该模型有一个输入image_input,两个输出out,feature
custom_vgg_model = Model(inputs = image_input, outputs = [feature,out])
custom_vgg_model.summary()
#优化器,梯度下降
sgd = optimizers.SGD(lr=learn_Rate,decay=decay_Rate,momentum=0.9,nesterov=True)
#这里面,刚才有两个输出,这里面使用两个损失函数,total_loss对应的是fc2层输出的特征
#categorical_crossentropy对应softmax层的损失函数
#loss_weights两个损失函数的权重
custom_vgg_model.compile(loss={'fc2': 'total_loss','predictions': "categorical_crossentropy"},
             loss_weights={'fc2': 1, 'predictions':1},optimizer= sgd,
                   metrics={'predictions': 'accuracy'})
#这里使用dummy1,dummy2做演示,为0
dummy1 = np.zeros((y_train.shape[0],4096))
dummy2 = np.zeros((y_test.shape[0],4096))
#模型的输入输出必须和model.fit()中x,y两个参数维度相同
#dummy1的维度和fc2层输出的feature维度相同,y_train和softmax层输出的预测值维度相同
#validation_data验证数据集也是如此,需要和输出层的维度相同
hist = custom_vgg_model.fit(x = X_train,y = {'fc2':dummy1,'predictions':y_train},batch_size=batch_Sizes,
                epochs=epoch_Times, verbose=1,validation_data=(X_test, {'fc2':dummy2,'predictions':y_test}))

写到这里差不多就可以了,不够详细,以后再做补充。

以上这篇解决Keras 自定义层时遇到版本的问题就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
python中使用正则表达式的连接符示例代码
Oct 10 Python
Python进阶学习之特殊方法实例详析
Dec 01 Python
Python实现的json文件读取及中文乱码显示问题解决方法
Aug 06 Python
使用Selenium破解新浪微博的四宫格验证码
Oct 19 Python
python内置数据类型之列表操作
Nov 12 Python
使用pycharm设置控制台不换行的操作方法
Jan 19 Python
给你一面国旗 教你用python画中国国旗
Sep 24 Python
wxPython实现分隔窗口
Nov 19 Python
python判断无向图环是否存在的示例
Nov 22 Python
使用Python脚本从文件读取数据代码实例
Jan 19 Python
pycharm 实现光标快速移动到括号外或行尾的操作
Feb 05 Python
python绘制云雨图raincloud plot
Aug 05 Python
Keras实现支持masking的Flatten层代码
Jun 16 #Python
Keras自定义实现带masking的meanpooling层方式
Jun 16 #Python
浅谈keras 的抽象后端(from keras import backend as K)
Jun 16 #Python
记录模型训练时loss值的变化情况
Jun 16 #Python
python实现批量转换图片为黑白
Jun 16 #Python
在keras中实现查看其训练loss值
Jun 16 #Python
安装python3.7编译器后如何正确安装opnecv的方法详解
Jun 16 #Python
You might like
PHP5函数小全(分享)
2013/06/06 PHP
ThinkPHP模板自定义标签使用方法
2014/06/26 PHP
php判断对象是派生自哪个类的方法
2015/06/20 PHP
php微信开发之批量生成带参数的二维码
2016/06/26 PHP
cakephp常见知识点汇总
2017/02/24 PHP
PHP实现向关联数组指定的Key之前插入元素的方法
2017/06/06 PHP
thinkphp5框架实现的自定义扩展类操作示例
2019/05/16 PHP
关于laravel框架中的常用目录路径函数
2019/10/23 PHP
分析 JavaScript 中令人困惑的变量赋值
2007/08/13 Javascript
javascript 操作cookies及正确使用cookies的属性
2009/10/15 Javascript
JQuery 选择器、过滤器介绍
2011/02/14 Javascript
用JQuery在网页中实现分隔条功能的代码
2012/08/09 Javascript
js带前后翻页的图片切换效果代码分享
2015/09/08 Javascript
14 个折磨人的 JavaScript 面试题
2016/08/08 Javascript
基于jQuery实现歌词滚动版音乐播放器的代码
2016/09/17 Javascript
vue2.0+koa2+mongodb实现注册登录
2018/04/10 Javascript
vscode下的vue文件格式化问题
2018/11/28 Javascript
jquery获取file表单选择文件的路径、名字、大小、类型
2019/01/18 jQuery
JavaScript 作用域scope简单汇总
2019/10/23 Javascript
JS控制GIF图片的停止与显示
2019/10/24 Javascript
JS如何实现动态添加的元素绑定事件
2019/11/12 Javascript
js canvas实现五子棋小游戏
2021/01/22 Javascript
通过pycharm使用git的步骤(图文详解)
2019/06/13 Python
python-django中的APPEND_SLASH实现方法
2019/06/21 Python
解决Django删除migrations文件夹中的文件后出现的异常问题
2019/08/31 Python
Win系统PyQt5安装和使用教程
2019/12/25 Python
Python函数基本使用原理详解
2020/03/19 Python
html5启动原生APP总结
2020/07/03 HTML / CSS
使用HTML5做的导航条详细步骤
2020/10/19 HTML / CSS
新闻记者实习自我鉴定
2013/09/19 职场文书
电气技术员岗位职责
2013/11/19 职场文书
大学自主招生推荐信
2014/05/10 职场文书
师德师风的心得体会
2014/09/02 职场文书
消防安全培训工作总结
2015/10/23 职场文书
Nginx服务器如何设置url链接
2021/03/31 Servers
Python趣味实战之手把手教你实现举牌小人生成器
2021/06/07 Python