解决Keras 自定义层时遇到版本的问题


Posted in Python onJune 16, 2020

在2.2.0版本前,

from keras import backend as K
from keras.engine.topology import Layer
 
class MyLayer(Layer):
 
  def __init__(self, output_dim, **kwargs):
    self.output_dim = output_dim
    super(MyLayer, self).__init__(**kwargs)
 
  def build(self, input_shape):
    # 为该层创建一个可训练的权重
    self.kernel = self.add_weight(name='kernel', 
                   shape=(input_shape[1], self.output_dim),
                   initializer='uniform',
                   trainable=True)
    super(MyLayer, self).build(input_shape) # 一定要在最后调用它
 
  def call(self, x):
    return K.dot(x, self.kernel)
 
  def compute_output_shape(self, input_shape):
    return (input_shape[0], self.output_dim)

2.2.0 版本时:

from keras import backend as K
from keras.layers import Layer
 
class MyLayer(Layer):
 
  def __init__(self, output_dim, **kwargs):
    self.output_dim = output_dim
    super(MyLayer, self).__init__(**kwargs)
 
  def build(self, input_shape):
    # Create a trainable weight variable for this layer.
    self.kernel = self.add_weight(name='kernel', 
                   shape=(input_shape[1], self.output_dim),
                   initializer='uniform',
                   trainable=True)
    super(MyLayer, self).build(input_shape) # Be sure to call this at the end
 
  def call(self, x):
    return K.dot(x, self.kernel)
 
  def compute_output_shape(self, input_shape):
    return (input_shape[0], self.output_dim)

如果你遇到:

<module> from keras.engine.base_layer import InputSpec ModuleNotFoundError: No module named 'keras.engine.base_layer'

不妨试试另一种引入!

补充知识:Keras自定义损失函数在场景分类的使用

在做图像场景分类的过程中,需要自定义损失函数,遇到很多坑。Keras自带的损失函数都在losses.py文件中。(以下默认为分类处理)

#losses.py
#y_true是分类的标签,y_pred是分类中预测值(这里指,模型最后一层为softmax层,输出的是每个类别的预测值)
def mean_squared_error(y_true, y_pred):
  return K.mean(K.square(y_pred - y_true), axis=-1)
def mean_absolute_error(y_true, y_pred):
  return K.mean(K.abs(y_pred - y_true), axis=-1)
def mean_absolute_percentage_error(y_true, y_pred):
  diff = K.abs((y_true - y_pred) / K.clip(K.abs(y_true),K.epsilon(),None))
  return 100. * K.mean(diff, axis=-1)
def mean_squared_logarithmic_error(y_true, y_pred):
  first_log = K.log(K.clip(y_pred, K.epsilon(), None) + 1.)
  second_log = K.log(K.clip(y_true, K.epsilon(), None) + 1.)
  return K.mean(K.square(first_log - second_log), axis=-1)
def squared_hinge(y_true, y_pred):
  return K.mean(K.square(K.maximum(1. - y_true * y_pred, 0.)), axis=-1)

这里面简单的来说,y_true就是训练数据的标签,y_pred就是模型训练时经过softmax层的预测值。经过计算,得出损失值。

那么我们要新建损失函数totoal_loss,就要在本文件下,进行新建。

def get_loss(labels,features, alpha,lambda_c,lambda_g,num_classes):
  #由于涉及研究内容,详细代码不做公开
  return loss
#total_loss(y_true,y_pred),y_true代表标签(类别),y_pred代表模型的输出
#( 如果是模型中间层输出,即代表特征,如果模型输出是经过softmax就是代表分类预测值)
#其他有需要的参数也可以写在里面
def total_loss(y_true,y_pred):
    git_loss=get_loss(y_true,y_pred,alpha=0.5,lambda_c=0.001,lambda_g=0.001,num_classes=45)
    return git_loss

自定义损失函数写好之后,可以进行使用了。这里,我使用交叉熵损失函数和自定义损失函数一起使用。

#这里使用vgg16模型
model = VGG16(input_tensor=image_input, include_top=True,weights='imagenet')
model.summary()
#fc2层输出为特征
last_layer = model.get_layer('fc2').output
#获取特征
feature = last_layer
#softmax层输出为各类的预测值
out = Dense(num_classes,activation = 'softmax',name='predictions')(last_layer)
#该模型有一个输入image_input,两个输出out,feature
custom_vgg_model = Model(inputs = image_input, outputs = [feature,out])
custom_vgg_model.summary()
#优化器,梯度下降
sgd = optimizers.SGD(lr=learn_Rate,decay=decay_Rate,momentum=0.9,nesterov=True)
#这里面,刚才有两个输出,这里面使用两个损失函数,total_loss对应的是fc2层输出的特征
#categorical_crossentropy对应softmax层的损失函数
#loss_weights两个损失函数的权重
custom_vgg_model.compile(loss={'fc2': 'total_loss','predictions': "categorical_crossentropy"},
             loss_weights={'fc2': 1, 'predictions':1},optimizer= sgd,
                   metrics={'predictions': 'accuracy'})
