解决Keras 自定义层时遇到版本的问题

时间:2021-05-22

在2.2.0版本前,

from keras import backend as Kfrom keras.engine.topology import Layer class MyLayer(Layer): def __init__(self, output_dim, **kwargs): self.output_dim = output_dim super(MyLayer, self).__init__(**kwargs) def build(self, input_shape): # 为该层创建一个可训练的权重 self.kernel = self.add_weight(name='kernel', shape=(input_shape[1], self.output_dim), initializer='uniform', trainable=True) super(MyLayer, self).build(input_shape) # 一定要在最后调用它 def call(self, x): return K.dot(x, self.kernel) def compute_output_shape(self, input_shape): return (input_shape[0], self.output_dim)

2.2.0 版本时:

from keras import backend as Kfrom keras.layers import Layer class MyLayer(Layer): def __init__(self, output_dim, **kwargs): self.output_dim = output_dim super(MyLayer, self).__init__(**kwargs) def build(self, input_shape): # Create a trainable weight variable for this layer. self.kernel = self.add_weight(name='kernel', shape=(input_shape[1], self.output_dim), initializer='uniform', trainable=True) super(MyLayer, self).build(input_shape) # Be sure to call this at the end def call(self, x): return K.dot(x, self.kernel) def compute_output_shape(self, input_shape): return (input_shape[0], self.output_dim)

如果你遇到:

<module> from keras.engine.base_layer import InputSpec ModuleNotFoundError: No module named 'keras.engine.base_layer'

不妨试试另一种引入!

补充知识:Keras自定义损失函数在场景分类的使用

在做图像场景分类的过程中,需要自定义损失函数,遇到很多坑。Keras自带的损失函数都在losses.py文件中。(以下默认为分类处理)

#losses.py#y_true是分类的标签,y_pred是分类中预测值(这里指,模型最后一层为softmax层,输出的是每个类别的预测值)def mean_squared_error(y_true, y_pred): return K.mean(K.square(y_pred - y_true), axis=-1)def mean_absolute_error(y_true, y_pred): return K.mean(K.abs(y_pred - y_true), axis=-1)def mean_absolute_percentage_error(y_true, y_pred): diff = K.abs((y_true - y_pred) / K.clip(K.abs(y_true),K.epsilon(),None)) return 100. * K.mean(diff, axis=-1)def mean_squared_logarithmic_error(y_true, y_pred): first_log = K.log(K.clip(y_pred, K.epsilon(), None) + 1.) second_log = K.log(K.clip(y_true, K.epsilon(), None) + 1.) return K.mean(K.square(first_log - second_log), axis=-1)def squared_hinge(y_true, y_pred): return K.mean(K.square(K.maximum(1. - y_true * y_pred, 0.)), axis=-1)

这里面简单的来说,y_true就是训练数据的标签,y_pred就是模型训练时经过softmax层的预测值。经过计算,得出损失值。

那么我们要新建损失函数totoal_loss,就要在本文件下,进行新建。

def get_loss(labels,features, alpha,lambda_c,lambda_g,num_classes): #由于涉及研究内容,详细代码不做公开 return loss#total_loss(y_true,y_pred),y_true代表标签(类别),y_pred代表模型的输出#( 如果是模型中间层输出,即代表特征,如果模型输出是经过softmax就是代表分类预测值)#其他有需要的参数也可以写在里面def total_loss(y_true,y_pred): git_loss=get_loss(y_true,y_pred,alpha=0.5,lambda_c=0.001,lambda_g=0.001,num_classes=45) return git_loss

自定义损失函数写好之后,可以进行使用了。这里,我使用交叉熵损失函数和自定义损失函数一起使用。

#这里使用vgg16模型model = VGG16(input_tensor=image_input, include_top=True,weights='imagenet')model.summary()#fc2层输出为特征last_layer = model.get_layer('fc2').output#获取特征feature = last_layer#softmax层输出为各类的预测值out = Dense(num_classes,activation = 'softmax',name='predictions')(last_layer)#该模型有一个输入image_input,两个输出out,featurecustom_vgg_model = Model(inputs = image_input, outputs = [feature,out])custom_vgg_model.summary()#优化器,梯度下降sgd = optimizers.SGD(lr=learn_Rate,decay=decay_Rate,momentum=0.9,nesterov=True)#这里面,刚才有两个输出,这里面使用两个损失函数,total_loss对应的是fc2层输出的特征#categorical_crossentropy对应softmax层的损失函数#loss_weights两个损失函数的权重custom_vgg_model.compile(loss={'fc2': 'total_loss','predictions': "categorical_crossentropy"}, loss_weights={'fc2': 1, 'predictions':1},optimizer= sgd, metrics={'predictions': 'accuracy'})#这里使用dummy1,dummy2做演示,为0dummy1 = np.zeros((y_train.shape[0],4096))dummy2 = np.zeros((y_test.shape[0],4096))#模型的输入输出必须和model.fit()中x,y两个参数维度相同#dummy1的维度和fc2层输出的feature维度相同,y_train和softmax层输出的预测值维度相同#validation_data验证数据集也是如此,需要和输出层的维度相同hist = custom_vgg_model.fit(x = X_train,y = {'fc2':dummy1,'predictions':y_train},batch_size=batch_Sizes, epochs=epoch_Times, verbose=1,validation_data=(X_test, {'fc2':dummy2,'predictions':y_test}))

写到这里差不多就可以了,不够详细,以后再做补充。

以上这篇解决Keras 自定义层时遇到版本的问题就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持。

声明:本页内容来源网络,仅供用户参考;我单位不保证亦不表示资料全面及准确无误,也不保证亦不表示这些资料为最新信息,如因任何原因,本网内容或者用户因倚赖本网内容造成任何损失或损害,我单位将不会负任何法律责任。如涉及版权问题,请提交至online#300.cn邮箱联系删除。

相关文章