Cannot import weights for 'block1_conv1/kernel' of pre-trained VGG-19 model when implementing perceptual loss in Niftynet repo

I am trying to implement perceptual loss in Niftynet.

Unfortunately, whenever I run the following command to train the model net_run train -a niftynet.contrib.regression_weighted_sampler.isample_regression.ISampleRegression -c ~/niftynet/extensions/mr_ct_regression/net_isampler.ini, I get the following error output:

  File "/home_local/rajeevd/miniconda3/envs/ML_research/bin/net_run", line 8, in <module>
    sys.exit(main())
  File "/home_local/rajeevd/miniconda3/envs/ML_research/lib/python3.7/site-packages/niftynet/__init__.py", line 149, in main
    app_driver.run(app_driver.app)
  File "/home_local/rajeevd/miniconda3/envs/ML_research/lib/python3.7/site-packages/niftynet/engine/application_driver.py", line 191, in run
    is_training_action=self.is_training_action)
  File "/home_local/rajeevd/miniconda3/envs/ML_research/lib/python3.7/site-packages/niftynet/engine/application_driver.py", line 272, in create_graph
    outputs_collector, gradients_collector)
  File "/home_local/rajeevd/miniconda3/envs/ML_research/lib/python3.7/site-packages/niftynet/contrib/regression_weighted_sampler/isample_regression.py", line 89, in connect_data_and_network
    self, outputs_collector, gradients_collector)
  File "/home_local/rajeevd/miniconda3/envs/ML_research/lib/python3.7/site-packages/niftynet/application/regression_application.py", line 313, in connect_data_and_network
    weight_map=weight_map)
  File "/home_local/rajeevd/miniconda3/envs/ML_research/lib/python3.7/site-packages/niftynet/layer/base_layer.py", line 36, in __call__
    return self._op(*args, **kwargs)
  File "/home_local/rajeevd/miniconda3/envs/ML_research/lib/python3.7/site-packages/tensorflow/python/ops/template.py", line 385, in __call__
    return self._call_func(args, kwargs)
  File "/home_local/rajeevd/miniconda3/envs/ML_research/lib/python3.7/site-packages/tensorflow/python/ops/template.py", line 355, in _call_func
    result = self._func(*args, **kwargs)
  File "/home_local/rajeevd/miniconda3/envs/ML_research/lib/python3.7/site-packages/niftynet/layer/loss_regression.py", line 113, in layer_op
    parallel_iterations=1)
  File "/home_local/rajeevd/miniconda3/envs/ML_research/lib/python3.7/site-packages/tensorflow/python/util/deprecation.py", line 602, in new_func
    return func(*args, **kwargs)
  File "/home_local/rajeevd/miniconda3/envs/ML_research/lib/python3.7/site-packages/tensorflow/python/util/deprecation.py", line 535, in new_func
    return func(*args, **kwargs)
  File "/home_local/rajeevd/miniconda3/envs/ML_research/lib/python3.7/site-packages/tensorflow/python/ops/map_fn.py", line 651, in map_fn_v2
    name=name)
  File "/home_local/rajeevd/miniconda3/envs/ML_research/lib/python3.7/site-packages/tensorflow/python/util/deprecation.py", line 535, in new_func
    return func(*args, **kwargs)
  File "/home_local/rajeevd/miniconda3/envs/ML_research/lib/python3.7/site-packages/tensorflow/python/ops/map_fn.py", line 507, in map_fn
    maximum_iterations=n)
  File "/home_local/rajeevd/miniconda3/envs/ML_research/lib/python3.7/site-packages/tensorflow/python/ops/control_flow_ops.py", line 2816, in while_loop
    return_same_structure)
  File "/home_local/rajeevd/miniconda3/envs/ML_research/lib/python3.7/site-packages/tensorflow/python/ops/control_flow_ops.py", line 2298, in BuildLoop
    pred, body, original_loop_vars, loop_vars, shape_invariants)
  File "/home_local/rajeevd/miniconda3/envs/ML_research/lib/python3.7/site-packages/tensorflow/python/ops/control_flow_ops.py", line 2223, in _BuildLoop
    body_result = body(*packed_vars_for_body)
  File "/home_local/rajeevd/miniconda3/envs/ML_research/lib/python3.7/site-packages/tensorflow/python/ops/control_flow_ops.py", line 2768, in <lambda>
    body = lambda i, lv: (i + 1, orig_body(*lv))
  File "/home_local/rajeevd/miniconda3/envs/ML_research/lib/python3.7/site-packages/tensorflow/python/ops/map_fn.py", line 491, in compute
    result_value = autographed_fn(elems_value)
  File "/home_local/rajeevd/miniconda3/envs/ML_research/lib/python3.7/site-packages/tensorflow/python/autograph/impl/api.py", line 695, in wrapper
    raise e.ag_error_metadata.to_exception(e)
