解决Keras 自定义层时遇到版本的问题


Posted in Python onJune 16, 2020

在2.2.0版本前,

from keras import backend as K
from keras.engine.topology import Layer
 
class MyLayer(Layer):
 
  def __init__(self, output_dim, **kwargs):
    self.output_dim = output_dim
    super(MyLayer, self).__init__(**kwargs)
 
  def build(self, input_shape):
    # 为该层创建一个可训练的权重
    self.kernel = self.add_weight(name='kernel', 
                   shape=(input_shape[1], self.output_dim),
                   initializer='uniform',
                   trainable=True)
    super(MyLayer, self).build(input_shape) # 一定要在最后调用它
 
  def call(self, x):
    return K.dot(x, self.kernel)
 
  def compute_output_shape(self, input_shape):
    return (input_shape[0], self.output_dim)

2.2.0 版本时:

from keras import backend as K
from keras.layers import Layer
 
class MyLayer(Layer):
 
  def __init__(self, output_dim, **kwargs):
    self.output_dim = output_dim
    super(MyLayer, self).__init__(**kwargs)
 
  def build(self, input_shape):
    # Create a trainable weight variable for this layer.
    self.kernel = self.add_weight(name='kernel', 
                   shape=(input_shape[1], self.output_dim),
                   initializer='uniform',
                   trainable=True)
    super(MyLayer, self).build(input_shape) # Be sure to call this at the end
 
  def call(self, x):
    return K.dot(x, self.kernel)
 
  def compute_output_shape(self, input_shape):
    return (input_shape[0], self.output_dim)

如果你遇到:

<module> from keras.engine.base_layer import InputSpec ModuleNotFoundError: No module named 'keras.engine.base_layer'

不妨试试另一种引入!

补充知识:Keras自定义损失函数在场景分类的使用

在做图像场景分类的过程中,需要自定义损失函数,遇到很多坑。Keras自带的损失函数都在losses.py文件中。(以下默认为分类处理)

#losses.py
#y_true是分类的标签,y_pred是分类中预测值(这里指,模型最后一层为softmax层,输出的是每个类别的预测值)
def mean_squared_error(y_true, y_pred):
  return K.mean(K.square(y_pred - y_true), axis=-1)
def mean_absolute_error(y_true, y_pred):
  return K.mean(K.abs(y_pred - y_true), axis=-1)
def mean_absolute_percentage_error(y_true, y_pred):
  diff = K.abs((y_true - y_pred) / K.clip(K.abs(y_true),K.epsilon(),None))
  return 100. * K.mean(diff, axis=-1)
def mean_squared_logarithmic_error(y_true, y_pred):
  first_log = K.log(K.clip(y_pred, K.epsilon(), None) + 1.)
  second_log = K.log(K.clip(y_true, K.epsilon(), None) + 1.)
  return K.mean(K.square(first_log - second_log), axis=-1)
def squared_hinge(y_true, y_pred):
  return K.mean(K.square(K.maximum(1. - y_true * y_pred, 0.)), axis=-1)

这里面简单的来说,y_true就是训练数据的标签,y_pred就是模型训练时经过softmax层的预测值。经过计算,得出损失值。

那么我们要新建损失函数totoal_loss,就要在本文件下,进行新建。

def get_loss(labels,features, alpha,lambda_c,lambda_g,num_classes):
  #由于涉及研究内容,详细代码不做公开
  return loss
#total_loss(y_true,y_pred),y_true代表标签(类别),y_pred代表模型的输出
#( 如果是模型中间层输出,即代表特征,如果模型输出是经过softmax就是代表分类预测值)
#其他有需要的参数也可以写在里面
def total_loss(y_true,y_pred):
    git_loss=get_loss(y_true,y_pred,alpha=0.5,lambda_c=0.001,lambda_g=0.001,num_classes=45)
    return git_loss

自定义损失函数写好之后,可以进行使用了。这里,我使用交叉熵损失函数和自定义损失函数一起使用。

#这里使用vgg16模型
model = VGG16(input_tensor=image_input, include_top=True,weights='imagenet')
model.summary()
#fc2层输出为特征
last_layer = model.get_layer('fc2').output
#获取特征
feature = last_layer
#softmax层输出为各类的预测值
out = Dense(num_classes,activation = 'softmax',name='predictions')(last_layer)
#该模型有一个输入image_input,两个输出out,feature
custom_vgg_model = Model(inputs = image_input, outputs = [feature,out])
custom_vgg_model.summary()
#优化器,梯度下降
sgd = optimizers.SGD(lr=learn_Rate,decay=decay_Rate,momentum=0.9,nesterov=True)
#这里面,刚才有两个输出,这里面使用两个损失函数,total_loss对应的是fc2层输出的特征
#categorical_crossentropy对应softmax层的损失函数
#loss_weights两个损失函数的权重
custom_vgg_model.compile(loss={'fc2': 'total_loss','predictions': "categorical_crossentropy"},
             loss_weights={'fc2': 1, 'predictions':1},optimizer= sgd,
                   metrics={'predictions': 'accuracy'})
