What does batch, repeat, and shuffle do with TensorFlow Dataset?
Asked Answered
U

4

102

I'm currently learning TensorFlow but I came across a confusion in the below code snippet:

dataset = dataset.shuffle(buffer_size = 10 * batch_size) 
dataset = dataset.repeat(num_epochs).batch(batch_size)
return dataset.make_one_shot_iterator().get_next()

I know that first the dataset will hold all the data but what shuffle(),repeat(), and batch() do to the dataset? Please help me with an example and explanation.

Uglify answered 28/11, 2018 at 7:47 Comment(0)
A
163

Update: Here is a small collaboration notebook for demonstration of this answer.


Imagine, you have a dataset: [1, 2, 3, 4, 5, 6], then:

How ds.shuffle() works

dataset.shuffle(buffer_size=3) will allocate a buffer of size 3 for picking random entries. This buffer will be connected to the source dataset. We could image it like this:

Random buffer
   |
   |   Source dataset where all other elements live
   |         |
   ↓         ↓
[1,2,3] <= [4,5,6]

Let's assume that entry 2 was taken from the random buffer. Free space is filled by the next element from the source buffer, that is 4:

2 <= [1,3,4] <= [5,6]

We continue reading till nothing is left:

1 <= [3,4,5] <= [6]
5 <= [3,4,6] <= []
3 <= [4,6]   <= []
6 <= [4]     <= []
4 <= []      <= []

How ds.repeat() works

As soon as all the entries are read from the dataset and you try to read the next element, the dataset will throw an error. That's where ds.repeat() comes into play. It will re-initialize the dataset, making it again like this:

[1,2,3] <= [4,5,6]

What will ds.batch() produce

The ds.batch() will take the first batch_size entries and make a batch out of them. So, a batch size of 3 for our example dataset will produce two batch records:

[2,1,5]
[3,6,4]

As we have a ds.repeat() before the batch, the generation of the data will continue. But the order of the elements will be different, due to the ds.random(). What should be taken into account is that 6 will never be present in the first batch, due to the size of the random buffer.

