Cannot fit keras model which uses tf.gather (Gradients do not exist)

I'm trying to put together a keras model in tensorflow 2 that involves using gather to multiply each input by multiple different model parameters. For instance, in the below layer, if the inputs are [i1, i2] and the model weights are [w1, w2], then the outputs will be [i1 * w1 + i2 * w2, i1 * w2 + i2 * w1]:

import tensorflow as tf

class GatherLayer(tf.keras.layers.Layer):
    def __init__(self, activation='linear', **kwargs):
        super(GatherLayer, self).__init__(**kwargs)
        self.activation = tf.keras.activations.get(activation)

    def build(self, _):
        self.w = self.add_weight(
                shape=(2,), name='weights', initializer='random_normal', trainable=True)
        self.b = self.add_weight(
                shape=(1,), name='offset', initializer='random_normal', trainable=True)
        self.m = tf.concat(
                [tf.reshape(tf.gather(self.w, ix), (-1, 1)) for ix in [[0, 1], [1, 0]]], axis=-1)

    def call(self, inputs):
        lin = tf.matmul(inputs, self.m) + self.b
        return self.activation(lin)

    def get_config(self):
        config = super(GatherLayer, self).get_config()
        return config

I can construct this layer just fine, and compute gradients of the outputs with respect to the weights using GradientTape, but when I try to use it in a tf.keras.Sequential model I get warnings about missing gradients:

WARNING:tensorflow:Gradients do not exist for variables ['sequential/gather_layer/weights:0'] when minimizing the loss.
WARNING:tensorflow:Gradients do not exist for variables ['sequential/gather_layer/weights:0'] when minimizing the loss.

A very similar setup using tf.stack fails in the same way. It seems like tensorflow is having trouble computing gradients for these tensors in lazy mode, but I can't find anything about such a limitation in the documentation or elsewhere online. Anyone have any ideas?

Full example code:

def gen_data():
    x = [[0, 0],
         [0, 1],
         [1, 0],
         [1, 1]]
    y = [[1, 0],
         [0, 0],
         [1, 1],
         [0, 1]]
    while True:
        yield (x, y)

def make_dataset():
            gen_data, (tf.float32, tf.float32), output_shapes=([batchsize, 2], [batchsize, 2]))

mdl = tf.keras.Sequential([GatherLayer()])
mdl.compile(optimizer=tf.keras.optimizers.Adam(), loss='categorical_crossentropy')
ds = make_dataset(), steps_per_epoch=4, epochs=1)