How to split data into 3 sets (train, validation and test)?
Asked Answered
P

13

250

I have a pandas dataframe and I wish to divide it to 3 separate sets. I know that using train_test_split from sklearn.cross_validation, one can divide the data in two sets (train and test). However, I couldn't find any solution about splitting the data into three sets. Preferably, I'd like to have the indices of the original data.

I know that a workaround would be to use train_test_split two times and somehow adjust the indices. But is there a more standard / built-in way to split the data into 3 sets instead of 2?

Polyvalent answered 7/7, 2016 at 16:26 Comment(9)
This doesn't answer your specific question, but I think the more standard approach for this would be splitting into two sets, train and test, and running cross-validation on the training set thus eliminating the need for a stand alone "development" set.Rootlet
This came up before, and as far as I know there is no built-in method for that yet.Fortney
I suggest Hastie et al.'s The Elements of Statistical Learning for a discussion on why to use three sets instead of two (web.stanford.edu/~hastie/local.ftp/Springer/OLD/… Model assessment and selection chapter)Fortney
@Rootlet In some models to prevent overfitting, there is a need for 3 sets instead of 2. Because in your design choices, you are somehow tuning parameters to improve performance on the test set. To prevent that, a development set is required. So, using cross validation will not be sufficient.Polyvalent
@Polyvalent I don't understand. You tune your model in cross-validation and then do a final test with your test set to ensure that your CV results line up. There are use cases where CV isn't a good call, like when there is inherit bias/leakage in data (ex: forecasting), but you wouldn't be looking for a general purpose function to split your dataset into 3 random chunks if that was the case.Rootlet
@Rootlet Ah I see. If you tune on CV and then evaluate on final test set it is fine then. I misinterpreted your comment. +1Polyvalent
@ayhan, a corrected URL for that book is statweb.stanford.edu/~tibs/ElemStatLearn/printings/…, chapter 7 (p. 219).Eisenstein
Using train_test_split two times is two lines of code. If you absolutely want only one line of code, then just wrap them up in a function. That way you keep the benefits of using a well-tested function with community support etc.Integration
Also, there isn't anything magical about the number 3. You could in principle need a fourth set (and a fifth and...) at a later point if you need to do more layers of cross-validation and testing with more models that have already been tested with the existing "test" set. So having an implementation in sklearn that is specific to exactly 3 sets may not be the way to go. It could, however, make sense to have an implementation that can split into n sets. Could roll your own by recursively calling train_test_split.Integration
W
293

Numpy solution. We will shuffle the whole dataset first (df.sample(frac=1, random_state=42)) and then split our data set into the following parts:

  • 60% - train set,
  • 20% - validation set,
  • 20% - test set

In [305]: train, validate, test = \
              np.split(df.sample(frac=1, random_state=42), 
                       [int(.6*len(df)), int(.8*len(df))])

In [306]: train
Out[306]:
          A         B         C         D         E
0  0.046919  0.792216  0.206294  0.440346  0.038960
2  0.301010  0.625697  0.604724  0.936968  0.870064
1  0.642237  0.690403  0.813658  0.525379  0.396053
9  0.488484  0.389640  0.599637  0.122919  0.106505
8  0.842717  0.793315  0.554084  0.100361  0.367465
7  0.185214  0.603661  0.217677  0.281780  0.938540

In [307]: validate
Out[307]:
          A         B         C         D         E
5  0.806176  0.008896  0.362878  0.058903  0.026328
6  0.145777  0.485765  0.589272  0.806329  0.703479

In [308]: test
Out[308]:
          A         B         C         D         E
4  0.521640  0.332210  0.370177  0.859169  0.401087
3  0.333348  0.964011  0.083498  0.670386  0.169619

[int(.6*len(df)), int(.8*len(df))] - is an indices_or_sections array for numpy.split().

Here is a small demo for np.split() usage - let's split 20-elements array into the following parts: 80%, 10%, 10%:

In [45]: a = np.arange(1, 21)

In [46]: a
Out[46]: array([ 1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20])

In [47]: np.split(a, [int(.8 * len(a)), int(.9 * len(a))])
Out[47]:
[array([ 1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15, 16]),
 array([17, 18]),
 array([19, 20])]
