解决Keras 自定义层时遇到版本的问题


Posted in Python onJune 16, 2020

在2.2.0版本前,

from keras import backend as K
from keras.engine.topology import Layer
 
class MyLayer(Layer):
 
  def __init__(self, output_dim, **kwargs):
    self.output_dim = output_dim
    super(MyLayer, self).__init__(**kwargs)
 
  def build(self, input_shape):
    # 为该层创建一个可训练的权重
    self.kernel = self.add_weight(name='kernel', 
                   shape=(input_shape[1], self.output_dim),
                   initializer='uniform',
                   trainable=True)
    super(MyLayer, self).build(input_shape) # 一定要在最后调用它
 
  def call(self, x):
    return K.dot(x, self.kernel)
 
  def compute_output_shape(self, input_shape):
    return (input_shape[0], self.output_dim)

2.2.0 版本时:

from keras import backend as K
from keras.layers import Layer
 
class MyLayer(Layer):
 
  def __init__(self, output_dim, **kwargs):
    self.output_dim = output_dim
    super(MyLayer, self).__init__(**kwargs)
 
  def build(self, input_shape):
    # Create a trainable weight variable for this layer.
    self.kernel = self.add_weight(name='kernel', 
                   shape=(input_shape[1], self.output_dim),
                   initializer='uniform',
                   trainable=True)
    super(MyLayer, self).build(input_shape) # Be sure to call this at the end
 
  def call(self, x):
    return K.dot(x, self.kernel)
 
  def compute_output_shape(self, input_shape):
    return (input_shape[0], self.output_dim)

如果你遇到:

<module> from keras.engine.base_layer import InputSpec ModuleNotFoundError: No module named 'keras.engine.base_layer'

不妨试试另一种引入!

补充知识:Keras自定义损失函数在场景分类的使用

在做图像场景分类的过程中,需要自定义损失函数,遇到很多坑。Keras自带的损失函数都在losses.py文件中。(以下默认为分类处理)

#losses.py
#y_true是分类的标签,y_pred是分类中预测值(这里指,模型最后一层为softmax层,输出的是每个类别的预测值)
def mean_squared_error(y_true, y_pred):
  return K.mean(K.square(y_pred - y_true), axis=-1)
def mean_absolute_error(y_true, y_pred):
  return K.mean(K.abs(y_pred - y_true), axis=-1)
def mean_absolute_percentage_error(y_true, y_pred):
  diff = K.abs((y_true - y_pred) / K.clip(K.abs(y_true),K.epsilon(),None))
  return 100. * K.mean(diff, axis=-1)
def mean_squared_logarithmic_error(y_true, y_pred):
  first_log = K.log(K.clip(y_pred, K.epsilon(), None) + 1.)
  second_log = K.log(K.clip(y_true, K.epsilon(), None) + 1.)
  return K.mean(K.square(first_log - second_log), axis=-1)
def squared_hinge(y_true, y_pred):
  return K.mean(K.square(K.maximum(1. - y_true * y_pred, 0.)), axis=-1)

这里面简单的来说,y_true就是训练数据的标签,y_pred就是模型训练时经过softmax层的预测值。经过计算,得出损失值。

那么我们要新建损失函数totoal_loss,就要在本文件下,进行新建。

def get_loss(labels,features, alpha,lambda_c,lambda_g,num_classes):
  #由于涉及研究内容,详细代码不做公开
  return loss
#total_loss(y_true,y_pred),y_true代表标签(类别),y_pred代表模型的输出
#( 如果是模型中间层输出,即代表特征,如果模型输出是经过softmax就是代表分类预测值)
#其他有需要的参数也可以写在里面
def total_loss(y_true,y_pred):
    git_loss=get_loss(y_true,y_pred,alpha=0.5,lambda_c=0.001,lambda_g=0.001,num_classes=45)
    return git_loss

自定义损失函数写好之后,可以进行使用了。这里,我使用交叉熵损失函数和自定义损失函数一起使用。

#这里使用vgg16模型
model = VGG16(input_tensor=image_input, include_top=True,weights='imagenet')
model.summary()
#fc2层输出为特征
last_layer = model.get_layer('fc2').output
#获取特征
feature = last_layer
#softmax层输出为各类的预测值
out = Dense(num_classes,activation = 'softmax',name='predictions')(last_layer)
#该模型有一个输入image_input,两个输出out,feature
custom_vgg_model = Model(inputs = image_input, outputs = [feature,out])
custom_vgg_model.summary()
#优化器,梯度下降
sgd = optimizers.SGD(lr=learn_Rate,decay=decay_Rate,momentum=0.9,nesterov=True)
#这里面,刚才有两个输出,这里面使用两个损失函数,total_loss对应的是fc2层输出的特征
#categorical_crossentropy对应softmax层的损失函数
#loss_weights两个损失函数的权重
custom_vgg_model.compile(loss={'fc2': 'total_loss','predictions': "categorical_crossentropy"},
             loss_weights={'fc2': 1, 'predictions':1},optimizer= sgd,
                   metrics={'predictions': 'accuracy'})
