keras K.function获取某层的输出操作

时间:2021-05-22

如下所示:

from keras import backend as Kfrom keras.models import load_modelmodels = load_model('models.hdf5')image=r'image.png'images=cv2.imread(r'image.png')image_arr = process_image(image, (224, 224, 3))image_arr = np.expand_dims(image_arr, axis=0)layer_1 = K.function([base_model.get_input_at(0)], [base_model.get_layer('layer_name').output])f1 = layer_1([image_arr])[0]

加载训练好并保存的网络模型

加载数据(图像),并将数据处理成array形式

指定输出层

将处理后的数据输入,然后获取输出

其中,K.function有两种不同的写法:

1. 获取名为layer_name的层的输出

layer_1 = K.function([base_model.get_input_at(0)], [base_model.get_layer('layer_name').output])#指定输出层的名称

2. 获取第n层的输出

layer_1 = K.function([model.get_input_at(0)], [model.layers[5].output])#指定输出层的序号(层号从0开始)

另外,需要注意的是,书写不规范会导致报错:

报错:

TypeError: inputs to a TensorFlow backend function should be a list or tuple

将该句:

f1 = layer_1(image_arr)[0]

修改为:

f1 = layer_1([image_arr])[0]

补充知识:keras.backend.function()

如下所示:

def function(inputs, outputs, updates=None, **kwargs): """Instantiates a Keras function. Arguments: inputs: List of placeholder tensors. outputs: List of output tensors. updates: List of update ops. **kwargs: Passed to `tf.Session.run`. Returns: Output values as Numpy arrays. Raises: ValueError: if invalid kwargs are passed in. """ if kwargs: for key in kwargs: if (key not in tf_inspect.getargspec(session_module.Session.run)[0] and key not in tf_inspect.getargspec(Function.__init__)[0]): msg = ('Invalid argument "%s" passed to K.function with Tensorflow ' 'backend') % key raise ValueError(msg) return Function(inputs, outputs, updates=updates, **kwargs)

这是keras.backend.function()的源码。其中函数定义开头的注释就是官方文档对该函数的解释。

我们可以发现function()函数返回的是一个Function对象。下面是Function类的定义。

class Function(object): """Runs a computation graph. Arguments: inputs: Feed placeholders to the computation graph. outputs: Output tensors to fetch. updates: Additional update ops to be run at function call. name: a name to help users identify what this function does. """ def __init__(self, inputs, outputs, updates=None, name=None, **session_kwargs): updates = updates or [] if not isinstance(inputs, (list, tuple)): raise TypeError('`inputs` to a TensorFlow backend function ' 'should be a list or tuple.') if not isinstance(outputs, (list, tuple)): raise TypeError('`outputs` of a TensorFlow backend function ' 'should be a list or tuple.') if not isinstance(updates, (list, tuple)): raise TypeError('`updates` in a TensorFlow backend function ' 'should be a list or tuple.') self.inputs = list(inputs) self.outputs = list(outputs) with ops.control_dependencies(self.outputs): updates_ops = [] for update in updates: if isinstance(update, tuple): p, new_p = update updates_ops.append(state_ops.assign(p, new_p)) else: # assumed already an op updates_ops.append(update) self.updates_op = control_flow_ops.group(*updates_ops) self.name = name self.session_kwargs = session_kwargs def __call__(self, inputs): if not isinstance(inputs, (list, tuple)): raise TypeError('`inputs` should be a list or tuple.') feed_dict = {} for tensor, value in zip(self.inputs, inputs): if is_sparse(tensor): sparse_coo = value.tocoo() indices = np.concatenate((np.expand_dims(sparse_coo.row, 1), np.expand_dims(sparse_coo.col, 1)), 1) value = (indices, sparse_coo.data, sparse_coo.shape) feed_dict[tensor] = value session = get_session() updated = session.run( self.outputs + [self.updates_op], feed_dict=feed_dict, **self.session_kwargs) return updated[:len(self.outputs)]

所以,function函数利用我们之前已经创建好的comuptation graph。遵循计算图,从输入到定义的输出。这也是为什么该函数经常用于提取中间层结果。

以上这篇keras K.function获取某层的输出操作就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持。

声明:本页内容来源网络,仅供用户参考;我单位不保证亦不表示资料全面及准确无误,也不保证亦不表示这些资料为最新信息,如因任何原因,本网内容或者用户因倚赖本网内容造成任何损失或损害,我单位将不会负任何法律责任。如涉及版权问题,请提交至online#300.cn邮箱联系删除。

相关文章