Atheism answered 28/11, 2018 at 10:57 Comment(17)
what if i don't want to shuffle the data because data is timeseries, can i still use repeat and batchsize without shuffle?Polly
@alily, yes. That would be an option. Another option is to make each batch record to represent a separate timeseries record. This way you could benefit from using shuffle().Atheism
@Atheism therefore what is the correct order for using these three operators?Bib
@Seymour: order is ds.shuffle(...).repeat().batch(..). At least for TensorFlow 2.1.0.Atheism
and why not ds.shuffle(reshuffle_each_iteration=True).batch(...).repeat(...) ?Bib
@Bib : this is also possible - tensorflow.org/tutorials/structured_data/time_seriesHawserlaid
I don't get why 6 will never be present in first batch? Why not? Isn't batches are taken randomly? So maybe first batch would be [2, 3, 6], for example.Timothy
Hi @Atheism thanks for your answer. Should we always use repeat when we create a dataset?Cholera
why 6 will never be present in the first batch?Repent
@DingxinXu, because shuffle buffer has size 3. In order for number 6 to get into the shuffle buffer, we need to eject numbers 1,2,3. They will be ejected in the first batch (of size 3). I've also added a link to the collaboration notebook to show this effect.Atheism
@AshwaniK you would probably add the repeat if you would write your own training loop. If you are using high level API (like Keras, or Estimators) then you don't need to add the repeat(), because these APIs will do it for you.Atheism
@Rou batches are not taken randomly. Please, have a look at the collaboration notebook (I've just attached a link to it in the answer). Batches are sequential. If you remove the shuffle() and repeat() operation from the example above, you'll always get batches: [1, 2, 3] and [4, 5, 6] in the same exact order. Batches look like random only after adding these operations.Atheism
@Vlad-HC. Thanks it makes sense. I tried using repeat with dataset and Keras and it fetched all data in one epoch itself :DCholera
@DingxinXu 6 is always later than 5 in the queue of entering the first batch. Regardless of whether 5 is selected into the first batch, even it is, the 3rd candidate of the 1st batch has chosen (randomly) right before 6 entering the random shuffle buffer (of size 3 in this case).Micturition
@RadwaKhattab In the first place there can be only 1, 2 or 3, in the second place there can be also 4, and in the third place also 5 as they were added to the buffer until then, 6 can be no t before the third item.Alethaalethea
Google search "What does TF dataset shuffle do?" shuffle( buffer_size, seed=None, reshuffle_each_iteration=None) The method shuffles the samples in the dataset. The buffer_size is the number of samples which are randomized and returned as tf. Dataset .Microminiaturization
Do that batch(batch_size) need to be the same size than the batch size provided to model.fit() and tuner.search()?Lolly
E
11

The following methods in tf.Dataset :

  1. repeat( count=0 ) The method repeats the dataset count number of times.
  2. shuffle( buffer_size, seed=None, reshuffle_each_iteration=None) The method shuffles the samples in the dataset. The buffer_size is the number of samples which are randomized and returned as tf.Dataset.
  3. batch(batch_size,drop_remainder=False) Creates batches of the dataset with batch size given as batch_size which is also the length of the batches.
Exciter answered 28/11, 2018 at 10:53 Comment(2)
Thank you. I was confused that tensorflow.keras.preprocessing.timeseries_dataset_from_array() didn't have a drop_remainder argument.Breton
Do that batch(batch_size) need to be the same size than the batch_size provided to model.fit() and tuner.search()?Lolly
D
2

Batch

Combines consecutive elements of the dataset into groups (batches):

without batching

dataset = tf.data.Dataset.range(10)
for i in dataset:
    print(i.numpy())

Output:

0
1
2
3
4
5

with batching

dataset = tf.data.Dataset.range(10)
for i in dataset.batch(2):
    print(i.numpy())

Output:

[0 1]
[2 3]
[4 5]

Shuffle

Randomly shuffles the input data. According to the docs, the Dataset.shuffle() transformation maintains a fixed-size buffer and chooses the next element uniformly at random from that buffer. I could not understand the result of chained shuffle and batch when we use only part of the data in the buffer. Why do we get values greater than 19 in the first batch when we only have the first 20 values in the shuffle?

dataset = tf.data.Dataset.range(100)
dataset = dataset.shuffle(20).batch(10)
print(next(iter(dataset)).numpy())

Output:

[ 6  3 13 18 20 21  5  0  2 15]

It looks like after the batch has fetched a single value from the buffer, the next value (20, 21, 22...) jumps into the buffer and the batch can select this as its next value. In this way, in the first batch, we will get 10 values ranging from 0 to 29.

Repeat

According to the Tensorflow documentation, repeat is used to iterate over a dataset in multiple epochs (epoch is a complete dataset). In other words, it simply replicates the input data.

dataset = tf.data.Dataset.range(5) # [0 1 2 3 4]
for i in dataset.repeat(2).batch(3):
    print(i.numpy())

Output:

[0 1 2]
[3 4 0]
[1 2 3]
[4]

If we don't want to mix data from different epochs in one batch, we need to put repeat after batch.

dataset = tf.data.Dataset.range(5)
for i in dataset.batch(3).repeat(2):
    print(i.numpy())

Output:

[0 1 2]
[3 4]
[0 1 2]
[3 4]
Distortion answered 26/6, 2023 at 9:27 Comment(2)
Do you need to use repeat(2).batch(3) or batch(3).repeat(2)?Lolly
@Lolly It depends on whether you want to mix data from different epochs or not (see above).Distortion
W
0

An example that shows looping over epochs. Upon running this script notice the difference in

  • dataset_gen1 - shuffle operation produces more random outputs (this may be more useful while running machine learning experiments)
  • dataset_gen2 - lack of shuffle operation produces elements in sequence

Other additions in this script

  • tf.data.experimental.sample_from_datasets - used to combine two datasets. Note that the shuffle operation in this case shall create a buffer that samples equally from both datasets.
import os

os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3" # to avoid all those prints
os.environ["TF_GPU_THREAD_MODE"] = "gpu_private" # to avoid large "Kernel Launch Time"

import tensorflow as tf
if len(tf.config.list_physical_devices('GPU')):
    tf.config.experimental.set_memory_growth(tf.config.list_physical_devices('GPU')[0], True)

class Augmentations:

    def __init__(self):
        pass

    @tf.function
    def filter_even(self, x):
        if x % 2 == 0:
            return False
        else:
            return True

class Dataset:

    def __init__(self, aug, range_min=0, range_max=100):
        self.range_min = range_min
        self.range_max = range_max
        self.aug = aug

    def generator(self):
        dataset = tf.data.Dataset.from_generator(self._generator
                        , output_types=(tf.float32), args=())

        dataset = dataset.filter(self.aug.filter_even)

        return dataset
    
    def _generator(self):
        for item in range(self.range_min, self.range_max):
            yield(item)

# Can be used when you have multiple datasets that you wish to combine
class ZipDataset:

    def __init__(self, datasets):
        self.datasets = datasets
        self.datasets_generators = []
    
    def generator(self):
        for dataset in self.datasets:
            self.datasets_generators.append(dataset.generator())
        return tf.data.experimental.sample_from_datasets(self.datasets_generators)

if __name__ == "__main__":
    aug = Augmentations()
    dataset1 = Dataset(aug, 0, 100)
    dataset2 = Dataset(aug, 100, 200)
    dataset = ZipDataset([dataset1, dataset2])

    epochs = 2
    shuffle_buffer = 10
    batch_size = 4
    prefetch_buffer = 5

    dataset_gen1 = dataset.generator().shuffle(shuffle_buffer).batch(batch_size).prefetch(prefetch_buffer)
    # dataset_gen2 = dataset.generator().batch(batch_size).prefetch(prefetch_buffer) # this will output odd elements in sequence 

    for epoch in range(epochs):
        print ('\n ------------------ Epoch: {} ------------------'.format(epoch))
        for X in dataset_gen1.repeat(1): # adding .repeat() in the loop allows you to easily control the end of the loop
            print (X)
        
        # Do some stuff at end of loop
Watteau answered 20/11, 2020 at 13:37 Comment(0)

© 2022 - 2024 — McMap. All rights reserved.