Posted in Python onJune 16, 2020
在2.2.0版本前,
from keras import backend as K from keras.engine.topology import Layer class MyLayer(Layer): def __init__(self, output_dim, **kwargs): self.output_dim = output_dim super(MyLayer, self).__init__(**kwargs) def build(self, input_shape): # 为该层创建一个可训练的权重 self.kernel = self.add_weight(name='kernel', shape=(input_shape[1], self.output_dim), initializer='uniform', trainable=True) super(MyLayer, self).build(input_shape) # 一定要在最后调用它 def call(self, x): return K.dot(x, self.kernel) def compute_output_shape(self, input_shape): return (input_shape[0], self.output_dim)
2.2.0 版本时:
from keras import backend as K from keras.layers import Layer class MyLayer(Layer): def __init__(self, output_dim, **kwargs): self.output_dim = output_dim super(MyLayer, self).__init__(**kwargs) def build(self, input_shape): # Create a trainable weight variable for this layer. self.kernel = self.add_weight(name='kernel', shape=(input_shape[1], self.output_dim), initializer='uniform', trainable=True) super(MyLayer, self).build(input_shape) # Be sure to call this at the end def call(self, x): return K.dot(x, self.kernel) def compute_output_shape(self, input_shape): return (input_shape[0], self.output_dim)
如果你遇到:
<module> from keras.engine.base_layer import InputSpec ModuleNotFoundError: No module named 'keras.engine.base_layer'
不妨试试另一种引入!
补充知识:Keras自定义损失函数在场景分类的使用
在做图像场景分类的过程中,需要自定义损失函数,遇到很多坑。Keras自带的损失函数都在losses.py文件中。(以下默认为分类处理)
#losses.py #y_true是分类的标签,y_pred是分类中预测值(这里指,模型最后一层为softmax层,输出的是每个类别的预测值) def mean_squared_error(y_true, y_pred): return K.mean(K.square(y_pred - y_true), axis=-1) def mean_absolute_error(y_true, y_pred): return K.mean(K.abs(y_pred - y_true), axis=-1) def mean_absolute_percentage_error(y_true, y_pred): diff = K.abs((y_true - y_pred) / K.clip(K.abs(y_true),K.epsilon(),None)) return 100. * K.mean(diff, axis=-1) def mean_squared_logarithmic_error(y_true, y_pred): first_log = K.log(K.clip(y_pred, K.epsilon(), None) + 1.) second_log = K.log(K.clip(y_true, K.epsilon(), None) + 1.) return K.mean(K.square(first_log - second_log), axis=-1) def squared_hinge(y_true, y_pred): return K.mean(K.square(K.maximum(1. - y_true * y_pred, 0.)), axis=-1)
这里面简单的来说,y_true就是训练数据的标签,y_pred就是模型训练时经过softmax层的预测值。经过计算,得出损失值。
那么我们要新建损失函数totoal_loss,就要在本文件下,进行新建。
def get_loss(labels,features, alpha,lambda_c,lambda_g,num_classes): #由于涉及研究内容,详细代码不做公开 return loss #total_loss(y_true,y_pred),y_true代表标签(类别),y_pred代表模型的输出 #( 如果是模型中间层输出,即代表特征,如果模型输出是经过softmax就是代表分类预测值) #其他有需要的参数也可以写在里面 def total_loss(y_true,y_pred): git_loss=get_loss(y_true,y_pred,alpha=0.5,lambda_c=0.001,lambda_g=0.001,num_classes=45) return git_loss
自定义损失函数写好之后,可以进行使用了。这里,我使用交叉熵损失函数和自定义损失函数一起使用。
#这里使用vgg16模型 model = VGG16(input_tensor=image_input, include_top=True,weights='imagenet') model.summary() #fc2层输出为特征 last_layer = model.get_layer('fc2').output #获取特征 feature = last_layer #softmax层输出为各类的预测值 out = Dense(num_classes,activation = 'softmax',name='predictions')(last_layer) #该模型有一个输入image_input,两个输出out,feature custom_vgg_model = Model(inputs = image_input, outputs = [feature,out]) custom_vgg_model.summary() #优化器,梯度下降 sgd = optimizers.SGD(lr=learn_Rate,decay=decay_Rate,momentum=0.9,nesterov=True) #这里面,刚才有两个输出,这里面使用两个损失函数,total_loss对应的是fc2层输出的特征 #categorical_crossentropy对应softmax层的损失函数 #loss_weights两个损失函数的权重 custom_vgg_model.compile(loss={'fc2': 'total_loss','predictions': "categorical_crossentropy"}, loss_weights={'fc2': 1, 'predictions':1},optimizer= sgd, metrics={'predictions': 'accuracy'}) #这里使用dummy1,dummy2做演示,为0 dummy1 = np.zeros((y_train.shape[0],4096)) dummy2 = np.zeros((y_test.shape[0],4096)) #模型的输入输出必须和model.fit()中x,y两个参数维度相同 #dummy1的维度和fc2层输出的feature维度相同,y_train和softmax层输出的预测值维度相同 #validation_data验证数据集也是如此,需要和输出层的维度相同 hist = custom_vgg_model.fit(x = X_train,y = {'fc2':dummy1,'predictions':y_train},batch_size=batch_Sizes, epochs=epoch_Times, verbose=1,validation_data=(X_test, {'fc2':dummy2,'predictions':y_test}))
写到这里差不多就可以了,不够详细,以后再做补充。
以上这篇解决Keras 自定义层时遇到版本的问题就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。
解决Keras 自定义层时遇到版本的问题
- Author -
orDream声明:登载此文出于传递更多信息之目的,并不意味着赞同其观点或证实其描述。
Reply on: @reply_date@
@reply_contents@