tensorflow.python.autograph.pyct.error_utils.KeyError: in user code:

    /home_local/rajeevd/miniconda3/envs/ML_research/lib/python3.7/site-packages/niftynet/layer/loss_regression.py:102 _batch_i_loss  *
        return tf.cast(self._data_loss_func(**loss_params), tf.float32)
    /home_local/rajeevd/miniconda3/envs/ML_research/lib/python3.7/site-packages/tensorflow/python/eager/def_function.py:889 __call__  **
        result = self._call(*args, **kwds)
    /home_local/rajeevd/miniconda3/envs/ML_research/lib/python3.7/site-packages/tensorflow/python/eager/def_function.py:957 _call
        filtered_flat_args, self._concrete_stateful_fn.captured_inputs)  # pylint: disable=protected-access
    /home_local/rajeevd/miniconda3/envs/ML_research/lib/python3.7/site-packages/tensorflow/python/eager/function.py:1974 _call_flat
        flat_outputs = forward_function.call(ctx, args_with_tangents)
    /home_local/rajeevd/miniconda3/envs/ML_research/lib/python3.7/site-packages/tensorflow/python/eager/function.py:625 call
        executor_type=executor_type)
    /home_local/rajeevd/miniconda3/envs/ML_research/lib/python3.7/site-packages/tensorflow/python/ops/functional_ops.py:1221 partitioned_call
        op = graph.create_op(op_name, args, tout, name=op_name, attrs=op_attrs)
    /home_local/rajeevd/miniconda3/envs/ML_research/lib/python3.7/site-packages/tensorflow/python/util/deprecation.py:535 new_func
        return func(*args, **kwargs)
    /home_local/rajeevd/miniconda3/envs/ML_research/lib/python3.7/site-packages/tensorflow/python/framework/ops.py:3496 create_op
        attrs, op_def, compute_device)
    /home_local/rajeevd/miniconda3/envs/ML_research/lib/python3.7/site-packages/tensorflow/python/framework/ops.py:3565 _create_op_internal
        op_def=op_def)
    /home_local/rajeevd/miniconda3/envs/ML_research/lib/python3.7/site-packages/tensorflow/python/framework/ops.py:2063 __init__
        self._control_flow_post_processing(input_tensors=inputs)
    /home_local/rajeevd/miniconda3/envs/ML_research/lib/python3.7/site-packages/tensorflow/python/framework/ops.py:2081 _control_flow_post_processing
        self._control_flow_context.AddOp(self)
    /home_local/rajeevd/miniconda3/envs/ML_research/lib/python3.7/site-packages/tensorflow/python/ops/control_flow_ops.py:1765 AddOp
        self._AddOpInternal(op)
    /home_local/rajeevd/miniconda3/envs/ML_research/lib/python3.7/site-packages/tensorflow/python/ops/control_flow_ops.py:1777 _AddOpInternal
        if not op.inputs:
    /home_local/rajeevd/miniconda3/envs/ML_research/lib/python3.7/site-packages/tensorflow/python/framework/ops.py:2393 inputs
        pywrap_tf_session.GetOperationInputs(self._c_op)))
    /home_local/rajeevd/miniconda3/envs/ML_research/lib/python3.7/site-packages/tensorflow/python/framework/ops.py:3946 _get_tensor_by_tf_output
        op = self._get_operation_by_tf_operation(tf_output.oper)
    /home_local/rajeevd/miniconda3/envs/ML_research/lib/python3.7/site-packages/tensorflow/python/framework/ops.py:3910 _get_operation_by_tf_operation
        return self._get_operation_by_name_unsafe(op_name)
    /home_local/rajeevd/miniconda3/envs/ML_research/lib/python3.7/site-packages/tensorflow/python/framework/ops.py:3906 _get_operation_by_name_unsafe
        return self._nodes_by_name[name]

    KeyError: 'block1_conv1/kernel'

