Split a dataset created by Tensorflow dataset API in to Train and Test?
Asked Answered
D

11

72

Does anyone know how to split a dataset created by the dataset API (tf.data.Dataset) in Tensorflow into Test and Train?

Darlinedarling answered 11/1, 2018 at 18:34 Comment(2)
take(), skip(), and shard() all have their own problems. I just posted my answer over here. I hope it better answers your question.Brickey
use Keras - model.fit(dataset,.., validation.split=0.7, ...) see its all possible argumentsOntina
L
96

Assuming you have all_dataset variable of tf.data.Dataset type:

test_dataset = all_dataset.take(1000) 
train_dataset = all_dataset.skip(1000)

Test dataset now has first 1000 elements and the rest goes for training.

Leisaleiser answered 5/5, 2018 at 3:10 Comment(3)
As also mentioned in ted's answer, adding all_dataset.shuffle() allows for a shuffled split. Possibly add as code comment in answer like so? # all_dataset = all_dataset.shuffle() # in case you want a shuffled splitStriation
TensorFlow 2.10.0 will have a utility function for splitting, see my answer: https://mcmap.net/q/273173/-split-a-dataset-created-by-tensorflow-dataset-api-in-to-train-and-testGunnel
take and skip return TfTakeDatasets/SkipDatasets which have less functionality than TfDatasets. Does anyone know how to map those to tfDatasets or split into train test splits and get back TfDataset objects?Sinuate
C
59

You may use Dataset.take() and Dataset.skip():

train_size = int(0.7 * DATASET_SIZE)
val_size = int(0.15 * DATASET_SIZE)
test_size = int(0.15 * DATASET_SIZE)

full_dataset = tf.data.TFRecordDataset(FLAGS.input_file)
full_dataset = full_dataset.shuffle()
train_dataset = full_dataset.take(train_size)
test_dataset = full_dataset.skip(train_size)
val_dataset = test_dataset.skip(val_size)
test_dataset = test_dataset.take(test_size)

For more generality, I gave an example using a 70/15/15 train/val/test split but if you don't need a test or a val set, just ignore the last 2 lines.

Take:

Creates a Dataset with at most count elements from this dataset.

Skip:

Creates a Dataset that skips count elements from this dataset.

You may also want to look into Dataset.shard():

Creates a Dataset that includes only 1/num_shards of this dataset.


Disclaimer I stumbled upon this question after answering this one so I thought I'd spread the love

Clowers answered 10/7, 2018 at 6:42 Comment(7)
Thank you very much @ted! Is there a way to divide the dataset in a stratified way? Or, alternatively, how can we have an idea of the class proportions (suppose a binary problem) after the train/val/test split? Thanks a lot in advance!Cutlery
Have a look at this blogpost I wrote; eventhough it's for multilabel datasets, should be easily usable for single label, multiclass datasets -> vict0rs.ch/2018/06/17/multilabel-text-classification-tensorflowClowers
This causes my train,validation and test datasets to have overlap between them. Is this supposed to happen and not a big deal? I would assume it's not a good idea to have the model train on validation and test data.Miasma
@c_student I had the same problem and I figured out what I was missing: when you shuffle use the option reshuffle_each_iteration=False otherwise elements could be repeated in train, test and valInky
This is very true @xdola, and in particular when using list_files you should use shuffle=False and then shuffle with the .shuffle with reshuffle_each_iteration=False.Salahi
@xdola, Thank you for your comment, a potential disaster avoided!Pride
With this answer, we end up with a TakeDataSet, which does not have all the same properties and methods as a Dataset.Gram
W
33

Most of the answers here use take() and skip(), which requires knowing the size of your dataset before hand. This isn't always possible, or is difficult/intensive to ascertain.

Instead what you can do is to essentially slice the dataset up so that 1 every N records becomes a validation record.

To accomplish this, lets start with a simple dataset of 0-9:

dataset = tf.data.Dataset.range(10)
# [0, 1, 2, 3, 4, 5, 6, 7, 8, 9]

Now for our example, we're going to slice it so that we have a 3/1 train/validation split. Meaning 3 records will go to training, then 1 record to validation, then repeat.

split = 3
dataset_train = dataset.window(split, split + 1).flat_map(lambda ds: ds)
# [0, 1, 2, 4, 5, 6, 8, 9]
dataset_validation = dataset.skip(split).window(1, split + 1).flat_map(lambda ds: ds)
# [3, 7]

So the first dataset.window(split, split + 1) says to grab split number (3) of elements, then advance split + 1 elements, and repeat. That + 1 effectively skips the 1 element we're going to use in our validation dataset.
The flat_map(lambda ds: ds) is because window() returns the results in batches, which we don't want. So we flatten it back out.

Then for the validation data we first skip(split), which skips over the first split number (3) of elements that were grabbed in the first training window, so we start our iteration on the 4th element. The window(1, split + 1) then grabs 1 element, advances split + 1 (4), and repeats.

 

