how can we get benefit from sharding the data to speed the training time?

My main issue is : I have 204 GB training tfrecords for 2 million images, and 28GB for validation tf.records files, of 302900 images. it takes 8 hour to train one epoch and this will take 33 day for training. I want to speed that by using multiple threads and shards but I am little bit confused about couple of things.

In tf.data.Dataset API there is shard function , So in the documentation they mentioned the following about shard function :

Creates a Dataset that includes only 1/num_shards of this dataset.

This dataset operator is very useful when running distributed training, as it allows each worker to read a unique subset.

When reading a single input file, you can skip elements as follows:

d = tf.data.TFRecordDataset(FLAGS.input_file)
d = d.shard(FLAGS.num_workers, FLAGS.worker_index)
d = d.repeat(FLAGS.num_epochs)
d = d.shuffle(FLAGS.shuffle_buffer_size)
d = d.map(parser_fn, num_parallel_calls=FLAGS.num_map_threads)

Important caveats:

Be sure to shard before you use any randomizing operator (such as shuffle). Generally it is best if the shard operator is used early in the dataset pipeline. >For example, when reading from a set of TFRecord files, shard before converting >the dataset to input samples. This avoids reading every file on every worker. The >following is an example of an efficient sharding strategy within a complete >pipeline:

d = Dataset.list_files(FLAGS.pattern)
d = d.shard(FLAGS.num_workers, FLAGS.worker_index)
d = d.repeat(FLAGS.num_epochs)
d = d.shuffle(FLAGS.shuffle_buffer_size)
d = d.repeat()
d = d.interleave(tf.data.TFRecordDataset,
             cycle_length=FLAGS.num_readers, block_length=1)

d = d.map(parser_fn, num_parallel_calls=FLAGS.num_map_threads)

So my question regarding the code above is when I try to makes d.shards of my data using shard function, if I set the number of shards (num_workers)to 10 , I will have 10 splits of my data , then should I set the num_reader in d.interleave function to 10 to guarantee that each reader take one split from the 10 split?

and how I can control which split the function interleave will take? because if I set the shard_index (worker_index) in shard function to 1 it will give me the first split. Can anyone give me an idea how can I perform this distributed training using the above functions?

then what about the num_parallel_call . should I set it to 10 as well?

knowing that I have single tf.records file for training and another one for validation , I don't split the tf.records files into multiple files.