Wixted answered 7/7, 2016 at 16:56 Comment(14)
@root what exactly is the frac=1 parameter doing?Foiled
@SpiderWasp42, frac=1 instructs sample() function to return all (100% or fraction = 1.0) rowsWixted
Thanks @MaxU. I'd like to mention 2 things to keep things simplified. First, use np.random.seed(any_number) before the split line to obtain same result with every run. Second, to make unequal ratio like train:test:val::50:40:10 use [int(.5*len(dfn)), int(.9*len(dfn))]. Here first element denotes size for train (0.5%), second element denotes size for val (1-0.9 = 0.1%) and difference between the two denotes size for test(0.9-0.5 = 0.4%). Correct me if I'm wrong :)Tautomer
hrmm is it a mistake when you say "Here is a small demo for np.split() usage - let's split 20-elements array into the following parts: 90%, 10%, 10%:" I am pretty sure you mean 80%, 10%, 10%Straggle
Hey, @MaxU I had a case, something somewhat similar. I was wondering if you could look at it for me to see if it is and help me there. Here is my question #54848168Haik
@DeepakM, you can easily use provided solution two times - for X and yWixted
I am trying to use this method for splitting a 4D data (e.g, a 3D image with multiple channels). But pandas is complaining that I can only pass a 2D input. Is there a way to use this method if the input data is in the form of a multidimensional array?Yoon
Does this np.split shuffle the data ?Odessa
Good solution and much needed for all practitioners.Odessa
@KathiravanNatarajan, yes. In the first example it will be shuffled first, using - df.sample(frac=1)Wixted
I would add the random seed to the Pandas sample method explicitly. The default is None, not the Numpy seed. So pass df.sample(frac=1., random_state=42) or some variable or class instead of 42.Rochelle
@grofte, it's a good point, thank you! I've updated the answer correspondingly...Wixted
Note that by using Numpy you're missing out on sklearn functionality such as split stratification.Seniority
@Tautomer I'm afraid you are wrong ! according to docs : numpy.org/doc/stable/reference/generated/numpy.split.html, and cosidering your comment splitted data should be : train = data[0:50%] val = data[50%:90%] and test = data[90%:] so train is 50%, val is 40% and test is 10%Aldrin
S
76

However, one approach to dividing the dataset into train, test, cv with 0.6, 0.2, 0.2 would be to use the train_test_split method twice.

from sklearn.model_selection import train_test_split

x, x_test, y, y_test = train_test_split(xtrain,labels,test_size=0.2,train_size=0.8)
x_train, x_cv, y_train, y_cv = train_test_split(x,y,test_size = 0.25,train_size =0.75)
Sinotibetan answered 21/3, 2017 at 16:10 Comment(10)
Suboptimal for large datasetsMisinterpret
@MaksymGanenko Can you please elaborate ?Sinotibetan
You suggest to split data with two separate operations. Each data split involves data copying. So when you suggest to use two separate split operations instead of one you artificially create burden on both RAM and CPU. So your solution is suboptimal. Data split should be done with a single operation like np.split(). Also, it doesn't require additional dependency on sklearn.Misinterpret
@MaksymGanenko agreed on the extra burden on the memory, and for the same we can delete the original data from the memory i.e(xtrain & labels)! And about your suggestion for using numpy is somewhat limited to only integer data types what about other data types?Sinotibetan
With np.split() you can split indices and so you may reindex any datatype. If you look into train_test_split() you'll see that it does exactly the same way: define np.arange(), shuffle it and then reindex original data. But train_test_split() can't split data into three datasets, so its use is limited. In the context of the answer it's suboptimal (== wrong).Misinterpret
another benefit of this approach is that you can use the stratification parameters.Juggernaut
@MaksymGanenko: Who cares if its performance is suboptimal? Typically, you'd want to perform train/val/test splitting only once, so therefore you want to make sure it's done correctly, not just efficiently. For train/val/test splitting, you need to have stratified sampling, which is not available with Numpy split(); you have to implement stratification yourself. The sci-kit learn function does all that for you using train_test_split().Blacksnake
@Blacksnake Well, let's talk about it when you've got a dataset that barely fit into the memory ) ..or better doesn't fit into the RAM at all which is quite common for production tasks.Misinterpret
@MaksymGanenko: If data can't fit into memory, then I'd use Spark and its libraries to split the data.Blacksnake
If you want to make a small optimization, make an empty dataframe with the same indices as your original data. Pass that through this double split, then in the end take slices from the original dataframe using the indexes of the splits.Elaterid
R
73

