解决Keras 自定义层时遇到版本的问题


Posted in Python onJune 16, 2020

在2.2.0版本前,

from keras import backend as K
from keras.engine.topology import Layer
 
class MyLayer(Layer):
 
  def __init__(self, output_dim, **kwargs):
    self.output_dim = output_dim
    super(MyLayer, self).__init__(**kwargs)
 
  def build(self, input_shape):
    # 为该层创建一个可训练的权重
    self.kernel = self.add_weight(name='kernel', 
                   shape=(input_shape[1], self.output_dim),
                   initializer='uniform',
                   trainable=True)
    super(MyLayer, self).build(input_shape) # 一定要在最后调用它
 
  def call(self, x):
    return K.dot(x, self.kernel)
 
  def compute_output_shape(self, input_shape):
    return (input_shape[0], self.output_dim)

2.2.0 版本时:

from keras import backend as K
from keras.layers import Layer
 
class MyLayer(Layer):
 
  def __init__(self, output_dim, **kwargs):
    self.output_dim = output_dim
    super(MyLayer, self).__init__(**kwargs)
 
  def build(self, input_shape):
    # Create a trainable weight variable for this layer.
    self.kernel = self.add_weight(name='kernel', 
                   shape=(input_shape[1], self.output_dim),
                   initializer='uniform',
                   trainable=True)
    super(MyLayer, self).build(input_shape) # Be sure to call this at the end
 
  def call(self, x):
    return K.dot(x, self.kernel)
 
  def compute_output_shape(self, input_shape):
    return (input_shape[0], self.output_dim)

如果你遇到:

<module> from keras.engine.base_layer import InputSpec ModuleNotFoundError: No module named 'keras.engine.base_layer'

不妨试试另一种引入!

补充知识:Keras自定义损失函数在场景分类的使用

在做图像场景分类的过程中,需要自定义损失函数,遇到很多坑。Keras自带的损失函数都在losses.py文件中。(以下默认为分类处理)

#losses.py
#y_true是分类的标签,y_pred是分类中预测值(这里指,模型最后一层为softmax层,输出的是每个类别的预测值)
def mean_squared_error(y_true, y_pred):
  return K.mean(K.square(y_pred - y_true), axis=-1)
def mean_absolute_error(y_true, y_pred):
  return K.mean(K.abs(y_pred - y_true), axis=-1)
def mean_absolute_percentage_error(y_true, y_pred):
  diff = K.abs((y_true - y_pred) / K.clip(K.abs(y_true),K.epsilon(),None))
  return 100. * K.mean(diff, axis=-1)
def mean_squared_logarithmic_error(y_true, y_pred):
  first_log = K.log(K.clip(y_pred, K.epsilon(), None) + 1.)
  second_log = K.log(K.clip(y_true, K.epsilon(), None) + 1.)
  return K.mean(K.square(first_log - second_log), axis=-1)
def squared_hinge(y_true, y_pred):
  return K.mean(K.square(K.maximum(1. - y_true * y_pred, 0.)), axis=-1)

这里面简单的来说,y_true就是训练数据的标签,y_pred就是模型训练时经过softmax层的预测值。经过计算,得出损失值。

那么我们要新建损失函数totoal_loss,就要在本文件下,进行新建。

def get_loss(labels,features, alpha,lambda_c,lambda_g,num_classes):
  #由于涉及研究内容,详细代码不做公开
  return loss
#total_loss(y_true,y_pred),y_true代表标签(类别),y_pred代表模型的输出
#( 如果是模型中间层输出,即代表特征,如果模型输出是经过softmax就是代表分类预测值)
#其他有需要的参数也可以写在里面
def total_loss(y_true,y_pred):
    git_loss=get_loss(y_true,y_pred,alpha=0.5,lambda_c=0.001,lambda_g=0.001,num_classes=45)
    return git_loss

自定义损失函数写好之后,可以进行使用了。这里,我使用交叉熵损失函数和自定义损失函数一起使用。

#这里使用vgg16模型
model = VGG16(input_tensor=image_input, include_top=True,weights='imagenet')
model.summary()
#fc2层输出为特征
last_layer = model.get_layer('fc2').output
#获取特征
feature = last_layer
#softmax层输出为各类的预测值
out = Dense(num_classes,activation = 'softmax',name='predictions')(last_layer)
#该模型有一个输入image_input,两个输出out,feature
custom_vgg_model = Model(inputs = image_input, outputs = [feature,out])
custom_vgg_model.summary()
#优化器,梯度下降
sgd = optimizers.SGD(lr=learn_Rate,decay=decay_Rate,momentum=0.9,nesterov=True)
#这里面,刚才有两个输出,这里面使用两个损失函数,total_loss对应的是fc2层输出的特征
#categorical_crossentropy对应softmax层的损失函数
#loss_weights两个损失函数的权重
custom_vgg_model.compile(loss={'fc2': 'total_loss','predictions': "categorical_crossentropy"},
             loss_weights={'fc2': 1, 'predictions':1},optimizer= sgd,
                   metrics={'predictions': 'accuracy'})
