keras中的两种tensors

今天在用keras框架训练一个网络时出现了一个不知所云的错误,内容如下:

1
ValueError: Output tensors to a Model must be the output of a Keras `Layer` (thus holding past layer metadata).

我就奇了怪了,我所有网络都是用keras来搭建的,这让我很莫名其妙。仔细调查一番后,才发现keras其实是有两种tensor的,一种是keras tensor,一种是backend tensor(比如如果你像我一样是用TensorFlow作为backend的话,就是tensorflow tensor)。简单的对比是,keras tensor是backend tensor的子类,增添了一些keras模型的一些信息,如上面提到的past layer metadata。下面我们用代码简单验证一下。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
from keras.layers import Input, ZeroPadding2D
from keras.backend import spatial_2d_padding
from keras.models import Model

# 由于ZeroPadding2D是一个keras layer
# 调用这个函数会返回一个keras tensor
def pad_keras(x, pad):
return ZeroPadding2D(padding=(pad, pad), data_format='channels_last')(x)

# 由于spatial_2d_padding来自于keras.backend模块
# 调用这个函数会返回一个tensorflow tensor
def pad_tf(x, pad):
return spatial_2d_padding(x, padding=((pad,pad),(pad,pad)), data_format='channels_last')

if __name__=='__main__':
img = Input(shape=(2,2,3))
keras_out = pad_keras(img,2)
tf_out = pad_tf(img,2)

keras_model = Model(inputs=img, outputs=keras_out) # works fine
tf_model = Model(inputs=img, outputs=tf_out) # raise ValueError

如果你跑一下上面的代码,最后一行会引发一个如上文所述的ValueError。简单来说,任何backend里面的函数都会返回一个backend tensor。那么,有没有办法回避这个问题呢?废话,如果没有解决办法我说半天不就浪费狗生了吗?



解决办法是用keras.layers.Lambda函数把用到的backend函数包装起来。客官请看下面代码:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
from keras.layers import Input, Lambda
from keras.backend import spatial_2d_padding
from keras.models import Model
import numpy as np

# 嗯,就是这样!
def pad_tf_correct_way(x, pad):
ret = Lambda(lambda x: \
spatial_2d_padding(x, padding=((pad,pad),(pad,pad))))(x)
return ret

if __name__=='__main__':
img = Input(shape=(2,2,3))
tf_out = pad_tf_correct_way(img, 2)
tf_model = Model(inputs=img, outputs=tf_out)

img_input = np.arange(12).reshape(1,2,2,3)
tf_model_out = tf_model.predict(img_input)

通过Lambda层,我们把原本spatial_2d_padding返回的backend tensor转变成了keras tensor,这样就可以放心大胆地使用keras.backend中的函数构建神经网络模型了。

值得注意的是,keras只有在构建网络的时候才会严格要求使用keras tensor,在写loss和metrics中用的目标函数是不需要这么做的。