Note:

Function was written to handle seeding of randomized set creation. You should not rely on set splitting that doesn't randomize the sets.

import numpy as np
import pandas as pd

def train_validate_test_split(df, train_percent=.6, validate_percent=.2, seed=None):
    np.random.seed(seed)
    perm = np.random.permutation(df.index)
    m = len(df.index)
    train_end = int(train_percent * m)
    validate_end = int(validate_percent * m) + train_end
    train = df.iloc[perm[:train_end]]
    validate = df.iloc[perm[train_end:validate_end]]
    test = df.iloc[perm[validate_end:]]
    return train, validate, test

Demonstration

np.random.seed([3,1415])
df = pd.DataFrame(np.random.rand(10, 5), columns=list('ABCDE'))
df

enter image description here

train, validate, test = train_validate_test_split(df)

train

enter image description here

validate

enter image description here

test

enter image description here

Riki answered 7/7, 2016 at 16:47 Comment(2)
I believe this function requires a df with index values ranging from 1 to n. In my case, I modified the function to use df.loc as my index values were not necessarily in this range.Twittery
I'd add one more np.random.seed() (without the passed seed) to the end, to avoid seeding the rest of your code.Halden
B
35

Here is a Python function that splits a Pandas dataframe into train, validation, and test dataframes with stratified sampling. It performs this split by calling scikit-learn's function train_test_split() twice.

import pandas as pd
from sklearn.model_selection import train_test_split

def split_stratified_into_train_val_test(df_input, stratify_colname='y',
                                         frac_train=0.6, frac_val=0.15, frac_test=0.25,
                                         random_state=None):
    '''
    Splits a Pandas dataframe into three subsets (train, val, and test)
    following fractional ratios provided by the user, where each subset is
    stratified by the values in a specific column (that is, each subset has
    the same relative frequency of the values in the column). It performs this
    splitting by running train_test_split() twice.

    Parameters
    ----------
    df_input : Pandas dataframe
        Input dataframe to be split.
    stratify_colname : str
        The name of the column that will be used for stratification. Usually
        this column would be for the label.
    frac_train : float
    frac_val   : float
    frac_test  : float
        The ratios with which the dataframe will be split into train, val, and
        test data. The values should be expressed as float fractions and should
        sum to 1.0.
    random_state : int, None, or RandomStateInstance
        Value to be passed to train_test_split().

    Returns
    -------
    df_train, df_val, df_test :
        Dataframes containing the three splits.
    '''

    if frac_train + frac_val + frac_test != 1.0:
        raise ValueError('fractions %f, %f, %f do not add up to 1.0' % \
                         (frac_train, frac_val, frac_test))

    if stratify_colname not in df_input.columns:
        raise ValueError('%s is not a column in the dataframe' % (stratify_colname))

    X = df_input # Contains all columns.
    y = df_input[[stratify_colname]] # Dataframe of just the column on which to stratify.

    # Split original dataframe into train and temp dataframes.
    df_train, df_temp, y_train, y_temp = train_test_split(X,
                                                          y,
                                                          stratify=y,
                                                          test_size=(1.0 - frac_train),
                                                          random_state=random_state)

    # Split the temp dataframe into val and test dataframes.
    relative_frac_test = frac_test / (frac_val + frac_test)
    df_val, df_test, y_val, y_test = train_test_split(df_temp,
                                                      y_temp,
                                                      stratify=y_temp,
                                                      test_size=relative_frac_test,
                                                      random_state=random_state)

    assert len(df_input) == len(df_train) + len(df_val) + len(df_test)

    return df_train, df_val, df_test

Below is a complete working example.

Consider a dataset that has a label upon which you want to perform the stratification. This label has its own distribution in the original dataset, say 75% foo, 15% bar and 10% baz. Now let's split the dataset into train, validation, and test into subsets using a 60/20/20 ratio, where each split retains the same distribution of the labels. See the illustration below:

enter image description here

Here is the example dataset:

df = pd.DataFrame( { 'A': list(range(0, 100)),
                     'B': list(range(100, 0, -1)),
                     'label': ['foo'] * 75 + ['bar'] * 15 + ['baz'] * 10 } )

df.head()
#    A    B label
# 0  0  100   foo
# 1  1   99   foo
# 2  2   98   foo
# 3  3   97   foo
# 4  4   96   foo