Note on nested datasets:
The above example works well for simple datasets, but flat_map() will generate an error if the dataset is nested. To address this, you can swap out the flat_map() with a more complicated version that can handle both simple and nested datasets:

.flat_map(lambda *ds: ds[0] if len(ds) == 1 else tf.data.Dataset.zip(ds))
With answered 3/3, 2020 at 8:32 Comment(7)
Doesn't window just use skip under the hood? How does is the disadvantage The other disadvantage is that with skip() it has to read, and then discard, all the skipped records, which if your data source is slow means you might have a large spool-up time before results are emitted. adressed?Desmond
If you have a dataset of 1000 records, and you want a 10% for validation, you would have to skip the first 900 records before a single validation record is emitted. With this solution, it only has to skip 9 records. It does end up skipping the same amount overall, but if you use dataset.prefetch(), it can read in the background while doing other things. The difference is just saving the initial spool-up time.With
Thinking about it a bit more, and I removed the statement. There's probably a dozen ways to solve that problem, and it's probably minute, if present at all, for most people.With
You should probably set the without knowing the dataset size beforehand to boldface, or like a header or something, it's pretty important. This should really be the accepted answer, as it fits into the premise of tf.data.Dataset treating data like infinite streams.Desmond
One thing when I was trying this method was that RAM consumption was much higher than when using the method described by @ted. So much higher that I couldn't get it to run on my maschine at all. Maybe I'm doing something wrong, but what would be a feasible approach wenn I don't know the size of the dataset and also have data that doesn't fit into memory?Permenter
@Permenter This solution by itself won't cause any significant difference in memory usage from other solutions. If you're experiencing such, it's going to be because of how the rest of the pipeline is interacting with it. I suggest posting the full details in a new question. I use this method on data sets that are hundreds of GB without issue.With
Not sure if this good conduct on stackoverlfow, but here is the question I created for reference: #68274853 If it's not, let me know and I will delete my comment.Permenter
F
8

@ted's answer will cause some overlap. Try this.

train_ds_size = int(0.64 * full_ds_size)
valid_ds_size = int(0.16 * full_ds_size)

train_ds = full_ds.take(train_ds_size)
remaining = full_ds.skip(train_ds_size)  
valid_ds = remaining.take(valid_ds_size)
test_ds = remaining.skip(valid_ds_size)

use code below to test.

tf.enable_eager_execution()

dataset = tf.data.Dataset.range(100)

train_size = 20
valid_size = 30
test_size = 50

train = dataset.take(train_size)
remaining = dataset.skip(train_size)
valid = remaining.take(valid_size)
test = remaining.skip(valid_size)

for i in train:
    print(i)

for i in valid:
    print(i)

for i in test:
    print(i)
Frimaire answered 27/3, 2020 at 21:43 Comment(2)
I love how everyone assumes you know the full_ds_size but no one explains how to find itSinistrad
@Sinistrad len(list(dataset)) is the most straightforward #50737692 ...but... my understanding is that datasets can be extremely large (might not fit in memory) so iterating over them can take a very long time. It is probably best to figure out how large the dataset is based on external knowledge of the dataset.Speak
G
6

The upcoming TensorFlow 2.10.0 will have a tf.keras.utils.split_dataset function, see the rc3 release notes:

Added tf.keras.utils.split_dataset utility to split a Dataset object or a list/tuple of arrays into two Dataset objects (e.g. train/test).

Gunnel answered 3/9, 2022 at 11:18 Comment(1)
By the way, I found that using this separate split_dataset function makes the shuffling of image_dataset_from_directory re-iteration stable, yielding correctly ordered results of Model.predict later on. See discuss.tensorflow.org/t/…Gunnel
N
5

Now Tensorflow doesn't contain any tools for that.
You could use sklearn.model_selection.train_test_split to generate train/eval/test dataset, then create tf.data.Dataset respectively.

Nostradamus answered 12/3, 2018 at 8:35 Comment(1)
sklearn requires that stuff fits in memory, TF Data does not.Quindecennial
P
5

You can use shard:

dataset = dataset.shuffle()  # optional
trainset = dataset.shard(2, 0)
testset = dataset.shard(2, 1)

See: https://www.tensorflow.org/api_docs/python/tf/data/Dataset#shard

Paderewski answered 21/11, 2018 at 19:17 Comment(2)
shard is depricatedCannonball
@Cannonball are you sure? I don't see anything saying it is deprecated.Speak
S
0

In case size of the dataset is known:

from typing import Tuple
import tensorflow as tf

