How to build a layer that merges two images, in an order determined by a third parameter

I want a layer that takes three tensors as input: Two (n,m,k) tensors and one (1) tensor, i.e. one single number. The output should be a (n,m,2k) tensor, achieved by simply having the first k channels be one image and the remaining being the other. Now, the catch is that the order in which we merge them -- do we put image one on top of image two or vice versa -- should be determined by whether the third input is greater than 0 or not.

By my thinking, this is an entirely static layer without any trainable parameters, so I tried doing the ordering choice with a Lambda layer, as follows:

def image_scrambler(inp): #inp = [im1, im2, aux_input]
    im1, im2, aux_input = inp[0],inp[1],inp[2]
    assert aux_input==1 or aux_input==0
    if aux_input==0:
        return [im1, im2]
    else:
        return [im2,im1]
paired_images = Lambda(image_scrambler)([image_input, decoder, aux_input])

This does not work, because it protests that the layer is dynamic and needs to be built with dynamic=True. When I try that, I get a RecursionError as follows:

---------------------------------------------------------------------------
RecursionError                            Traceback (most recent call last)
<ipython-input-15-a40adb50e97d> in <module>
      7         return [im2,im1]
      8 aux_input = Input(shape=(1))
----> 9 paired_images = Lambda(image_scrambler,dynamic=True)([image_input, decoder, aux_input])

c:\users\vilhelm\appdata\local\programs\python\python36\lib\site-packages\tensorflow_core\python\keras\engine\base_layer.py in __call__(self, inputs, *args, **kwargs)
    791             # TODO(fchollet): consider py_func as an alternative, which
    792             # would enable us to run the underlying graph if needed.
--> 793             outputs = self._symbolic_call(inputs)
    794 
    795           if outputs is None:

c:\users\vilhelm\appdata\local\programs\python\python36\lib\site-packages\tensorflow_core\python\keras\engine\base_layer.py in _symbolic_call(self, inputs)
   2126   def _symbolic_call(self, inputs):
   2127     input_shapes = nest.map_structure(lambda x: x.shape, inputs)
-> 2128     output_shapes = self.compute_output_shape(input_shapes)
   2129 
   2130     def _make_placeholder_like(shape):

c:\users\vilhelm\appdata\local\programs\python\python36\lib\site-packages\tensorflow_core\python\keras\utils\tf_utils.py in wrapper(instance, input_shape)
    304     if input_shape is not None:
    305       input_shape = convert_shapes(input_shape, to_tuples=True)
--> 306     output_shape = fn(instance, input_shape)
    307     # Return shapes from `fn` as TensorShapes.
    308     if output_shape is not None:

c:\users\vilhelm\appdata\local\programs\python\python36\lib\site-packages\tensorflow_core\python\keras\layers\core.py in compute_output_shape(self, input_shape)
    808       with context.eager_mode():
    809         try:
--> 810           return super(Lambda, self).compute_output_shape(input_shape)
    811         except NotImplementedError:
    812           raise NotImplementedError(

c:\users\vilhelm\appdata\local\programs\python\python36\lib\site-packages\tensorflow_core\python\keras\engine\base_layer.py in compute_output_shape(self, input_shape)
    552           try:
    553             if self._expects_training_arg:
--> 554               outputs = self(inputs, training=False)
    555             else:
    556               outputs = self(inputs)

... last 5 frames repeated, from the frame below ...

c:\users\vilhelm\appdata\local\programs\python\python36\lib\site-packages\tensorflow_core\python\keras\engine\base_layer.py in __call__(self, inputs, *args, **kwargs)
    791             # TODO(fchollet): consider py_func as an alternative, which
    792             # would enable us to run the underlying graph if needed.
--> 793             outputs = self._symbolic_call(inputs)
    794 
    795           if outputs is None:

RecursionError: maximum recursion depth exceeded while calling a Python object

So that doesn't actually tell me anything about why it didn't work, it just crashed.

I'd prefer not to have to fiddle around with building a layer class inhering from Layer if there's any way to get the less complicated method to work.

1 answer

  • answered 2020-02-16 15:56 Daniel Möller

    Always work with "tensor functions" not with "Python functions":

    import keras.backend as K
    
    def image_scrambler(inp): #inp = [im1, im2, aux_input]
        im1, im2, aux_input = inp[0],inp[1],inp[2]
    
        is_greater = K.greater(aux_input, 0.5)
        return K.switch(is_greater,                     #this is a keras "if"
                        K.concatenate([img2, img1]),    #result if true
                        K.concatenate([img1, img2]))    #result if false
    
    paired_images = Lambda(image_scrambler)([image_input, decoder, aux_input])
    

    I don't think the assert would be a good idea, you should do this when checking your data, not in the model.

    Although you say it's not trainable, you might be expecting it to be trainable somehow? What decides what the value of aux_input will be? If you expect it to be learned elsewhere I'm in doubt it will work. Maybe it should be a continuous value given by a sigmoid somewhere. Then it "might" have a chance to work, although the if part sort of disrupts (but doesn't break) backpropagation.