df.shape
# (100, 3)

df.label.value_counts()
# foo    75
# bar    15
# baz    10
# Name: label, dtype: int64

Now, let's call the split_stratified_into_train_val_test() function from above to get train, validation, and test dataframes following a 60/20/20 ratio.

df_train, df_val, df_test = \
    split_stratified_into_train_val_test(df, stratify_colname='label', frac_train=0.60, frac_val=0.20, frac_test=0.20)

The three dataframes df_train, df_val, and df_test contain all the original rows but their sizes will follow the above ratio.

df_train.shape
#(60, 3)

df_val.shape
#(20, 3)

df_test.shape
#(20, 3)

Further, each of the three splits will have the same distribution of the label, namely 75% foo, 15% bar and 10% baz.

df_train.label.value_counts()
# foo    45
# bar     9
# baz     6
# Name: label, dtype: int64

df_val.label.value_counts()
# foo    15
# bar     3
# baz     2
# Name: label, dtype: int64

df_test.label.value_counts()
# foo    15
# bar     3
# baz     2
# Name: label, dtype: int64
Blacksnake answered 22/3, 2020 at 19:43 Comment(2)
NameError: name 'df' is not defined. The 'df' in split_stratified_into_train_val_test() should be replaced with 'df_input'.Humberto
Thanks. I fixed it. The problem was in an error-handling path of the code.Blacksnake
S
4

In the case of supervised learning, you may want to split both X and y (where X is your input and y the ground truth output). You just have to pay attention to shuffle X and y the same way before splitting.

Here, either X and y are in the same dataframe, so we shuffle them, separate them and apply the split for each (just like in chosen answer), or X and y are in two different dataframes, so we shuffle X, reorder y the same way as the shuffled X and apply the split to each.

# 1st case: df contains X and y (where y is the "target" column of df)
df_shuffled = df.sample(frac=1)
X_shuffled = df_shuffled.drop("target", axis = 1)
y_shuffled = df_shuffled["target"]

# 2nd case: X and y are two separated dataframes
X_shuffled = X.sample(frac=1)
y_shuffled = y[X_shuffled.index]

# We do the split as in the chosen answer
X_train, X_validation, X_test = np.split(X_shuffled, [int(0.6*len(X)),int(0.8*len(X))])
y_train, y_validation, y_test = np.split(y_shuffled, [int(0.6*len(X)),int(0.8*len(X))])
Shalne answered 4/9, 2020 at 8:21 Comment(0)
A
2