#这里使用dummy1,dummy2做演示,为0
dummy1 = np.zeros((y_train.shape[0],4096))
dummy2 = np.zeros((y_test.shape[0],4096))
#模型的输入输出必须和model.fit()中x,y两个参数维度相同
#dummy1的维度和fc2层输出的feature维度相同,y_train和softmax层输出的预测值维度相同
#validation_data验证数据集也是如此,需要和输出层的维度相同
hist = custom_vgg_model.fit(x = X_train,y = {'fc2':dummy1,'predictions':y_train},batch_size=batch_Sizes,
                epochs=epoch_Times, verbose=1,validation_data=(X_test, {'fc2':dummy2,'predictions':y_test}))

写到这里差不多就可以了,不够详细,以后再做补充。

以上这篇解决Keras 自定义层时遇到版本的问题就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
Python远程桌面协议RDPY安装使用介绍
Apr 15 Python
Python中每次处理一个字符的5种方法
May 21 Python
python实现查找两个字符串中相同字符并输出的方法
Jul 11 Python
python爬虫 正则表达式使用技巧及爬取个人博客的实例讲解
Oct 20 Python
解决python 未发现数据源名称并且未指定默认驱动程序的问题
Dec 07 Python
Python使用itchat模块实现简单的微信控制电脑功能示例
Aug 26 Python
opencv3/C++图像像素操作详解
Dec 10 Python
Python StringIO如何在内存中读写str
Jan 07 Python
Python的PIL库中getpixel方法的使用
Apr 09 Python
python 如何快速复制序列
Sep 07 Python
python Gabor滤波器讲解
Oct 26 Python
总结Python连接CS2000的详细步骤
Jun 23 Python
Keras实现支持masking的Flatten层代码
Jun 16 #Python
Keras自定义实现带masking的meanpooling层方式
Jun 16 #Python
浅谈keras 的抽象后端(from keras import backend as K)
Jun 16 #Python
记录模型训练时loss值的变化情况
Jun 16 #Python
python实现批量转换图片为黑白
Jun 16 #Python
在keras中实现查看其训练loss值
Jun 16 #Python
安装python3.7编译器后如何正确安装opnecv的方法详解
Jun 16 #Python
You might like
通过html表格发电子邮件
2006/10/09 PHP
PHP在线生成二维码代码(google api)
2013/06/03 PHP
PHP CLI模式下的多进程应用分析
2013/06/03 PHP
php+html5基于websocket实现聊天室的方法
2015/07/17 PHP
php无限级评论嵌套实现代码
2018/04/18 PHP
PHP获取本周所有日期或者最近七天所有日期的方法
2018/06/20 PHP
简单介绍JavaScript中字符串创建的基本方法
2015/07/07 Javascript
详解JavaScript中jQuery和Ajax以及JSONP的联合使用
2015/08/13 Javascript
jQuery的选择器中的通配符[id^='code']或[name^='code']及jquery选择器总结
2015/12/24 Javascript
js实现数字递增特效【仿支付宝我的财富】
2017/05/05 Javascript
解决jquery appaend元素中id绑定事件失效的问题
2017/09/12 jQuery
vue实现a标签点击高亮方法
2018/03/17 Javascript
图文详解vue框架安装步骤
2019/02/12 Javascript
uni-app 支持多端第三方地图定位的方法
2020/01/03 Javascript
原生js实现五子棋游戏
2020/05/28 Javascript
[43:32]2014 DOTA2华西杯精英邀请赛 5 25 LGD VS NewBee第一场
2014/05/26 DOTA
Python常见文件操作的函数示例代码
2011/11/15 Python
python之virtualenv的简单使用方法(必看篇)
2017/11/25 Python
python3实现公众号每日定时发送日报和图片
2018/02/24 Python
详解Python读取yaml文件多层菜单
2019/03/23 Python
解决matplotlib.pyplot在Jupyter notebook中不显示图像问题
2020/04/22 Python
python中yield的用法详解
2021/01/13 Python
HTML5中canvas中的beginPath()和closePath()的重要性
2018/08/24 HTML / CSS
ALDI奥乐齐官方海外旗舰店:德国百年超市
2017/12/27 全球购物
捷克领先的户外服装及配件市场零售商:ALPINE PRO
2018/01/09 全球购物
zooplus波兰:在线宠物店
2019/07/21 全球购物
攀岩、滑雪、徒步旅行装备:Black Diamond Equipment
2019/08/16 全球购物
Etam德国:内衣精品店
2019/08/25 全球购物
印度电子产品购物网站:Vijay Sales
2021/02/16 全球购物
简历自荐信
2013/12/02 职场文书
网上祭先烈心得体会
2014/09/01 职场文书
2014坚持党风廉政建设思想汇报
2014/09/18 职场文书
汶川大地震感悟
2015/08/10 职场文书
给原生html中添加水印遮罩层的实现示例
2021/04/02 Javascript
MySQL查看表和清空表的常用命令总结
2021/05/26 MySQL
一篇文章弄懂Python关键字、标识符和变量
2021/07/15 Python