解决Keras 自定义层时遇到版本的问题


Posted in Python onJune 16, 2020

在2.2.0版本前,

from keras import backend as K
from keras.engine.topology import Layer
 
class MyLayer(Layer):
 
  def __init__(self, output_dim, **kwargs):
    self.output_dim = output_dim
    super(MyLayer, self).__init__(**kwargs)
 
  def build(self, input_shape):
    # 为该层创建一个可训练的权重
    self.kernel = self.add_weight(name='kernel', 
                   shape=(input_shape[1], self.output_dim),
                   initializer='uniform',
                   trainable=True)
    super(MyLayer, self).build(input_shape) # 一定要在最后调用它
 
  def call(self, x):
    return K.dot(x, self.kernel)
 
  def compute_output_shape(self, input_shape):
    return (input_shape[0], self.output_dim)

2.2.0 版本时:

from keras import backend as K
from keras.layers import Layer
 
class MyLayer(Layer):
 
  def __init__(self, output_dim, **kwargs):
    self.output_dim = output_dim
    super(MyLayer, self).__init__(**kwargs)
 
  def build(self, input_shape):
    # Create a trainable weight variable for this layer.
    self.kernel = self.add_weight(name='kernel', 
                   shape=(input_shape[1], self.output_dim),
                   initializer='uniform',
                   trainable=True)
    super(MyLayer, self).build(input_shape) # Be sure to call this at the end
 
  def call(self, x):
    return K.dot(x, self.kernel)
 
  def compute_output_shape(self, input_shape):
    return (input_shape[0], self.output_dim)

如果你遇到:

<module> from keras.engine.base_layer import InputSpec ModuleNotFoundError: No module named 'keras.engine.base_layer'

不妨试试另一种引入!

补充知识:Keras自定义损失函数在场景分类的使用

在做图像场景分类的过程中,需要自定义损失函数,遇到很多坑。Keras自带的损失函数都在losses.py文件中。(以下默认为分类处理)

#losses.py
#y_true是分类的标签,y_pred是分类中预测值(这里指,模型最后一层为softmax层,输出的是每个类别的预测值)
def mean_squared_error(y_true, y_pred):
  return K.mean(K.square(y_pred - y_true), axis=-1)
def mean_absolute_error(y_true, y_pred):
  return K.mean(K.abs(y_pred - y_true), axis=-1)
def mean_absolute_percentage_error(y_true, y_pred):
  diff = K.abs((y_true - y_pred) / K.clip(K.abs(y_true),K.epsilon(),None))
  return 100. * K.mean(diff, axis=-1)
def mean_squared_logarithmic_error(y_true, y_pred):
  first_log = K.log(K.clip(y_pred, K.epsilon(), None) + 1.)
  second_log = K.log(K.clip(y_true, K.epsilon(), None) + 1.)
  return K.mean(K.square(first_log - second_log), axis=-1)
def squared_hinge(y_true, y_pred):
  return K.mean(K.square(K.maximum(1. - y_true * y_pred, 0.)), axis=-1)

这里面简单的来说,y_true就是训练数据的标签,y_pred就是模型训练时经过softmax层的预测值。经过计算,得出损失值。

那么我们要新建损失函数totoal_loss,就要在本文件下,进行新建。

def get_loss(labels,features, alpha,lambda_c,lambda_g,num_classes):
  #由于涉及研究内容,详细代码不做公开
  return loss
#total_loss(y_true,y_pred),y_true代表标签(类别),y_pred代表模型的输出
#( 如果是模型中间层输出,即代表特征,如果模型输出是经过softmax就是代表分类预测值)
#其他有需要的参数也可以写在里面
def total_loss(y_true,y_pred):
    git_loss=get_loss(y_true,y_pred,alpha=0.5,lambda_c=0.001,lambda_g=0.001,num_classes=45)
    return git_loss

自定义损失函数写好之后,可以进行使用了。这里,我使用交叉熵损失函数和自定义损失函数一起使用。

#这里使用vgg16模型
model = VGG16(input_tensor=image_input, include_top=True,weights='imagenet')
model.summary()
#fc2层输出为特征
last_layer = model.get_layer('fc2').output
#获取特征
feature = last_layer
#softmax层输出为各类的预测值
out = Dense(num_classes,activation = 'softmax',name='predictions')(last_layer)
#该模型有一个输入image_input,两个输出out,feature
custom_vgg_model = Model(inputs = image_input, outputs = [feature,out])
custom_vgg_model.summary()
#优化器,梯度下降
sgd = optimizers.SGD(lr=learn_Rate,decay=decay_Rate,momentum=0.9,nesterov=True)
#这里面,刚才有两个输出,这里面使用两个损失函数,total_loss对应的是fc2层输出的特征
#categorical_crossentropy对应softmax层的损失函数
#loss_weights两个损失函数的权重
custom_vgg_model.compile(loss={'fc2': 'total_loss','predictions': "categorical_crossentropy"},
             loss_weights={'fc2': 1, 'predictions':1},optimizer= sgd,
                   metrics={'predictions': 'accuracy'})
