解决Keras 自定义层时遇到版本的问题


Posted in Python onJune 16, 2020

在2.2.0版本前,

from keras import backend as K
from keras.engine.topology import Layer
 
class MyLayer(Layer):
 
  def __init__(self, output_dim, **kwargs):
    self.output_dim = output_dim
    super(MyLayer, self).__init__(**kwargs)
 
  def build(self, input_shape):
    # 为该层创建一个可训练的权重
    self.kernel = self.add_weight(name='kernel', 
                   shape=(input_shape[1], self.output_dim),
                   initializer='uniform',
                   trainable=True)
    super(MyLayer, self).build(input_shape) # 一定要在最后调用它
 
  def call(self, x):
    return K.dot(x, self.kernel)
 
  def compute_output_shape(self, input_shape):
    return (input_shape[0], self.output_dim)

2.2.0 版本时:

from keras import backend as K
from keras.layers import Layer
 
class MyLayer(Layer):
 
  def __init__(self, output_dim, **kwargs):
    self.output_dim = output_dim
    super(MyLayer, self).__init__(**kwargs)
 
  def build(self, input_shape):
    # Create a trainable weight variable for this layer.
    self.kernel = self.add_weight(name='kernel', 
                   shape=(input_shape[1], self.output_dim),
                   initializer='uniform',
                   trainable=True)
    super(MyLayer, self).build(input_shape) # Be sure to call this at the end
 
  def call(self, x):
    return K.dot(x, self.kernel)
 
  def compute_output_shape(self, input_shape):
    return (input_shape[0], self.output_dim)

如果你遇到:

<module> from keras.engine.base_layer import InputSpec ModuleNotFoundError: No module named 'keras.engine.base_layer'

不妨试试另一种引入!

补充知识:Keras自定义损失函数在场景分类的使用

在做图像场景分类的过程中,需要自定义损失函数,遇到很多坑。Keras自带的损失函数都在losses.py文件中。(以下默认为分类处理)

#losses.py
#y_true是分类的标签,y_pred是分类中预测值(这里指,模型最后一层为softmax层,输出的是每个类别的预测值)
def mean_squared_error(y_true, y_pred):
  return K.mean(K.square(y_pred - y_true), axis=-1)
def mean_absolute_error(y_true, y_pred):
  return K.mean(K.abs(y_pred - y_true), axis=-1)
def mean_absolute_percentage_error(y_true, y_pred):
  diff = K.abs((y_true - y_pred) / K.clip(K.abs(y_true),K.epsilon(),None))
  return 100. * K.mean(diff, axis=-1)
def mean_squared_logarithmic_error(y_true, y_pred):
  first_log = K.log(K.clip(y_pred, K.epsilon(), None) + 1.)
  second_log = K.log(K.clip(y_true, K.epsilon(), None) + 1.)
  return K.mean(K.square(first_log - second_log), axis=-1)
def squared_hinge(y_true, y_pred):
  return K.mean(K.square(K.maximum(1. - y_true * y_pred, 0.)), axis=-1)

这里面简单的来说,y_true就是训练数据的标签,y_pred就是模型训练时经过softmax层的预测值。经过计算,得出损失值。

那么我们要新建损失函数totoal_loss,就要在本文件下,进行新建。

def get_loss(labels,features, alpha,lambda_c,lambda_g,num_classes):
  #由于涉及研究内容,详细代码不做公开
  return loss
#total_loss(y_true,y_pred),y_true代表标签(类别),y_pred代表模型的输出
#( 如果是模型中间层输出,即代表特征,如果模型输出是经过softmax就是代表分类预测值)
#其他有需要的参数也可以写在里面
def total_loss(y_true,y_pred):
    git_loss=get_loss(y_true,y_pred,alpha=0.5,lambda_c=0.001,lambda_g=0.001,num_classes=45)
    return git_loss

自定义损失函数写好之后,可以进行使用了。这里,我使用交叉熵损失函数和自定义损失函数一起使用。

#这里使用vgg16模型
model = VGG16(input_tensor=image_input, include_top=True,weights='imagenet')
model.summary()
#fc2层输出为特征
last_layer = model.get_layer('fc2').output
#获取特征
feature = last_layer
#softmax层输出为各类的预测值
out = Dense(num_classes,activation = 'softmax',name='predictions')(last_layer)
#该模型有一个输入image_input,两个输出out,feature
custom_vgg_model = Model(inputs = image_input, outputs = [feature,out])
custom_vgg_model.summary()
#优化器,梯度下降
sgd = optimizers.SGD(lr=learn_Rate,decay=decay_Rate,momentum=0.9,nesterov=True)
#这里面,刚才有两个输出,这里面使用两个损失函数,total_loss对应的是fc2层输出的特征
#categorical_crossentropy对应softmax层的损失函数
#loss_weights两个损失函数的权重
custom_vgg_model.compile(loss={'fc2': 'total_loss','predictions': "categorical_crossentropy"},
             loss_weights={'fc2': 1, 'predictions':1},optimizer= sgd,
                   metrics={'predictions': 'accuracy'})
