解决Keras 自定义层时遇到版本的问题


Posted in Python onJune 16, 2020

在2.2.0版本前,

from keras import backend as K
from keras.engine.topology import Layer
 
class MyLayer(Layer):
 
  def __init__(self, output_dim, **kwargs):
    self.output_dim = output_dim
    super(MyLayer, self).__init__(**kwargs)
 
  def build(self, input_shape):
    # 为该层创建一个可训练的权重
    self.kernel = self.add_weight(name='kernel', 
                   shape=(input_shape[1], self.output_dim),
                   initializer='uniform',
                   trainable=True)
    super(MyLayer, self).build(input_shape) # 一定要在最后调用它
 
  def call(self, x):
    return K.dot(x, self.kernel)
 
  def compute_output_shape(self, input_shape):
    return (input_shape[0], self.output_dim)

2.2.0 版本时:

from keras import backend as K
from keras.layers import Layer
 
class MyLayer(Layer):
 
  def __init__(self, output_dim, **kwargs):
    self.output_dim = output_dim
    super(MyLayer, self).__init__(**kwargs)
 
  def build(self, input_shape):
    # Create a trainable weight variable for this layer.
    self.kernel = self.add_weight(name='kernel', 
                   shape=(input_shape[1], self.output_dim),
                   initializer='uniform',
                   trainable=True)
    super(MyLayer, self).build(input_shape) # Be sure to call this at the end
 
  def call(self, x):
    return K.dot(x, self.kernel)
 
  def compute_output_shape(self, input_shape):
    return (input_shape[0], self.output_dim)

如果你遇到:

<module> from keras.engine.base_layer import InputSpec ModuleNotFoundError: No module named 'keras.engine.base_layer'

不妨试试另一种引入!

补充知识:Keras自定义损失函数在场景分类的使用

在做图像场景分类的过程中,需要自定义损失函数,遇到很多坑。Keras自带的损失函数都在losses.py文件中。(以下默认为分类处理)

#losses.py
#y_true是分类的标签,y_pred是分类中预测值(这里指,模型最后一层为softmax层,输出的是每个类别的预测值)
def mean_squared_error(y_true, y_pred):
  return K.mean(K.square(y_pred - y_true), axis=-1)
def mean_absolute_error(y_true, y_pred):
  return K.mean(K.abs(y_pred - y_true), axis=-1)
def mean_absolute_percentage_error(y_true, y_pred):
  diff = K.abs((y_true - y_pred) / K.clip(K.abs(y_true),K.epsilon(),None))
  return 100. * K.mean(diff, axis=-1)
def mean_squared_logarithmic_error(y_true, y_pred):
  first_log = K.log(K.clip(y_pred, K.epsilon(), None) + 1.)
  second_log = K.log(K.clip(y_true, K.epsilon(), None) + 1.)
  return K.mean(K.square(first_log - second_log), axis=-1)
def squared_hinge(y_true, y_pred):
  return K.mean(K.square(K.maximum(1. - y_true * y_pred, 0.)), axis=-1)

这里面简单的来说,y_true就是训练数据的标签,y_pred就是模型训练时经过softmax层的预测值。经过计算,得出损失值。

那么我们要新建损失函数totoal_loss,就要在本文件下,进行新建。

def get_loss(labels,features, alpha,lambda_c,lambda_g,num_classes):
  #由于涉及研究内容,详细代码不做公开
  return loss
#total_loss(y_true,y_pred),y_true代表标签(类别),y_pred代表模型的输出
#( 如果是模型中间层输出,即代表特征,如果模型输出是经过softmax就是代表分类预测值)
#其他有需要的参数也可以写在里面
def total_loss(y_true,y_pred):
    git_loss=get_loss(y_true,y_pred,alpha=0.5,lambda_c=0.001,lambda_g=0.001,num_classes=45)
    return git_loss

自定义损失函数写好之后,可以进行使用了。这里,我使用交叉熵损失函数和自定义损失函数一起使用。

#这里使用vgg16模型
model = VGG16(input_tensor=image_input, include_top=True,weights='imagenet')
model.summary()
#fc2层输出为特征
last_layer = model.get_layer('fc2').output
#获取特征
feature = last_layer
#softmax层输出为各类的预测值
out = Dense(num_classes,activation = 'softmax',name='predictions')(last_layer)
#该模型有一个输入image_input,两个输出out,feature
custom_vgg_model = Model(inputs = image_input, outputs = [feature,out])
custom_vgg_model.summary()
#优化器,梯度下降
sgd = optimizers.SGD(lr=learn_Rate,decay=decay_Rate,momentum=0.9,nesterov=True)
#这里面,刚才有两个输出,这里面使用两个损失函数,total_loss对应的是fc2层输出的特征
#categorical_crossentropy对应softmax层的损失函数
#loss_weights两个损失函数的权重
custom_vgg_model.compile(loss={'fc2': 'total_loss','predictions': "categorical_crossentropy"},
             loss_weights={'fc2': 1, 'predictions':1},optimizer= sgd,
                   metrics={'predictions': 'accuracy'})
