I understand Dataset API is a sort of a iterator which does not load the entire dataset into memory, because of which it is unable to find the size of the Dataset. I am talking in the context of large corpus of data that is stored in text files or tfRecord files. These files are generally read using tf.data.TextLineDataset
or something similar. It is trivial to find the size of dataset loaded using tf.data.Dataset.from_tensor_slices
.
The reason I am asking the size of the Dataset is the following: Let's say my Dataset size is 1000 elements. Batch size = 50 elements. Then training steps/batches (assuming 1 epoch) = 20. During these 20 steps, I would like to exponentially decay my learning rate from 0.1 to 0.01 as
tf.train.exponential_decay(
learning_rate = 0.1,
global_step = global_step,
decay_steps = 20,
decay_rate = 0.1,
staircase=False,
name=None
)
In the above code, I have "and" would like to set decay_steps = number of steps/batches per epoch = num_elements/batch_size
. This can be calculated only if the number of elements in the dataset is known in advance.
Another reason to know the size in advance is to split the data into train and test sets using tf.data.Dataset.take()
, tf.data.Dataset.skip()
methods.
PS: I am not looking for brute-force approaches like iterating through the whole dataset and updating a counter to count the number of elements or putting a very large batch size and then finding the size of the resultant dataset, etc.
dataset.__len__()
does not work for a dataset made withtf.data.TextLineDataset
. Error isTypeError: dataset length is unknown.
– Vick