解决Keras 自定义层时遇到版本的问题


Posted in Python onJune 16, 2020

在2.2.0版本前,

from keras import backend as K
from keras.engine.topology import Layer
 
class MyLayer(Layer):
 
  def __init__(self, output_dim, **kwargs):
    self.output_dim = output_dim
    super(MyLayer, self).__init__(**kwargs)
 
  def build(self, input_shape):
    # 为该层创建一个可训练的权重
    self.kernel = self.add_weight(name='kernel', 
                   shape=(input_shape[1], self.output_dim),
                   initializer='uniform',
                   trainable=True)
    super(MyLayer, self).build(input_shape) # 一定要在最后调用它
 
  def call(self, x):
    return K.dot(x, self.kernel)
 
  def compute_output_shape(self, input_shape):
    return (input_shape[0], self.output_dim)

2.2.0 版本时:

from keras import backend as K
from keras.layers import Layer
 
class MyLayer(Layer):
 
  def __init__(self, output_dim, **kwargs):
    self.output_dim = output_dim
    super(MyLayer, self).__init__(**kwargs)
 
  def build(self, input_shape):
    # Create a trainable weight variable for this layer.
    self.kernel = self.add_weight(name='kernel', 
                   shape=(input_shape[1], self.output_dim),
                   initializer='uniform',
                   trainable=True)
    super(MyLayer, self).build(input_shape) # Be sure to call this at the end
 
  def call(self, x):
    return K.dot(x, self.kernel)
 
  def compute_output_shape(self, input_shape):
    return (input_shape[0], self.output_dim)

如果你遇到:

<module> from keras.engine.base_layer import InputSpec ModuleNotFoundError: No module named 'keras.engine.base_layer'

不妨试试另一种引入!

补充知识:Keras自定义损失函数在场景分类的使用

在做图像场景分类的过程中,需要自定义损失函数,遇到很多坑。Keras自带的损失函数都在losses.py文件中。(以下默认为分类处理)

#losses.py
#y_true是分类的标签,y_pred是分类中预测值(这里指,模型最后一层为softmax层,输出的是每个类别的预测值)
def mean_squared_error(y_true, y_pred):
  return K.mean(K.square(y_pred - y_true), axis=-1)
def mean_absolute_error(y_true, y_pred):
  return K.mean(K.abs(y_pred - y_true), axis=-1)
def mean_absolute_percentage_error(y_true, y_pred):
  diff = K.abs((y_true - y_pred) / K.clip(K.abs(y_true),K.epsilon(),None))
  return 100. * K.mean(diff, axis=-1)
def mean_squared_logarithmic_error(y_true, y_pred):
  first_log = K.log(K.clip(y_pred, K.epsilon(), None) + 1.)
  second_log = K.log(K.clip(y_true, K.epsilon(), None) + 1.)
  return K.mean(K.square(first_log - second_log), axis=-1)
def squared_hinge(y_true, y_pred):
  return K.mean(K.square(K.maximum(1. - y_true * y_pred, 0.)), axis=-1)

这里面简单的来说,y_true就是训练数据的标签,y_pred就是模型训练时经过softmax层的预测值。经过计算,得出损失值。

那么我们要新建损失函数totoal_loss,就要在本文件下,进行新建。

def get_loss(labels,features, alpha,lambda_c,lambda_g,num_classes):
  #由于涉及研究内容,详细代码不做公开
  return loss
#total_loss(y_true,y_pred),y_true代表标签(类别),y_pred代表模型的输出
#( 如果是模型中间层输出,即代表特征,如果模型输出是经过softmax就是代表分类预测值)
#其他有需要的参数也可以写在里面
def total_loss(y_true,y_pred):
    git_loss=get_loss(y_true,y_pred,alpha=0.5,lambda_c=0.001,lambda_g=0.001,num_classes=45)
    return git_loss

自定义损失函数写好之后,可以进行使用了。这里,我使用交叉熵损失函数和自定义损失函数一起使用。

#这里使用vgg16模型
model = VGG16(input_tensor=image_input, include_top=True,weights='imagenet')
model.summary()
#fc2层输出为特征
last_layer = model.get_layer('fc2').output
#获取特征
feature = last_layer
#softmax层输出为各类的预测值
out = Dense(num_classes,activation = 'softmax',name='predictions')(last_layer)
#该模型有一个输入image_input,两个输出out,feature
custom_vgg_model = Model(inputs = image_input, outputs = [feature,out])
custom_vgg_model.summary()
#优化器,梯度下降
sgd = optimizers.SGD(lr=learn_Rate,decay=decay_Rate,momentum=0.9,nesterov=True)
#这里面,刚才有两个输出,这里面使用两个损失函数,total_loss对应的是fc2层输出的特征
#categorical_crossentropy对应softmax层的损失函数
#loss_weights两个损失函数的权重
custom_vgg_model.compile(loss={'fc2': 'total_loss','predictions': "categorical_crossentropy"},
             loss_weights={'fc2': 1, 'predictions':1},optimizer= sgd,
                   metrics={'predictions': 'accuracy'})