#这里使用dummy1,dummy2做演示,为0
dummy1 = np.zeros((y_train.shape[0],4096))
dummy2 = np.zeros((y_test.shape[0],4096))
#模型的输入输出必须和model.fit()中x,y两个参数维度相同
#dummy1的维度和fc2层输出的feature维度相同,y_train和softmax层输出的预测值维度相同
#validation_data验证数据集也是如此,需要和输出层的维度相同
hist = custom_vgg_model.fit(x = X_train,y = {'fc2':dummy1,'predictions':y_train},batch_size=batch_Sizes,
                epochs=epoch_Times, verbose=1,validation_data=(X_test, {'fc2':dummy2,'predictions':y_test}))

写到这里差不多就可以了,不够详细,以后再做补充。

以上这篇解决Keras 自定义层时遇到版本的问题就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
Python3使用PyQt5制作简单的画板/手写板实例
Oct 19 Python
python使用 HTMLTestRunner.py生成测试报告
Oct 20 Python
python中文分词教程之前向最大正向匹配算法详解
Nov 02 Python
Python实现的堆排序算法原理与用法实例分析
Nov 22 Python
Python使用folium excel绘制point
Jan 03 Python
Python2.7实现多进程下开发多线程示例
May 31 Python
pygame实现打字游戏
Feb 19 Python
简单了解为什么python函数后有多个括号
Dec 19 Python
python中使用paramiko模块并实现远程连接服务器执行上传下载功能
Feb 29 Python
使用pymysql查询数据库,把结果保存为列表并获取指定元素下标实例
May 15 Python
Python 如何定义匿名或内联函数
Aug 01 Python
浅谈python数据类型及其操作
May 25 Python
Keras实现支持masking的Flatten层代码
Jun 16 #Python
Keras自定义实现带masking的meanpooling层方式
Jun 16 #Python
浅谈keras 的抽象后端(from keras import backend as K)
Jun 16 #Python
记录模型训练时loss值的变化情况
Jun 16 #Python
python实现批量转换图片为黑白
Jun 16 #Python
在keras中实现查看其训练loss值
Jun 16 #Python
安装python3.7编译器后如何正确安装opnecv的方法详解
Jun 16 #Python
You might like
自动生成文章摘要的代码[PHP 版本]
2007/03/20 PHP
php根据操作系统转换文件名大小写的方法
2014/02/24 PHP
浅析iis7.5安装配置php环境
2015/05/10 PHP
php中bind_param()函数用法分析
2017/03/28 PHP
js 绑定带参数的事件以及手动触发事件
2010/04/27 Javascript
将两个div左右并列显示并实现点击标题切换内容
2013/10/22 Javascript
js子页面获取父页面数据示例
2014/05/15 Javascript
jquery插件qrcode在线生成二维码
2015/04/26 Javascript
js实现一个链接打开两个链接地址的方法
2015/05/12 Javascript
JavaScript获取并更改input标签name属性的方法
2015/07/02 Javascript
jquery选择器简述
2015/08/31 Javascript
jQuery+ajax实现文章点赞功能的方法
2015/12/31 Javascript
javascript实现dom元素可拖动
2016/03/21 Javascript
详解为Angular.js内置$http服务添加拦截器的方法
2016/12/20 Javascript
vuejs开发组件分享之H5图片上传、压缩及拍照旋转的问题处理
2017/03/06 Javascript
JS对象的深度克隆方法示例
2017/03/16 Javascript
JavaScript中最常用的10种代码简写技巧总结
2017/06/28 Javascript
手把手教你使用vue-cli脚手架(图文解析)
2017/11/08 Javascript
vue微信分享 vue实现当前页面分享其他页面
2017/12/02 Javascript
基于vue如何发布一个npm包的方法步骤
2019/05/15 Javascript
简单了解Javscript中兄弟ifream的方法调用
2019/06/17 Javascript
快速解决layui弹窗按enter键不停弹窗的问题
2019/09/18 Javascript
机器学习10大经典算法详解
2017/12/07 Python
python数字图像处理之骨架提取与分水岭算法
2018/04/27 Python
python write无法写入文件的解决方法
2019/01/23 Python
Python @property使用方法解析
2019/09/17 Python
python路径的写法及目录的获取方式
2019/12/26 Python
Keras - GPU ID 和显存占用设定步骤
2020/06/22 Python
Canvas图片分割效果的实现
2019/07/29 HTML / CSS
DJI大疆德国官方商城:大疆无人机
2018/09/01 全球购物
犹他州最古老的体育用品公司:Al’s
2020/12/18 全球购物
应届毕业生通用的自荐书范文
2014/02/07 职场文书
党校学习自我鉴定
2014/02/24 职场文书
整脏治乱工作简报
2015/07/21 职场文书
《梅花魂》教学反思
2016/02/18 职场文书
总结一些Java常用的加密算法
2021/06/11 Java/Android