#这里使用dummy1,dummy2做演示,为0
dummy1 = np.zeros((y_train.shape[0],4096))
dummy2 = np.zeros((y_test.shape[0],4096))
#模型的输入输出必须和model.fit()中x,y两个参数维度相同
#dummy1的维度和fc2层输出的feature维度相同,y_train和softmax层输出的预测值维度相同
#validation_data验证数据集也是如此,需要和输出层的维度相同
hist = custom_vgg_model.fit(x = X_train,y = {'fc2':dummy1,'predictions':y_train},batch_size=batch_Sizes,
                epochs=epoch_Times, verbose=1,validation_data=(X_test, {'fc2':dummy2,'predictions':y_test}))

写到这里差不多就可以了,不够详细,以后再做补充。

以上这篇解决Keras 自定义层时遇到版本的问题就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
分享15个最受欢迎的Python开源框架
Jul 13 Python
python中遍历文件的3个方法
Sep 02 Python
Python爬取国外天气预报网站的方法
Jul 10 Python
浅谈python抛出异常、自定义异常, 传递异常
Jun 20 Python
Python 制作糗事百科爬虫实例
Sep 22 Python
Python 多核并行计算的示例代码
Nov 07 Python
Python切片操作实例分析
Mar 16 Python
django2+uwsgi+nginx上线部署到服务器Ubuntu16.04
Jun 26 Python
Python list列表中删除多个重复元素操作示例
Feb 27 Python
Python Multiprocessing多进程 使用tqdm显示进度条的实现
Aug 13 Python
python 字典item与iteritems的区别详解
Apr 25 Python
Python Process创建进程的2种方法详解
Jan 25 Python
Keras实现支持masking的Flatten层代码
Jun 16 #Python
Keras自定义实现带masking的meanpooling层方式
Jun 16 #Python
浅谈keras 的抽象后端(from keras import backend as K)
Jun 16 #Python
记录模型训练时loss值的变化情况
Jun 16 #Python
python实现批量转换图片为黑白
Jun 16 #Python
在keras中实现查看其训练loss值
Jun 16 #Python
安装python3.7编译器后如何正确安装opnecv的方法详解
Jun 16 #Python
You might like
ADODB类使用
2006/11/25 PHP
PHP实现Socket服务器的代码
2008/04/03 PHP
php动态生成版权所有信息的方法
2015/03/24 PHP
Yii框架参数化查询中IN查询只能查询一个的解决方法
2017/05/20 PHP
nodejs win7下安装方法
2012/05/24 NodeJs
javascript分页代码(当前页码居中)
2012/09/20 Javascript
jquery创建表格(自动增加表格)代码分享
2013/12/25 Javascript
node.js中的fs.appendFileSync方法使用说明
2014/12/17 Javascript
jQuery中element选择器用法实例
2014/12/29 Javascript
jquery制作图片时钟特效
2020/03/30 Javascript
非常棒的jQuery图片轮播效果
2016/04/17 Javascript
浅谈javascript中的加减时间
2016/07/12 Javascript
关于input全选反选恶心的异常情况
2016/07/24 Javascript
jquery实现图片列表鼠标移入微动
2016/12/01 Javascript
AngularJS中update两次出现$promise属性无法识别的解决方法
2017/01/05 Javascript
Java与JavaScript中判断两字符串是否相等的区别
2017/03/13 Javascript
vue中本地静态图片路径写法
2018/03/06 Javascript
基于PHP pthreads实现多线程代码实例
2020/06/24 Javascript
[38:51]2014 DOTA2国际邀请赛中国区预选赛 Orenda VS LGD-CDEC
2014/05/22 DOTA
Python2.5/2.6实用教程 入门基础篇
2009/11/29 Python
详解python 发送邮件实例代码
2016/12/22 Python
Python实现读取json文件到excel表
2017/11/18 Python
Python命名空间的本质和加载顺序
2018/12/17 Python
Python读取stdin方法实例
2019/05/24 Python
40个你可能不知道的Python技巧附代码
2020/01/29 Python
详解Python模块化编程与装饰器
2021/01/16 Python
意大利领先的奢侈品在线时装零售商:MCLABELS
2020/10/13 全球购物
英国最受欢迎的母婴精品品牌:JoJo Maman BéBé
2021/02/17 全球购物
介绍一下常见的木马种类
2014/11/15 面试题
自我鉴定四大框架
2014/01/17 职场文书
小学教师节活动方案
2014/01/31 职场文书
保密承诺书
2014/03/27 职场文书
质量保证书怎么写
2015/02/27 职场文书
golang http使用踩过的坑与填坑指南
2021/04/27 Golang
Django中的JWT身份验证的实现
2021/05/07 Python
JS 基本概念详细介绍
2021/10/16 Javascript