#这里使用dummy1,dummy2做演示,为0
dummy1 = np.zeros((y_train.shape[0],4096))
dummy2 = np.zeros((y_test.shape[0],4096))
#模型的输入输出必须和model.fit()中x,y两个参数维度相同
#dummy1的维度和fc2层输出的feature维度相同,y_train和softmax层输出的预测值维度相同
#validation_data验证数据集也是如此,需要和输出层的维度相同
hist = custom_vgg_model.fit(x = X_train,y = {'fc2':dummy1,'predictions':y_train},batch_size=batch_Sizes,
                epochs=epoch_Times, verbose=1,validation_data=(X_test, {'fc2':dummy2,'predictions':y_test}))

写到这里差不多就可以了,不够详细,以后再做补充。

以上这篇解决Keras 自定义层时遇到版本的问题就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
python抓取京东商城手机列表url实例代码
Dec 18 Python
Python中不同进制的语法及转换方法分析
Jul 27 Python
python和shell获取文本内容的方法
Jun 05 Python
python构建基础的爬虫教学
Dec 23 Python
Python enumerate函数功能与用法示例
Mar 01 Python
PyQtGraph在pyqt中的应用及安装过程
Aug 04 Python
python使用Matplotlib改变坐标轴的默认位置
Oct 18 Python
TensorFlow——Checkpoint为模型添加检查点的实例
Jan 21 Python
基于nexus3配置Python仓库过程详解
Jun 15 Python
python打开音乐文件的实例方法
Jul 21 Python
神经网络训练采用gpu设置的方式
Mar 03 Python
用Python将库打包发布到pypi
Apr 13 Python
Keras实现支持masking的Flatten层代码
Jun 16 #Python
Keras自定义实现带masking的meanpooling层方式
Jun 16 #Python
浅谈keras 的抽象后端(from keras import backend as K)
Jun 16 #Python
记录模型训练时loss值的变化情况
Jun 16 #Python
python实现批量转换图片为黑白
Jun 16 #Python
在keras中实现查看其训练loss值
Jun 16 #Python
安装python3.7编译器后如何正确安装opnecv的方法详解
Jun 16 #Python
You might like
星际争霸 Starcraft 编年史
2020/03/14 星际争霸
phplock(php进程锁) v1.0 beta1
2009/11/24 PHP
thinkphp文件处理类Dir.class.php的用法分析
2014/12/08 PHP
Thinkphp连表查询及数据导出方法示例
2016/10/15 PHP
PHP实现类似于C语言的文件读取及解析功能
2017/09/01 PHP
Thinkphp5+plupload实现的图片上传功能示例【支持实时预览】
2019/05/08 PHP
csdn 博客中实现运行代码功能实现
2009/08/29 Javascript
jQuery提示插件alertify使用指南
2015/04/21 Javascript
Bootstrap+jfinal实现省市级联下拉菜单
2016/05/30 Javascript
基于cssSlidy.js插件实现响应式手机图片轮播效果
2016/08/30 Javascript
jquery实现文本框的禁用和启用
2016/12/07 Javascript
微信小程序之picker日期和时间选择器
2017/02/09 Javascript
重新理解JavaScript的六种继承方式
2017/03/24 Javascript
详解微信第三方小程序代开发
2017/06/23 Javascript
vuejs使用$emit和$on进行组件之间的传值的示例
2017/10/04 Javascript
AngularJS 教程及实例代码
2017/10/23 Javascript
Dropify.js图片宽高自适应的方法
2017/11/27 Javascript
jQuery实现鼠标点击处心形漂浮的炫酷效果示例
2018/04/12 jQuery
p5.js实现故宫橘猫赏秋图动画
2019/10/23 Javascript
[05:31]DOTA2英雄梦之声_第04期_光之守卫
2014/06/23 DOTA
[02:52]2014DOTA2西雅图国际邀请赛 CIS战队巡礼
2014/07/07 DOTA
Python3处理文件中每个词的方法
2015/05/22 Python
Python解析并读取PDF文件内容的方法
2018/05/08 Python
python用线性回归预测股票价格的实现代码
2019/09/04 Python
python 实现从高分辨图像上抠取图像块
2020/01/02 Python
Django ForeignKey与数据库的FOREIGN KEY约束详解
2020/05/20 Python
谈一谈HTML5本地存储技术
2016/03/02 HTML / CSS
松下电器美国官方商店:Panasonic美国
2016/10/14 全球购物
香港连卡佛百货官网:Lane Crawford
2019/09/04 全球购物
思想汇报范文
2013/11/04 职场文书
博士生入学考试推荐信
2013/11/17 职场文书
会计实习生自我鉴定
2013/12/12 职场文书
商务英语专业求职信
2014/06/26 职场文书
最美孝心少年事迹材料
2014/08/15 职场文书
干部四风问题整改措施思想汇报
2014/10/13 职场文书
出纳年终工作总结2014
2014/12/05 职场文书