How do I split Tensorflow datasets?
Asked Answered
R

3

42

I have a tensorflow dataset based on one .tfrecord file. How do I split the dataset into test and train datasets? E.g. 70% Train and 30% test?

Edit:

My Tensorflow Version: 1.8 I've checked, there is no "split_v" function as mentioned in the possible duplicate. Also I am working with a tfrecord file.

Reiser answered 1/7, 2018 at 17:0 Comment(4)
Possible duplicate of Split inputs into training and test setsVicky
Does this answer your question? Split a dataset created by Tensorflow dataset API in to Train and Test?Odawa
The question was already answered years ago, but thanks for the linkReiser
Also related: #54519809Beaverbrook
V
57

You may use Dataset.take() and Dataset.skip():

train_size = int(0.7 * DATASET_SIZE)
val_size = int(0.15 * DATASET_SIZE)
test_size = int(0.15 * DATASET_SIZE)

full_dataset = tf.data.TFRecordDataset(FLAGS.input_file)
full_dataset = full_dataset.shuffle()
train_dataset = full_dataset.take(train_size)
test_dataset = full_dataset.skip(train_size)
val_dataset = test_dataset.skip(test_size)
test_dataset = test_dataset.take(test_size)

For more generality, I gave an example using a 70/15/15 train/val/test split but if you don't need a test or a val set, just ignore the last 2 lines.

Take:

Creates a Dataset with at most count elements from this dataset.

Skip:

Creates a Dataset that skips count elements from this dataset.

You may also want to look into Dataset.shard():

Creates a Dataset that includes only 1/num_shards of this dataset.

Vicky answered 1/7, 2018 at 20:40 Comment(6)
Isn't there a randomness issue here if the dataset is much larger than the shuffle buffer size? Since samples are shuffled only within the (relatively) small buffer, this means approximately the first 70% of samples will be the training set, next 15% will be the test set, etc. If the data is ordered somehow this would introduce bias into the training results. Probably the solution is to shard the data, then shuffle it, then split it.Ordinate
I agree. Good comment. I suppose most use cases would then simply need to shuffle the whole dataset at once but to be truly scalable you're rightVicky
Note that skip actually iterates over the dataset so it can cause big latency on large datasetAdrenalin
I don't recommend this as train and test sets are not disjoint: it happens that the test set contains elements of the training setDocumentary
You'll want to set shuffle(reshuffle_each_iteration=False). Without that, each time a new iteration of training starts, items from the training set and validation set will get shuffled into each-other.Inessive
Does take(), skip() happen randomly? so in RNN you wouldn't split dataset using take() and skip()?Gaziantep
B
44

This question is similar to this one and this one, and I am afraid we have not had a satisfactory answer yet.

  • Using take() and skip() requires knowing the dataset size. What if I don't know that, or don't want to find out?

  • Using shard() only gives 1 / num_shards of dataset. What if I want the rest?

I try to present a better solution below, tested on TensorFlow 2 only. Assuming you already have a shuffled dataset, you can then use filter() to split it into two:

import tensorflow as tf

all = tf.data.Dataset.from_tensor_slices(list(range(1, 21))) \
        .shuffle(10, reshuffle_each_iteration=False)

test_dataset = all.enumerate() \
                    .filter(lambda x,y: x % 4 == 0) \
                    .map(lambda x,y: y)

train_dataset = all.enumerate() \
                    .filter(lambda x,y: x % 4 != 0) \
                    .map(lambda x,y: y)

for i in test_dataset:
    print(i)

print()

for i in train_dataset:
    print(i)

The parameter reshuffle_each_iteration=False is important. It makes sure the original dataset is shuffled once and no more. Otherwise, the two resulting sets may have some overlaps.

Use enumerate() to add an index.

Use filter(lambda x,y: x % 4 == 0) to take 1 sample out of 4. Likewise, x % 4 != 0 takes 3 out of 4.

Use map(lambda x,y: y) to strip the index and recover the original sample.

This example achieves a 75/25 split.

x % 5 == 0 and x % 5 != 0 gives a 80/20 split.

If you really want a 70/30 split, x % 10 < 3 and x % 10 >= 3 should do.

UPDATE:

As of TensorFlow 2.0.0, above code may result in some warnings due to AutoGraph's limitations. To eliminate those warnings, declare all lambda functions separately:

def is_test(x, y):
    return x % 4 == 0

def is_train(x, y):
    return not is_test(x, y)

recover = lambda x,y: y

test_dataset = all.enumerate() \
                    .filter(is_test) \
                    .map(recover)

train_dataset = all.enumerate() \
                    .filter(is_train) \
                    .map(recover)

