Parameter "stratify" from method "train_test_split" (scikit Learn)
Asked Answered
N

6

170

I am trying to use train_test_split from package scikit Learn, but I am having trouble with parameter stratify. Hereafter is the code:

from sklearn import cross_validation, datasets 

X = iris.data[:,:2]
y = iris.target

cross_validation.train_test_split(X,y,stratify=y)

However, I keep getting the following problem:

raise TypeError("Invalid parameters passed: %s" % str(options))
TypeError: Invalid parameters passed: {'stratify': array([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 
0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 
1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 2, 
2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 
2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2])}

Does someone have an idea what is going on? Below is the function documentation.

[...]

stratify : array-like or None (default is None)

If not None, data is split in a stratified fashion, using this as the labels array.

New in version 0.17: stratify splitting

[...]

Nejd answered 17/1, 2016 at 19:5 Comment(1)
Nope, all solved.Nejd
B
86

Scikit-Learn is just telling you it doesn't recognise the argument "stratify", not that you're using it incorrectly. This is because the parameter was added in version 0.17 as indicated in the documentation you quoted.

So you just need to update Scikit-Learn.

Berserk answered 10/12, 2016 at 10:33 Comment(1)
I'm getting the same error, although I have version 0.21.2 of scikit-learn. scikit-learn 0.21.2 py37h2a6a0b8_0 conda-forgeBirchfield
H
517

This stratify parameter makes a split so that the proportion of values in the sample produced will be the same as the proportion of values provided by parameter stratify.

For example: a binary categorical classification problem,

if y is the dependent variable or target\label column within dataframe following values:

  • 0 25% data is zeros
  • 1 75% data is ones

Then stratify=y will make sure that your random split has:

  • 25% of 0's
  • 75% of 1's
Hauser answered 11/8, 2016 at 7:0 Comment(8)
This doesn't really answer the question but is super useful for just understanding how it works. Thanks a ton.Minatory
I still struggle to understand, why this stratification is necessary: If there's class in-balance in the data, wouldn't it be preserved on average when doing a random split of the data?Moleskin
@HolgerBrandl it will be preserved on average; with stratify, it will be preserved for sure.Enidenigma
@HolgerBrandl with very small or very imbalanced data sets, it's quite possible that the random split could completely eliminate a class from one of the splits.Emilie
@HolgerBrandl Nice question! Maybe we could add that first, you have to split into training and test set using stratify. Then second, to correct imbalance you eventually need to run oversampling or undersampling on the training set. Many Sklearn classifier has a parameter called class-weight which you can set to balanced. Finally you could also take a more appropriate metric than accuracy for imbalanced dataset. Try, F1 or area under ROC.Keciakeck
isn't it violate the temporal order of time-series data?Alialia
So what's a small/large dataset? I have a fairly well-balanced dataset with shape (130000, 23). Should I be using stratify? If stratify preserves the state of a split, why isn't it on all the time? If it's enabled on a balanced or large set, we are simply preserving that state. So what's the drawback to using stratify on a large or balanced dataset?Anglo
So if using stratify, it isn't strictly necessary to use shuffle = True right?Hodometer
C
115

For my future self who comes here via Google:

train_test_split is now in model_selection, hence:

from sklearn.model_selection import train_test_split

# given:
# features: xs
# ground truth: ys

x_train, x_test, y_train, y_test = train_test_split(xs, ys,
                                                        test_size=0.33,
                                                        random_state=0,
                                                        stratify=ys)

is the way to use it. Setting the random_state is desirable for reproducibility.

Chiapas answered 12/10, 2017 at 18:36 Comment(0)
B
86

Scikit-Learn is just telling you it doesn't recognise the argument "stratify", not that you're using it incorrectly. This is because the parameter was added in version 0.17 as indicated in the documentation you quoted.

So you just need to update Scikit-Learn.

Berserk answered 10/12, 2016 at 10:33 Comment(1)
I'm getting the same error, although I have version 0.21.2 of scikit-learn. scikit-learn 0.21.2 py37h2a6a0b8_0 conda-forgeBirchfield
S
21

In this context, stratification means that the train_test_split method returns training and test subsets that have the same proportions of class labels as the input dataset.

Sniggle answered 1/12, 2017 at 0:14 Comment(0)
M
14

The answer I can give is that stratifying preserves the proportion of how data is distributed in the target column - and depicts that same proportion of distribution in the train_test_split. Take for example, if the problem is a binary classification problem, and the target column is having the proportion of:

  • 80% = yes
  • 20% = no

Since there are 4 times more 'yes' than 'no' in the target column, by splitting into train and test without stratifying, we might run into the trouble of having only the 'yes' falling into our training set, and all the 'no' falling into our test set. (i.e, the training set might not have 'no' in its target column)

Hence by Stratifying, the target column for:

  • the training set has 80% of 'yes' and 20% of 'no', and also,
  • the test set has 80% of 'yes' and 20% of 'no' respectively.

Hence, stratify makes even distribution of the target (label) in the train and test set - just as it is distributed in the original dataset.

from sklearn.model_selection import train_test_split
X_train, y_train, X_test, y_test = train_test_split(features,
                                                    target,
                                                    test-size = 0.25,
                                                    stratify = target,
                                                    random_state = 43)
Metalinguistic answered 2/5, 2022 at 21:37 Comment(1)
better if you pointed imbalanced problem within data in your binary classification motivation example for better understanding.Commensurate
R
6

Try running this code, it "just works":

from sklearn import cross_validation, datasets 

iris = datasets.load_iris()

X = iris.data[:,:2]
y = iris.target

x_train, x_test, y_train, y_test = cross_validation.train_test_split(X,y,train_size=.8, stratify=y)

y_test

array([0, 0, 0, 0, 2, 2, 1, 0, 1, 2, 2, 0, 0, 1, 0, 1, 1, 2, 1, 2, 0, 2, 2,
       1, 2, 1, 1, 0, 2, 1])
Reinwald answered 17/1, 2016 at 20:45 Comment(2)
@user5767535 As you might see it's working on my Ubuntu machine, with sklearn of '0.17' version, Anaconda distribution for Python 3,5. I can only suggest checking one more time if you enter the code correctly and updating your software.Reinwald
@user5767535 BTW, "New in version 0.17: stratify splitting" makes me almost certain that you have to update your sklearn...Reinwald

© 2022 - 2025 — McMap. All rights reserved.