#这里使用dummy1,dummy2做演示,为0
dummy1 = np.zeros((y_train.shape[0],4096))
dummy2 = np.zeros((y_test.shape[0],4096))
#模型的输入输出必须和model.fit()中x,y两个参数维度相同
#dummy1的维度和fc2层输出的feature维度相同,y_train和softmax层输出的预测值维度相同
#validation_data验证数据集也是如此,需要和输出层的维度相同
hist = custom_vgg_model.fit(x = X_train,y = {'fc2':dummy1,'predictions':y_train},batch_size=batch_Sizes,
                epochs=epoch_Times, verbose=1,validation_data=(X_test, {'fc2':dummy2,'predictions':y_test}))

写到这里差不多就可以了,不够详细,以后再做补充。

以上这篇解决Keras 自定义层时遇到版本的问题就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
python利用hook技术破解https的实例代码
Mar 25 Python
python连接池实现示例程序
Nov 26 Python
Python functools模块学习总结
May 09 Python
python字符串的常用操作方法小结
May 21 Python
Python的Tornado框架实现异步非阻塞访问数据库的示例
Jun 30 Python
Python Tkinter模块实现时钟功能应用示例
Jul 23 Python
python+numpy+matplotalib实现梯度下降法
Aug 31 Python
django 中QuerySet特性功能详解
Jul 25 Python
使用python实现男神女神颜值打分系统(推荐)
Oct 31 Python
Python实现在Windows平台修改文件属性
Mar 05 Python
Python退出时强制运行一段代码的实现方法
Apr 29 Python
python中的3种定义类方法
Nov 27 Python
Keras实现支持masking的Flatten层代码
Jun 16 #Python
Keras自定义实现带masking的meanpooling层方式
Jun 16 #Python
浅谈keras 的抽象后端(from keras import backend as K)
Jun 16 #Python
记录模型训练时loss值的变化情况
Jun 16 #Python
python实现批量转换图片为黑白
Jun 16 #Python
在keras中实现查看其训练loss值
Jun 16 #Python
安装python3.7编译器后如何正确安装opnecv的方法详解
Jun 16 #Python
You might like
2020年4月新番动漫目录 官方宣布4月播出的作品一览
2020/03/08 日漫
PHP安装攻略:常见问题解答(二)
2006/10/09 PHP
php 购物车实例(申精)
2009/05/11 PHP
php过滤html中的其他网站链接的方法(域名白名单功能)
2014/04/24 PHP
php通过ksort()函数给关联数组按照键排序的方法
2015/03/18 PHP
php文件扩展名判断及获取文件扩展名的N种方法
2015/09/12 PHP
求帮忙修改个php curl模拟post请求内容后并下载文件的解决思路
2015/09/20 PHP
PHP采用超长(超大)数字运算防止数字以科学计数法显示的方法
2016/04/01 PHP
PHP入门教程之表单与验证实例详解
2016/09/11 PHP
uploadify 3.0 详细使用说明
2012/06/18 Javascript
JQuery中基础过滤选择器用法实例分析
2015/05/18 Javascript
快速学习jQuery插件 Cookie插件使用方法
2015/12/01 Javascript
Javascript技术难点之apply,call与this之间的衔接
2015/12/04 Javascript
浅谈javascript中的加减时间
2016/07/12 Javascript
JS获取鼠标相对位置的方法
2016/09/20 Javascript
详解Node.js模板引擎Jade入门
2018/01/19 Javascript
python的多重继承的理解
2017/08/06 Python
基于python socketserver框架全面解析
2017/09/21 Python
scrapy spider的几种爬取方式实例代码
2018/01/25 Python
Python3多进程 multiprocessing 模块实例详解
2018/06/11 Python
Python  unittest单元测试框架的使用
2018/09/08 Python
Python实现实时数据采集新型冠状病毒数据实例
2020/02/04 Python
python3发送request请求及查看返回结果实例
2020/04/30 Python
Python selenium 加载并保存QQ群成员,去除其群主、管理员信息的示例代码
2020/05/28 Python
python json.dumps() json.dump()的区别详解
2020/07/14 Python
CSS3 flex布局之快速实现BorderLayout布局
2015/12/03 HTML / CSS
CSS3 filter(滤镜)实现网页灰色或者黑色模式的示例代码
2021/02/24 HTML / CSS
Book Depository亚太地区:一家领先的国际图书零售商
2019/05/05 全球购物
巴西一家专门从事家居和装饰的连锁店:Camicado
2019/08/14 全球购物
C++:局部变量能否和全局变量重名
2014/03/03 面试题
大学生实习思想汇报
2014/01/12 职场文书
大班下学期个人总结
2015/02/13 职场文书
大一学生个人总结
2015/02/15 职场文书
大学生学习十八届五中全会精神心得体会
2016/01/05 职场文书
5人制售《绝地求生》游戏外挂获利500多万元 被判刑
2022/03/31 其他游戏
Python面试不修改数组找出重复的数字
2022/05/20 Python