How to use random_split with percentage split (sum of input lengths does not equal the length of the input dataset)
Asked Answered
S

2

10

I tried to use torch.utils.data.random_split as follows:

import torch
from torch.utils.data import DataLoader, random_split

list_dataset = [1,2,3,4,5,6,7,8,9,10]
dataset = DataLoader(list_dataset, batch_size=1, shuffle=False)

random_split(dataset, [0.8, 0.1, 0.1], generator=torch.Generator().manual_seed(123))

However, when I tried this, I got the error raise ValueError("Sum of input lengths does not equal the length of the input dataset!")

I looked at the docs and it seems like I should be able to pass in decimals that sum to 1, but clearly it's not working.

I also Googled this error and the closest thing that comes up is this issue.

What am I doing wrong?

Summerlin answered 5/11, 2022 at 11:50 Comment(0)
S
11

You're likely using an older version of PyTorch, such as Pytorch 1.10, which does not have this functionality.

To replicate this functionality in the older version, you can just copy the source code of the newer version:

import math
from torch import default_generator, randperm
from torch._utils import _accumulate
from torch.utils.data.dataset import Subset

def random_split(dataset, lengths,
                 generator=default_generator):
    r"""
    Randomly split a dataset into non-overlapping new datasets of given lengths.

    If a list of fractions that sum up to 1 is given,
    the lengths will be computed automatically as
    floor(frac * len(dataset)) for each fraction provided.

    After computing the lengths, if there are any remainders, 1 count will be
    distributed in round-robin fashion to the lengths
    until there are no remainders left.

    Optionally fix the generator for reproducible results, e.g.:

    >>> random_split(range(10), [3, 7], generator=torch.Generator().manual_seed(42))
    >>> random_split(range(30), [0.3, 0.3, 0.4], generator=torch.Generator(
    ...   ).manual_seed(42))

    Args:
        dataset (Dataset): Dataset to be split
        lengths (sequence): lengths or fractions of splits to be produced
        generator (Generator): Generator used for the random permutation.
    """
    if math.isclose(sum(lengths), 1) and sum(lengths) <= 1:
        subset_lengths: List[int] = []
        for i, frac in enumerate(lengths):
            if frac < 0 or frac > 1:
                raise ValueError(f"Fraction at index {i} is not between 0 and 1")
            n_items_in_split = int(
                math.floor(len(dataset) * frac)  # type: ignore[arg-type]
            )
            subset_lengths.append(n_items_in_split)
        remainder = len(dataset) - sum(subset_lengths)  # type: ignore[arg-type]
        # add 1 to all the lengths in round-robin fashion until the remainder is 0
        for i in range(remainder):
            idx_to_add_at = i % len(subset_lengths)
            subset_lengths[idx_to_add_at] += 1
        lengths = subset_lengths
        for i, length in enumerate(lengths):
            if length == 0:
                warnings.warn(f"Length of split at index {i} is 0. "
                              f"This might result in an empty dataset.")

    # Cannot verify that dataset is Sized
    if sum(lengths) != len(dataset):    # type: ignore[arg-type]
        raise ValueError("Sum of input lengths does not equal the length of the input dataset!")

    indices = randperm(sum(lengths), generator=generator).tolist()  # type: ignore[call-overload]
    return [Subset(dataset, indices[offset - length : offset]) for offset, length in zip(_accumulate(lengths), lengths)]
Summerlin answered 5/11, 2022 at 11:50 Comment(1)
Far as I can tell, this wasn't introduced until 1.13. Correct me if I am wrong.Countryman
C
1

If you know the length of your dataset, ie, it has the len method,

proportions = [.75, .10, .15]
lengths = [int(p * len(dataset)) for p in proportions]
lengths[-1] = len(dataset) - sum(lengths[:-1])
tr_dataset, vl_dataset, ts_dataset = random_split(dataset, lengths)
Countryman answered 13/12, 2022 at 17:2 Comment(0)

© 2022 - 2024 — McMap. All rights reserved.