#这里使用dummy1,dummy2做演示,为0
dummy1 = np.zeros((y_train.shape[0],4096))
dummy2 = np.zeros((y_test.shape[0],4096))
#模型的输入输出必须和model.fit()中x,y两个参数维度相同
#dummy1的维度和fc2层输出的feature维度相同,y_train和softmax层输出的预测值维度相同
#validation_data验证数据集也是如此,需要和输出层的维度相同
hist = custom_vgg_model.fit(x = X_train,y = {'fc2':dummy1,'predictions':y_train},batch_size=batch_Sizes,
                epochs=epoch_Times, verbose=1,validation_data=(X_test, {'fc2':dummy2,'predictions':y_test}))

写到这里差不多就可以了,不够详细,以后再做补充。

以上这篇解决Keras 自定义层时遇到版本的问题就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
Python实现partial改变方法默认参数
Aug 18 Python
python回调函数用法实例分析
May 09 Python
Python算法输出1-9数组形成的结果为100的所有运算式
Nov 03 Python
python中plot实现即时数据动态显示方法
Jun 22 Python
基于Python3.6+splinter实现自动抢火车票
Sep 25 Python
解决PyCharm不运行脚本,而是运行单元测试的问题
Jan 17 Python
Python使用matplotlib绘制Logistic曲线操作示例
Nov 28 Python
python数据预处理方式 :数据降维
Feb 24 Python
浅谈Python线程的同步互斥与死锁
Mar 22 Python
python中可以声明变量类型吗
Jun 18 Python
python 求两个向量的顺时针夹角操作
Mar 04 Python
python自动化之如何利用allure生成测试报告
May 02 Python
Keras实现支持masking的Flatten层代码
Jun 16 #Python
Keras自定义实现带masking的meanpooling层方式
Jun 16 #Python
浅谈keras 的抽象后端(from keras import backend as K)
Jun 16 #Python
记录模型训练时loss值的变化情况
Jun 16 #Python
python实现批量转换图片为黑白
Jun 16 #Python
在keras中实现查看其训练loss值
Jun 16 #Python
安装python3.7编译器后如何正确安装opnecv的方法详解
Jun 16 #Python
You might like
Destoon旺旺无法正常显示,点击提示“会员名不存在”的解决办法
2014/06/21 PHP
php从文件夹随机读取文件的方法
2015/06/01 PHP
PHP实现的获取文件mimes类型工具类示例
2018/04/08 PHP
extjs 学习笔记(三) 最基本的grid
2009/10/15 Javascript
JavaScript 对象的属性和方法4种不同的类型
2010/03/19 Javascript
javascript 实用的文字链提示框效果
2010/06/30 Javascript
基于jQuery的让非HTML5浏览器支持placeholder属性的代码
2011/05/24 Javascript
浅谈javascript中的作用域
2012/04/07 Javascript
nodejs读取memcache示例分享
2014/01/02 NodeJs
键盘上一张下一张兼容IE/google/firefox等浏览器
2014/01/28 Javascript
node.js中的buffer.length方法使用说明
2014/12/14 Javascript
JavaScript实现非常简单实用的下拉菜单效果
2015/08/27 Javascript
JS条形码(一维码)插件JsBarcode用法详解【编码类型、参数、属性】
2017/04/19 Javascript
完美解决axios在ie下的兼容性问题
2018/03/05 Javascript
vue地区选择组件教程详解
2018/05/04 Javascript
JS用最简单的方法实现四舍五入
2019/08/27 Javascript
Vue中的循环及修改差值表达式的方法
2019/08/29 Javascript
[54:30]Liquid vs Newbee 2019国际邀请赛小组赛 BO2 第二场 8.15
2019/08/16 DOTA
Python3实现抓取javascript动态生成的html网页功能示例
2017/08/22 Python
Python的SimpleHTTPServer模块用处及使用方法简介
2018/01/22 Python
python使用json序列化datetime类型实例解析
2018/02/11 Python
在python中pandas读文件,有中文字符的方法
2018/12/12 Python
Python 中Django验证码功能的实现代码
2019/06/20 Python
Python实现TCP探测目标服务路由轨迹的原理与方法详解
2019/09/04 Python
pytorch 获取tensor维度信息示例
2020/01/03 Python
python 回溯法模板详解
2020/02/26 Python
python爬取youtube视频的示例代码
2021/03/03 Python
CSS3 实现弹跳的小球动画
2020/10/26 HTML / CSS
马来西亚与新加坡长途巴士售票网站:BusOnlineTicket.com
2018/11/05 全球购物
介绍下Lucene建立索引的过程
2016/03/02 面试题
正规的求职信范文分享
2013/12/11 职场文书
迎新晚会主持词
2014/03/24 职场文书
歌唱比赛策划方案
2014/06/06 职场文书
励志演讲稿500字
2014/08/21 职场文书
活动经费申请报告
2015/05/15 职场文书
巾帼建功标兵先进事迹材料
2016/02/29 职场文书