Split .tfrecords file into many .tfrecords files
Asked Answered
V

6

14

Is there any way to split .tfrecords file into many .tfrecords files directly, without writing back each Dataset example ?

Vidda answered 4/2, 2019 at 15:25 Comment(0)
B
10

You can use a function like this:

import tensorflow as tf

def split_tfrecord(tfrecord_path, split_size):
    with tf.Graph().as_default(), tf.Session() as sess:
        ds = tf.data.TFRecordDataset(tfrecord_path).batch(split_size)
        batch = ds.make_one_shot_iterator().get_next()
        part_num = 0
        while True:
            try:
                records = sess.run(batch)
                part_path = tfrecord_path + '.{:03d}'.format(part_num)
                with tf.python_io.TFRecordWriter(part_path) as writer:
                    for record in records:
                        writer.write(record)
                part_num += 1
            except tf.errors.OutOfRangeError: break

For example, to split the file my_records.tfrecord into parts of 100 records each, you would do:

split_tfrecord(my_records.tfrecord, 100)

This would create multiple smaller record files my_records.tfrecord.000, my_records.tfrecord.001, etc.

Blayze answered 4/2, 2019 at 16:4 Comment(0)
J
18

In tensorflow 2.0.0, this will work:

import tensorflow as tf

raw_dataset = tf.data.TFRecordDataset("input_file.tfrecord")

shards = 10

for i in range(shards):
    writer = tf.data.experimental.TFRecordWriter(f"output_file-part-{i}.tfrecord")
    writer.write(raw_dataset.shard(shards, i))
Jot answered 14/12, 2019 at 19:26 Comment(1)
now how to merge all this split file ?Samuels
B
10

You can use a function like this:

import tensorflow as tf

def split_tfrecord(tfrecord_path, split_size):
    with tf.Graph().as_default(), tf.Session() as sess:
        ds = tf.data.TFRecordDataset(tfrecord_path).batch(split_size)
        batch = ds.make_one_shot_iterator().get_next()
        part_num = 0
        while True:
            try:
                records = sess.run(batch)
                part_path = tfrecord_path + '.{:03d}'.format(part_num)
                with tf.python_io.TFRecordWriter(part_path) as writer:
                    for record in records:
                        writer.write(record)
                part_num += 1
            except tf.errors.OutOfRangeError: break

For example, to split the file my_records.tfrecord into parts of 100 records each, you would do:

split_tfrecord(my_records.tfrecord, 100)

This would create multiple smaller record files my_records.tfrecord.000, my_records.tfrecord.001, etc.

Blayze answered 4/2, 2019 at 16:4 Comment(0)
H
3

Very efficient way for TensorFlow 2.x

As mentioned by @yongjieyongjie you should use .batch() instead of .shard() to avoid iterating more often over the dataset as needed. But in case you have a very large dataset, too big for memory, it will fail (but no error), just giving you a few files and a fraction of your original dataset.

First you should batch your dataset, and use as batch size the amount of records you want to have per file (I assume your dataset is already in serialized format, otherwise see here).

dataset = dataset.batch(ITEMS_PER_FILE)

Next thing you want to do, is to use a generator to avoid running out of memory.

def write_generator():
    i = 0
    iterator = iter(dataset)
    optional = iterator.get_next_as_optional()
    while optional.has_value().numpy():
        ds = optional.get_value()
        optional = iterator.get_next_as_optional()
        batch_ds = tf.data.Dataset.from_tensor_slices(ds)
        writer = tf.data.experimental.TFRecordWriter(save_to + "\\" + name + "-" + str(i) + ".tfrecord", compression_type='GZIP')#compression_type='GZIP'
        i += 1
        yield batch_ds, writer, i
    return

Now simply use the generator in a normal for-loop

for data, wri, i in write_generator():
    start_time = time.time()
    wri.write(data)
    print("Time needed: ", time.time() - start_time, "s", "\t", NAME_OF_FILES + "-" + str(i) + ".tfrecord")

As long one single file fits raw in memory, this should just work fine.

Heron answered 26/10, 2020 at 15:42 Comment(0)
S
2

Using .batch() instead of .shard() to avoid iterating over dataset multiple times

A more performant approach (compared to using tf.data.Dataset.shard()) would be to use batching:

import tensorflow as tf

ITEMS_PER_FILE = 100 # Assuming we are saving 100 items per .tfrecord file


raw_dataset = tf.data.TFRecordDataset('in.tfrecord')

batch_idx = 0
for batch in raw_dataset.batch(ITEMS_PER_FILE):

    # Converting `batch` back into a `Dataset`, assuming batch is a `tuple` of `tensors`
    batch_ds = tf.data.Dataset.from_tensor_slices(tuple([*batch]))
    filename = f'out.tfrecord.{batch_idx:03d}'

    writer = tf.data.experimental.TFRecordWriter(filename)
    writer.write(batch_ds)

    batch_idx += 1
Selfinductance answered 2/7, 2020 at 11:34 Comment(0)
G
1

Divide in N splits (tested in tensorflow 1.13.1)

import os
import hashlib
import tensorflow as tf
from tqdm import tqdm


def split_tfrecord(tfrecord_path, n_splits):
    dataset = tf.data.TFRecordDataset(tfrecord_path)
    outfiles=[]
    for n_split in range(n_splits):
        output_tfrecord_dir = f"{os.path.splitext(tfrecord_path)[0]}"
        if not os.path.exists(output_tfrecord_dir):
            os.makedirs(output_tfrecord_dir)
        output_tfrecord_path=os.path.join(output_tfrecord_dir, f"{n_split:03d}.tfrecord")
        out_f = tf.io.TFRecordWriter(output_tfrecord_path)
        outfiles.append(out_f)

    for record in tqdm(dataset):
        sample = tf.train.Example()
        record = record.numpy()
        sample.ParseFromString(record)

        idx = int(hashlib.sha1(record).hexdigest(),16) % n_splits
        outfiles[idx].write(example.SerializeToString())

    for file in outfiles:
        file.close()
Group answered 4/5, 2021 at 18:19 Comment(0)
C
0

Uneven splits

Most of the other answers work if you want to split evenly into files of equal size. This will work with uneven splits:

# `splits` is a list of the number of records you want in each output file
def split_files(filename: str, splits: List[int]) -> None:
    dataset: tf.data.Dataset = tf.data.TFRecordDataset(filename)
    rec_counter: int = 0

    # An extra iteration over the data to get the size
    total_records: int = len([r for r in dataset])
    print(f"Found {total_records} records in source file.")

    if sum(splits) != total_records:
        raise ValueError(f"Sum of splits {sum(splits)} does not equal "
                         f"total number of records {total_records}")

    rec_iter:Iterator = iter(dataset)
    split: int
    for split_idx, split in enumerate(splits):
        outfile: str = filename + f".{split_idx}-{split}"
        with tf.io.TFRecordWriter(outfile) as writer:
            for out_idx in range(split):
                rec: tf.Tensor = next(rec_iter, None)
                rec_counter +=1
                writer.write(rec.numpy())
        print(f"Finished writing {split} records to file {split_idx}")

Though I suppose technically the OP asked without writing back each Dataset example (which is what this does), this at least is doing it without deserializing each example.

It is a bit slow for very large files. There is probably a way to modify some of the other batching-based answers in order to use batched input reading but still write uneven splits, but I haven't tried.

Caddy answered 22/1, 2021 at 1:55 Comment(0)

© 2022 - 2024 — McMap. All rights reserved.