This gives no warning on my machine. And making is_train() to be not is_test() is definitely a good practice.

Botryoidal answered 18/10, 2019 at 13:51 Comment(10)
Nice answer. Suggestion: caching (with .cache()) each subsample will prevent for tensorflow performing a full iteration each time (the first full iteration for each subset seems unavoidable).Orethaorferd
I would assume this reads the entire dataset once, but are the test_dataset and train_dataset variables then also DataSets that are iterable, or are they fully loaded in memory? Asking because I have a large file that I would prefer to not load completely in RAM.Prophase
@Tominator, I am not sure. The way my example is set up, test_dataset being read in full before train_dataset is read, train_dataset has to be fully stored in RAM for some time, especially because I tell it to shuffle only once. But, what if the reading is controlled so that test_dataset is read once for every 3 time train_dataset is read? That way, data does not have to be fully stored in RAM. Is that the actual implementation? I suspect so. TF dataset (and this kind of data-pulling API in general) is designed precisely to deal with huge dataset ........Botryoidal
However, without inspecting source code, I cannot confirm my suspicion.Botryoidal
Next question is, what test can we do to find out?Botryoidal
You should not use all as a variable name as it overrides Python's built-in all(enum) functionVicky
It would be good to mention that a 70/20/10% split for train/val/test datasets are possible too with modulo 7. test_dataset = dataset.enumerate().filter(lambda x,y: x%10==7).map(lambda x,y: y) val_dataset = dataset.enumerate().filter(lambda x,y: x%10>7).map(lambda x,y: y) train_dataset = dataset.enumerate().filter(lambda x,y: x%10<7).map(lambda x,y: y)Chickpea
@JavierJC where do you put the .cache() call? could you post a code snippet with the other train or test dataset derivation?Waltner
One unfortunate side effect of .filter is that it breaks cardinality checksBuonomo
The autograph warnings can also be resolved by giving a different name to x and y for each lambda on the same line. If you want to keep the lambda.Jardena
O
1

I will first explain why the accepted answer is wrong and secondly will provide a simple working solution, using take(), skip() and seed.

When working with pipelines, such as TF/Torch Datasets, beware of lazy evaluation. Avoid:

# DONT
full_dataset = full_dataset.shuffle(10)
train_dataset = full_dataset.take(train_size)
test_dataset = full_dataset.skip(train_size)

because take and skip will synchronize to single shuffle, but rather gets executed as shuffle+take and shuffle+skip separately (yes !), overlapping typically in 80%*20%=16% of cases. So, information leak.

Play with this code in case of doubt

import tensorflow as tf

def gen_data():
    return iter(range(10))

full_dataset = tf.data.Dataset.from_generator(
  gen_data, 
  output_signature=tf.TensorSpec(shape=(),dtype=tf.int32,name="element"))

train_size = 8

# WRONG WAY
full_dataset = full_dataset.shuffle(10)
train_dataset = full_dataset.take(train_size)
test_dataset = full_dataset.skip(train_size)

A = set(train_dataset.as_numpy_iterator())
B = set(test_dataset.as_numpy_iterator())

# EXPECT OVERLAP
assert A.intersection(B)==set()

print(list(A))
print(list(B))

Now, what works is repeating and seeding shuffle in both train and test datasets, which is also good for reproducibility. This should work with any deterministically ordered iterator:

import tensorflow as tf

def gen_data():
    return iter(range(10))

ds = tf.data.Dataset.from_generator(
    gen_data, 
    output_signature=tf.TensorSpec(shape=(),dtype=tf.int32,name="element"))

SEED = 42 # NOTE: change this

ds_train = ds.shuffle(100,seed=SEED).take(8).shuffle(100)
ds_test  = ds.shuffle(100,seed=SEED).skip(8)

A = set(ds_train.as_numpy_iterator())
B = set(ds_test.as_numpy_iterator())

assert A.intersection(B)==set()

print(list(A))
print(list(B))

By playing with SEED you can for instance inspect/estimate generalization (bootstraping in place of cross-validation).

Obara answered 27/10, 2023 at 14:44 Comment(3)
Then, do we just need to use the same seed? Or do we also need to use iterators?Agram
what about: full_dataset = full_dataset.shuffle(10, seed=SEED, reshuffle_each_iteration=False) train_dataset = full_dataset.take(train_size) test_dataset = full_dataset.skip(train_size)Trier
full_dataset = full_dataset.shuffle(10, seed=SEED, reshuffle_each_iteration=False) train_dataset = full_dataset.take(train_size).shuffle(10) test_dataset = full_dataset.skip(train_size)Trier

© 2022 - 2024 — McMap. All rights reserved.