#这里使用dummy1,dummy2做演示,为0
dummy1 = np.zeros((y_train.shape[0],4096))
dummy2 = np.zeros((y_test.shape[0],4096))
#模型的输入输出必须和model.fit()中x,y两个参数维度相同
#dummy1的维度和fc2层输出的feature维度相同,y_train和softmax层输出的预测值维度相同
#validation_data验证数据集也是如此,需要和输出层的维度相同
hist = custom_vgg_model.fit(x = X_train,y = {'fc2':dummy1,'predictions':y_train},batch_size=batch_Sizes,
                epochs=epoch_Times, verbose=1,validation_data=(X_test, {'fc2':dummy2,'predictions':y_test}))

写到这里差不多就可以了,不够详细,以后再做补充。

以上这篇解决Keras 自定义层时遇到版本的问题就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
在Python中操作字符串之replace()方法的使用
May 19 Python
pycharm下查看python的变量类型和变量内容的方法
Jun 26 Python
通过python将大量文件按修改时间分类的方法
Oct 17 Python
python 构造三维全零数组的方法
Nov 12 Python
Django ManyToManyField 跨越中间表查询的方法
Dec 18 Python
python消费kafka数据批量插入到es的方法
Dec 27 Python
Python利用scapy实现ARP欺骗的方法
Jul 23 Python
Python object类中的特殊方法代码讲解
Mar 06 Python
详解Python中namedtuple的使用
Apr 27 Python
DataFrame 数据合并实现(merge,join,concat)
Jun 14 Python
详解python中的lambda与sorted函数
Sep 04 Python
python使用pywinauto驱动微信客户端实现公众号爬虫
May 19 Python
Keras实现支持masking的Flatten层代码
Jun 16 #Python
Keras自定义实现带masking的meanpooling层方式
Jun 16 #Python
浅谈keras 的抽象后端(from keras import backend as K)
Jun 16 #Python
记录模型训练时loss值的变化情况
Jun 16 #Python
python实现批量转换图片为黑白
Jun 16 #Python
在keras中实现查看其训练loss值
Jun 16 #Python
安装python3.7编译器后如何正确安装opnecv的方法详解
Jun 16 #Python
You might like
Zend Framework入门知识点小结
2016/03/19 PHP
Thinkphp连表查询及数据导出方法示例
2016/10/15 PHP
记录一次排查PHP脚本执行卡住的问题
2016/12/27 PHP
PDO::query讲解
2019/01/29 PHP
ExtJs中简单的登录界面制作方法
2010/08/19 Javascript
深入理解JavaScript系列(19):求值策略(Evaluation strategy)详解
2015/03/05 Javascript
Node.js 异步编程之 Callback介绍(一)
2015/03/30 Javascript
js实现仿京东2级菜单效果(带延时功能)
2015/08/27 Javascript
Javascript中replace()小结
2015/09/30 Javascript
jQuery zTree加载树形菜单功能
2016/02/25 Javascript
jQuery EasyUI Pagination实现分页的常用方法
2016/05/21 Javascript
深入理解vue-loader如何使用
2017/06/06 Javascript
在Vue项目中使用jsencrypt.js对数据进行加密传输的方法
2019/04/17 Javascript
jQuery 动画与停止动画效果实例详解
2020/05/19 jQuery
echarts浮动显示单位的实现方法示例
2020/12/04 Javascript
详解Vue3.0 + TypeScript + Vite初体验
2021/02/22 Vue.js
教你安装python Django(图文)
2013/11/04 Python
酷! 程序员用Python带你玩转冲顶大会
2018/01/17 Python
Python爬虫框架Scrapy常用命令总结
2018/07/26 Python
Python中turtle库的使用实例
2019/09/09 Python
python网络编程之五子棋游戏
2020/05/14 Python
Python手动或自动协程操作方法解析
2020/06/22 Python
python中tkinter窗口位置\坐标\大小等实现示例
2020/07/09 Python
英国著名的药妆网站:Escentual
2016/07/29 全球购物
Shein英国:女性时尚网上商店
2019/04/10 全球购物
捷克街头、运动和滑板一站式商店:BoardStar.cz
2019/10/06 全球购物
法国在线药房:Shop Pharmacie
2019/11/26 全球购物
优秀的个人求职信范文
2014/05/09 职场文书
生物工程专业求职信
2014/09/03 职场文书
营销与策划实训报告
2014/11/05 职场文书
2014年保育员工作总结
2014/12/02 职场文书
大学生毕业评语
2014/12/31 职场文书
交通安全主题班会
2015/08/12 职场文书
环保主题班会教案
2015/08/13 职场文书
掌握一个领域知识,高效学习必备方法
2019/08/08 职场文书
Redis sentinel哨兵集群的实现步骤
2022/07/15 Redis