#这里使用dummy1,dummy2做演示,为0
dummy1 = np.zeros((y_train.shape[0],4096))
dummy2 = np.zeros((y_test.shape[0],4096))
#模型的输入输出必须和model.fit()中x,y两个参数维度相同
#dummy1的维度和fc2层输出的feature维度相同,y_train和softmax层输出的预测值维度相同
#validation_data验证数据集也是如此,需要和输出层的维度相同
hist = custom_vgg_model.fit(x = X_train,y = {'fc2':dummy1,'predictions':y_train},batch_size=batch_Sizes,
                epochs=epoch_Times, verbose=1,validation_data=(X_test, {'fc2':dummy2,'predictions':y_test}))

写到这里差不多就可以了,不够详细,以后再做补充。

以上这篇解决Keras 自定义层时遇到版本的问题就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
在Linux上安装Python的Flask框架和创建第一个app实例的教程
Mar 30 Python
Python的Django框架中settings文件的部署建议
May 30 Python
Python黑魔法@property装饰器的使用技巧解析
Jun 16 Python
Python自定义函数实现求两个数最大公约数、最小公倍数示例
May 21 Python
Python PyAutoGUI模块控制鼠标和键盘实现自动化任务详解
Sep 04 Python
Linux下通过python获取本机ip方法示例
Sep 06 Python
Python搭建代理IP池实现存储IP的方法
Oct 27 Python
QT5 Designer 打不开的问题及解决方法
Aug 20 Python
python 实现围棋游戏(纯tkinter gui)
Nov 13 Python
详解Python中openpyxl模块基本用法
Feb 23 Python
python实现腾讯滑块验证码识别
Apr 27 Python
python数据库批量插入数据的实现(executemany的使用)
Apr 30 Python
Keras实现支持masking的Flatten层代码
Jun 16 #Python
Keras自定义实现带masking的meanpooling层方式
Jun 16 #Python
浅谈keras 的抽象后端(from keras import backend as K)
Jun 16 #Python
记录模型训练时loss值的变化情况
Jun 16 #Python
python实现批量转换图片为黑白
Jun 16 #Python
在keras中实现查看其训练loss值
Jun 16 #Python
安装python3.7编译器后如何正确安装opnecv的方法详解
Jun 16 #Python
You might like
PHP简单实现断点续传下载的方法
2015/09/25 PHP
yii实现model添加默认值的方法(2种方法)
2016/01/06 PHP
php实现根据身份证获取精准年龄
2020/02/26 PHP
基于Jquery的淡入淡出的特效基础练习
2010/12/13 Javascript
jQuery EasyUI API 中文文档 - Draggable 可拖拽
2011/09/29 Javascript
php图像生成函数之间的区别分析
2012/12/06 Javascript
JS实现3D图片旋转展示效果代码
2015/09/22 Javascript
JavaScript代码因逗号不规范导致IE不兼容的问题
2016/02/25 Javascript
关于JSON与JSONP简单总结
2016/08/16 Javascript
angular ngClick阻止冒泡使用默认行为的方法
2016/11/03 Javascript
使用Ajax生成的Excel文件并下载的实例
2016/11/21 Javascript
Node.js学习之地址解析模块URL的使用详解
2017/09/28 Javascript
判断jQuery是否加载完成,没完成继续判断的解决方法
2017/12/06 jQuery
Vue中的异步组件函数实现代码
2018/07/20 Javascript
node.js 模块和其下载资源的镜像设置的方法
2018/09/06 Javascript
如何用RxJS实现Redux Form
2018/12/29 Javascript
详解vue-cli3多环境打包配置
2019/03/28 Javascript
解决Layui数据表格的宽高问题
2019/09/28 Javascript
jQuery操作元素的内容和样式完整实例分析
2020/01/10 jQuery
Vue简单封装axios之解决post请求后端接收不到参数问题
2020/02/16 Javascript
微信小程序实现列表的横向滑动方式
2020/07/15 Javascript
[00:34]TI7不朽珍藏III——纯金地穴编织者饰品展示
2017/07/15 DOTA
[02:55]2018DOTA2国际邀请赛勇士令状不朽珍藏Ⅲ饰品一览
2018/08/01 DOTA
[02:28]PWL开团时刻DAY3——Ink Ice与DeMonsTer之间的勾心斗角
2020/11/03 DOTA
Python中类的继承代码实例
2014/10/28 Python
在Python的web框架中配置app的教程
2015/04/30 Python
python实现发送邮件功能代码
2017/12/14 Python
解决python大批量读写.doc文件的问题
2018/05/08 Python
python pandas实现excel转为html格式的方法
2018/10/23 Python
python项目对接钉钉SDK的实现
2019/07/15 Python
First Aid Beauty官网:FAB急救面霜
2018/05/24 全球购物
四风自我剖析材料思想汇报
2014/10/01 职场文书
营销经理工作检讨书
2014/11/03 职场文书
2014年财政工作总结
2014/12/10 职场文书
离婚答辩状怎么写
2015/05/22 职场文书
SQL试题 使用窗口函数选出连续3天登录的用户
2022/04/24 Oracle