def split_dataset(dataset: tf.data.Dataset, 
                  dataset_size: int, 
                  train_ratio: float, 
                  validation_ratio: float) -> Tuple[tf.data.Dataset, tf.data.Dataset, tf.data.Dataset]:
    assert (train_ratio + validation_ratio) < 1

    train_count = int(dataset_size * train_ratio)
    validation_count = int(dataset_size * validation_ratio)
    test_count = dataset_size - (train_count + validation_count)

    dataset = dataset.shuffle(dataset_size)

    train_dataset = dataset.take(train_count)
    validation_dataset = dataset.skip(train_count).take(validation_count)
    test_dataset = dataset.skip(validation_count + train_count).take(test_count)

    return train_dataset, validation_dataset, test_dataset

Example:

size_of_ds = 1001
train_ratio = 0.6
val_ratio = 0.2

ds = tf.data.Dataset.from_tensor_slices(list(range(size_of_ds)))
train_ds, val_ds, test_ds = split_dataset(ds, size_of_ds, train_ratio, val_ratio)
Spectre answered 26/1, 2020 at 15:28 Comment(0)
F
0

A robust way to split dataset into two parts is to first deterministically map every item in the dataset into a bucket with, for example, tf.strings.to_hash_bucket_fast. Then you can split the dataset into two by filtering by the bucket. If you split your data into five buckets, you get 80-20 split assuming that the split is even.

As an example, assume that your dataset contains dictionaries with key filename. We split the data into five buckets based on this key. With this add_fold function, we add the key "fold" in the dictionaries:

def add_fold(buckets: int):
    def add_(sample, label):
        fold = tf.strings.to_hash_bucket(sample["filename"], num_buckets=buckets)
        return {**sample, "fold": fold}, label

    return add_

dataset = dataset.map(add_fold(buckets=5))

Now we can split the dataset into two disjoint datasets with Dataset.filter:

def pick_fold(fold: int):
    def filter_fn(sample, _):
        return tf.math.equal(sample["fold"], fold)

    return filter_fn


def skip_fold(fold: int):
    def filter_fn(sample, _):
        return tf.math.not_equal(sample["fold"], fold)

    return filter_fn

train_dataset = dataset.filter(skip_fold(0))
val_dataset = dataset.filter(pick_fold(0))

The key that you use for hashing should be one that captures the correlations in the dataset. For example, if your samples collected by the same person are correlated and you want all samples with the same collector end up in the same bucket (and the same split), you should use the collector name or ID as the hashing column.

Of course, you can skip the part with dataset.map and do the hashing and filtering in one filter function. Here's a full example:

dataset = tf.data.Dataset.from_tensor_slices([f"value-{i}" for i in range(10000)])

def to_bucket(sample):
    return tf.strings.to_hash_bucket_fast(sample, 5)

def filter_train_fn(sample):
    return tf.math.not_equal(to_bucket(sample), 0)

def filter_val_fn(sample):
    return tf.math.logical_not(filter_train_fn(sample))

train_ds = dataset.filter(filter_train_fn)
val_ds = dataset.filter(filter_val_fn)

print(f"Length of training set: {len(list(train_ds.as_numpy_iterator()))}")
print(f"Length of validation set: {len(list(val_ds.as_numpy_iterator()))}")

This prints:

Length of training set: 7995
Length of validation set: 2005
Floccule answered 29/1, 2022 at 13:23 Comment(0)
G
0

Beware of lazy evaluation which produces two pipelines shuffle+take and shuffle+skip that do overlap. Due to this, some of the high-scored answers produce information leaks. Here is the correct way by repeating and seeding shuffle in both train and test datasets.

import tensorflow as tf

def gen_data():
    return iter(range(10))

ds = tf.data.Dataset.from_generator(
    gen_data, 
    output_signature=tf.TensorSpec(shape=(),dtype=tf.int32,name="element"))

SEED = 42 # NOTE: with no seed, you overlap train and test!

ds_train = ds.shuffle(100,seed=SEED).take(8).shuffle(100)
ds_test  = ds.shuffle(100,seed=SEED).skip(8)

A = set(ds_train.as_numpy_iterator())
B = set(ds_test.as_numpy_iterator())

assert A.intersection(B)==set()

print(list(A))
print(list(B))

NOTE: This works for any deterministically ordered iterator.

Gamber answered 27/10, 2023 at 14:54 Comment(0)
S
-2

Can't comment, but above answer has overlap and is incorrect. Set BUFFER_SIZE to DATASET_SIZE for perfect shuffle. Try different sized val/test size to verify. Answer should be:

DATASET_SIZE = tf.data.experimental.cardinality(full_dataset).numpy()
train_size = int(0.7 * DATASET_SIZE)
val_size = int(0.15 * DATASET_SIZE)
test_size = int(0.15 * DATASET_SIZE)

full_dataset = full_dataset.shuffle(BUFFER_SIZE)
train_dataset = full_dataset.take(train_size)
test_dataset = full_dataset.skip(train_size)
val_dataset = test_dataset.take(val_size)
test_dataset = test_dataset.skip(val_size)
Salvador answered 2/3, 2020 at 20:0 Comment(0)

© 2022 - 2024 — McMap. All rights reserved.