It is very convenient to use train_test_split without performing reindexing after dividing to several sets and not writing some additional code. Best answer above does not mention that by separating two times using train_test_split not changing partition sizes won`t give initially intended partition:

x_train, x_remain = train_test_split(x, test_size=(val_size + test_size))

Then the portion of validation and test sets in the x_remain change and could be counted as

new_test_size = np.around(test_size / (val_size + test_size), 2)
# To preserve (new_test_size + new_val_size) = 1.0 
new_val_size = 1.0 - new_test_size

x_val, x_test = train_test_split(x_remain, test_size=new_test_size)

In this occasion all initial partitions are saved.

Aquarist answered 16/11, 2018 at 9:35 Comment(0)
I
1
def train_val_test_split(X, y, train_size, val_size, test_size):
    X_train_val, X_test, y_train_val, y_test = train_test_split(X, y, test_size = test_size)
    relative_train_size = train_size / (val_size + train_size)
    X_train, X_val, y_train, y_val = train_test_split(X_train_val, y_train_val,
                                                      train_size = relative_train_size, test_size = 1-relative_train_size)
    return X_train, X_val, X_test, y_train, y_val, y_test

Here we split data 2 times with sklearn's train_test_split

Intorsion answered 21/11, 2020 at 12:34 Comment(0)
B
1

Considering that df id your original dataframe:

1 - First you split data between Train and Test (10%):

my_test_size = 0.10

X_train_, X_test, y_train_, y_test = train_test_split(
    df.index.values,
    df.label.values,
    test_size=my_test_size,
    random_state=42,
    stratify=df.label.values,    
)

2 - Then you split the train set between train and validation (20%):

my_val_size = 0.20

X_train, X_val, y_train, y_val = train_test_split(
    df.loc[X_train_].index.values,
    df.loc[X_train_].label.values,
    test_size=my_val_size,
    random_state=42,
    stratify=df.loc[X_train_].label.values,  
)

3 - Then, you slice the original dataframe according to the indices generated in the steps above:

# data_type is not necessary. 
df['data_type'] = ['not_set']*df.shape[0]
df.loc[X_train, 'data_type'] = 'train'
df.loc[X_val, 'data_type'] = 'val'
df.loc[X_test, 'data_type'] = 'test'

The result is going to be like this:

enter image description here

Note: This soluctions uses the workaround mentioned in the question.

Bagging answered 30/11, 2020 at 22:20 Comment(0)
J
0

Split the dataset in training and testing set as in the other answers, using

from sklearn.model_selection import train_test_split

X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)

Then, if you fit your model, you can add validation_split as a parameter. Then you do not need to create the validation set in advance. For example:

from tensorflow.keras import Model

model = Model(input_layer, out)

[...]

history = model.fit(x=X_train, y=y_train, [...], validation_split = 0.3)

The validation set is meant to serve as a representative on-the-run-testing-set during training of the training set, taken entirely from the training set, be it by k-fold cross-validation (recommended) or by validation_split; then you do not need to create a validation set separately and still you split a dataset into the three sets you are asking for.

Jedjedd answered 23/5, 2021 at 18:28 Comment(0)
G
0

ANSWER FOR ANY AMOUNT OF SUB-SETS:

def _separate_dataset(patches, label_patches, percentage, shuffle: bool = True):
    """
    :param patches: data patches
    :param label_patches: label patches
    :param percentage: list of percentages for each value, example [0.9, 0.02, 0.08] to get 90% train, 2% val and 8% test.
    :param shuffle: Shuffle dataset before split.
    :return: tuple of two lists of size = len(percentage), one with data x and other with labels y.
    """
    x_test = patches
    y_test = label_patches
    percentage = list(percentage)       # need it to be mutable
    assert sum(percentage) == 1., f"percentage must add to 1, but it adds to sum{percentage} = {sum(percentage)}"
    x = []
    y = []
    for i, per in enumerate(percentage[:-1]):
        x_train, x_test, y_train, y_test = train_test_split(x_test, y_test, test_size=1-per, shuffle=shuffle)
        percentage[i+1:] = [value / (1-percentage[i]) for value in percentage[i+1:]]
        x.append(x_train)
        y.append(y_train)
    x.append(x_test)
    y.append(y_test)
    return x, y

This work for any size of percentage. In your case, you should do percentage = [train_percentage, val_percentage, test_percentage].

Gerome answered 29/10, 2021 at 15:32 Comment(0)
T
0

The easiest way that I could think of is mapping split fractions to array index as follows:

train_set = data[:int((len(data)+1)*train_fraction)]
test_set = data[int((len(data)+1)*train_fraction):int((len(data)+1)*(train_fraction+test_fraction))]
val_set = data[int((len(data)+1)*(train_fraction+test_fraction)):]

where data = random.shuffle(data)

Turtleback answered 22/3, 2022 at 15:33 Comment(0)
N
0

I am always using this method to do a train, test, validation split. It always splits your data in the desired sizes.

def train_test_val_split(df, train_size, val_size, test_size, random_state=42):
    """
    Splits a pandas dataframe into training, validation, and test sets.

    Args:
    - df: pandas dataframe to split.
    - train_size: float between 0 and 1 indicating the proportion of the dataframe to include in the training set.
    - val_size: float between 0 and 1 indicating the proportion of the dataframe to include in the validation set.
    - test_size: float between 0 and 1 indicating the proportion of the dataframe to include in the test set.
    - random_state: int or None, optional (default=42). The seed used by the random number generator.

    Returns:
    - train_df: pandas dataframe containing the training set.
    - val_df: pandas dataframe containing the validation set.
    - test_df: pandas dataframe containing the test set.

    Raises:
    - AssertionError: if the sum of train_size, val_size, and test_size is not equal to 1.
    """

    assert train_size + val_size + test_size == 1, "Train, validation, and test sizes must add up to 1."
    
    # Split the dataframe into training and test sets
    train_df, test_df = train_test_split(df, test_size=test_size, random_state=random_state)
    
    # Calculate the size of the validation set relative to the original dataframe
    val_ratio = val_size / (1 - test_size)
    
    # Split the training set into training and validation sets
    train_df, val_df = train_test_split(train_df, test_size=val_ratio, random_state=random_state)
    
    return train_df, val_df, test_df

EDIT: You can also include the split of X and y into train, test, val set directly:

def train_test_val_split(X, y, train_size, val_size, test_size, random_state=42):
    """
    Splits X and y into training, validation, and test sets.

    Args:
    - X: pandas dataframe or array containing the independent variables.
    - y: pandas series or array containing the dependent variable.
    - train_size: float between 0 and 1 indicating the proportion of the data to include in the training set.
    - val_size: float between 0 and 1 indicating the proportion of the data to include in the validation set.
    - test_size: float between 0 and 1 indicating the proportion of the data to include in the test set.
    - random_state: int or None, optional (default=42). The seed used by the random number generator.

    Returns:
    - X_train: pandas dataframe or array containing the independent variables for the training set.
    - X_val: pandas dataframe or array containing the independent variables for the validation set.
    - X_test: pandas dataframe or array containing the independent variables for the test set.
    - y_train: pandas series or array containing the dependent variable for the training set.
    - y_val: pandas series or array containing the dependent variable for the validation set.
    - y_test: pandas series or array containing the dependent variable for the test set.

    Raises:
    - AssertionError: if the sum of train_size, val_size, and test_size is not equal to 1.
    """

    assert train_size + val_size + test_size == 1, "Train, validation, and test sizes must add up to 1."
    
    # Concatenate X and y into a single dataframe
    df = pd.concat([X, y], axis=1)
    
    # Split the dataframe into training and test sets
    train_df, test_df = train_test_split(df, test_size=test_size, random_state=random_state)
    
    # Calculate the size of the validation set relative to the original dataframe
    val_ratio = val_size / (1 - test_size)
    
    # Split the training set into training and validation sets
    train_df, val_df = train_test_split(train_df, test_size=val_ratio, random_state=random_state)
    
    # Split the training, validation, and test dataframes into X and y values
    X_train, y_train = train_df.drop(columns=y.name), train_df[y.name]
    X_val, y_val = val_df.drop(columns=y.name), val_df[y.name]
    X_test, y_test = test_df.drop(columns=y.name), test_df[y.name]
    
    return X_train, X_val, X_test, y_train, y_val, y_test

Nkrumah answered 18/4, 2023 at 9:31 Comment(0)
U
0

Here's a version that recursively splits the input arrays using sklearn's train_test_split function. Therefore, it handles an arbitrary number of input arrays (just like the sklearn version) in addition to an arbitrary number of splits. I've tested it on 3 input arrays and four splits.

from typing import Any, Iterable, List, Union

import numpy as np
from sklearn.model_selection import train_test_split as _train_test_split


def train_test_split(
    *arrays: Any,
    sizes: Union[float, Iterable[float]] = None,
    random_state: Any = None,
    shuffle: bool = True,
    stratify: Any = None
) -> List:
    """Like ``sklearn.model_selection.train_test_split`` but handles multiple splits.

    Returns:
        A list of array splits. The first ``n`` elements are the splits of the first array and so
        on. For example, ``X_train``, ``X_valid``, ``X_test``.

    Examples:
        >>> train_test_split(list(range(10)), sizes=(0.7, 0.2, 0.1))
        [[0, 1, 9, 6, 5, 8, 2], [3, 7], [4]]
        >>> train_test_split(features, labels, sizes=(0.7, 0.2, 0.1))
    """
    if isinstance(sizes, float):
        sizes = [sizes]
    else:
        sizes = np.array(sizes, dtype="float")
        if len(sizes) > 1:
            sizes /= sizes.sum()

    train_size = sizes[0]
    common_args = dict(random_state=random_state, shuffle=shuffle, stratify=stratify)

    if len(sizes) <= 2:
        return _train_test_split(*arrays, train_size=train_size, **common_args)
    else:
        n = len(arrays)
        left_arrays = _train_test_split(*arrays, train_size=train_size, **common_args)
        right_arrays = train_test_split(*left_arrays[1::2], sizes=sizes[1:], **common_args)

        interleaved = []
        s = len(sizes) - 1
        for i in range(n):
            interleaved.append(left_arrays[i * 2])
            interleaved.extend(right_arrays[i * s : (i + 1) * s])

        return interleaved
Undervest answered 16/5, 2023 at 20:4 Comment(0)

© 2022 - 2024 — McMap. All rights reserved.