My only relevant change to the repo is that I replaced Niftynet's RMSE loss function with the following code below.

vgg.trainable = False
outputs = vgg.get_layer('block4_conv2').output
model = tf.keras.Model([vgg.input], outputs)
for layer in model.layers:
    layer.trainable = False
@tf.function
def rmse_loss(prediction, ground_truth, weight_map=None):
    prediction = tf.reshape(prediction, shape=[-1, 320, 320, 3])
    ground_truth = tf.reshape(ground_truth, shape=[-1, 320, 320, 3])
    prediction = tf.keras.applications.vgg19.preprocess_input(prediction)
    ground_truth = tf.keras.applications.vgg19.preprocess_input(ground_truth)
    h1_list = model(prediction)
    h2_list = model(ground_truth)
    rc_loss = 0.0
    h1 = K.batch_flatten(h1_list)
    h2 = K.batch_flatten(h2_list)
    rc_loss = rc_loss + 1.0 * K.sum(K.square(h1 - h2), axis=-1)
    return rc_loss

Also, when I call vgg.summary(), I get:

Model: "vgg19"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
=================================================================
input_1 (InputLayer)         [(None, 320, 320, 3)]     0         
_________________________________________________________________
block1_conv1 (Conv2D)        (None, 320, 320, 64)      1792      
_________________________________________________________________
block1_conv2 (Conv2D)        (None, 320, 320, 64)      36928     
_________________________________________________________________
block1_pool (MaxPooling2D)   (None, 160, 160, 64)      0         
_________________________________________________________________
block2_conv1 (Conv2D)        (None, 160, 160, 128)     73856     
_________________________________________________________________
block2_conv2 (Conv2D)        (None, 160, 160, 128)     147584    
_________________________________________________________________
block2_pool (MaxPooling2D)   (None, 80, 80, 128)       0         
_________________________________________________________________
block3_conv1 (Conv2D)        (None, 80, 80, 256)       295168    
_________________________________________________________________
block3_conv2 (Conv2D)        (None, 80, 80, 256)       590080    
_________________________________________________________________
block3_conv3 (Conv2D)        (None, 80, 80, 256)       590080    
_________________________________________________________________
block3_conv4 (Conv2D)        (None, 80, 80, 256)       590080    
_________________________________________________________________
block3_pool (MaxPooling2D)   (None, 40, 40, 256)       0         
_________________________________________________________________
block4_conv1 (Conv2D)        (None, 40, 40, 512)       1180160   
_________________________________________________________________
block4_conv2 (Conv2D)        (None, 40, 40, 512)       2359808   
_________________________________________________________________
block4_conv3 (Conv2D)        (None, 40, 40, 512)       2359808   
_________________________________________________________________
block4_conv4 (Conv2D)        (None, 40, 40, 512)       2359808   
_________________________________________________________________
block4_pool (MaxPooling2D)   (None, 20, 20, 512)       0         
_________________________________________________________________
block5_conv1 (Conv2D)        (None, 20, 20, 512)       2359808   
_________________________________________________________________
block5_conv2 (Conv2D)        (None, 20, 20, 512)       2359808   
_________________________________________________________________
block5_conv3 (Conv2D)        (None, 20, 20, 512)       2359808   
_________________________________________________________________
block5_conv4 (Conv2D)        (None, 20, 20, 512)       2359808   
_________________________________________________________________
block5_pool (MaxPooling2D)   (None, 10, 10, 512)       0         
=================================================================
Total params: 20,024,384
Trainable params: 0
Non-trainable params: 20,024,384

So, at least from my perspective, it looks like the parameters for 'block1_conv1' are there. Does anyone know what might be causing this key error and how to fix it?

Thank you for your time!

How many English words
do you know?
Test your English vocabulary size, and measure
how many words do you know
Online Test
Powered by Examplum