Keras Reshape: total size of the new array must be unchanged

I'm trying to use Keras Reshape function API to reshape the output of a glove embedding (4D shape: (?, 9, 20, 100)) down to 3D (?, 9, 2000). However, when I tried Reshape((9, 2000))(text_layer), an error pops up and says total size of the new array must be unchanged, even though 9 * 20 * 100 = 9 * 2000. Any ideas why? Code is attached.

text = Input(shape=(9, news_text.shape[1]), name='text')
text_layer = Embedding(
    embedding_matrix.shape[0],
    embedding_matrix.shape[1],
    weights=[embedding_matrix],
    input_length=news_text.shape[1]
)(text)
text_layer = Reshape((9, text_layer.shape[2] * text_layer.shape[3]))(text_layer)

1 answer

  • answered 2018-08-11 04:32 Barrendeitor

    Remove the input_length parameter from the Embedding layer.

    It is strange and I don't know the reason but when you indicate the parameter input_length the error is thrown.

    Anyway, the Embedding layer receives the dimension of the Input layer. It seems that the parameter input_length has a very specific use, to know the dimension of the tensor after using a Flatten layer, etc.

    In this case, the Embedding layer obtains the shape of the output tensor from the input tensor, ignoring the input_length parameter.

    (If you set an invalid value does not throw error until you add the next layer. Note that the input_lenght and the result shape):

    >>> inp = Input(shape=(9,20))
    >>> emb = Embedding(100,100, input_length=84) (inp)
    >>> emb
    <tf.Tensor 'embedding_5/embedding_lookup:0' shape=(?, 9, 20, 100) dtype=float32>
    >>> res = Reshape((9,2000)) (emb)
    Traceback (most recent call last):
      File "<stdin>", line 1, in <module>
      ...
    

    However, it seems that the input_length parameter is in conflict when you add the Reshape layer.

    Finally:

    text = Input(shape=(9, news_text.shape[1]), name='text')
    text_layer = Embedding(
        embedding_matrix.shape[0],
        embedding_matrix.shape[1],
        weights=[embedding_matrix],
    )(text)
    text_layer = Reshape((9, text_layer.shape[2] * text_layer.shape[3]))(text_layer)