How to split data based on a column value in sklearn
Asked Answered
C

1

11

I have a data file with following columns

'customer', 'calibrat' - Calibration sample = 1; Validation sample = 0; 'churn', 'churndep', 'revenue', 'mou',

Data file contains some 40000 rows out of which 20000 have value for calibrat as 1. I want to split this data into

X1 = data.loc[:, data.columns != 'churn']
y1 = data.loc[:, data.columns == 'churn']
from imblearn.over_sampling import SMOTE
os = SMOTE(random_state=0)
X1_train, X1_test, y1_train, y1_test = train_test_split(X1, y1, test_size=0.3, random_state=0)

what I want is that in my X1_train should come data for Calibration with calibrat =1 and in X1_test should come all data for validation with calibrat = 0

Clue answered 9/4, 2020 at 6:56 Comment(2)
Have you tried X1_train, X1_test, y1_train, y1_test = train_test_split(X1.loc[X1['calibrat']==1], y1.loc[X1['calibrat']!=1], test_size=0.3, random_state=0)?Gautious
No this does not workClue
D
9

sklearn.model_selection has several other options other than train_test_split. One of them, aims at solving what you're asking for. In this case you could use GroupShuffleSplit, which as mentioned inthe docs it provides randomized train/test indices to split data according to a third-party provided group. This is useful when you're doing cross-validation, and you want to split in validation-train multiple times, ensuring that the sets are split by the group field. You also have GroupKFold for these cases which is very useful.

So, adapting your example, here's what you could do.

Say you have for instance:

from sklearn.model_selection import GroupShuffleSplit

cols = ['customer', 'calibrat', 'churn', 'churndep', 'revenue', 'mou',]
X = pd.DataFrame(np.random.rand(10, 6), columns=cols)
X['calibrat'] = np.random.choice([0,1], size=10)

print(X)

   customer  calibrat     churn  churndep   revenue       mou
0  0.523571         1  0.394896  0.933637  0.232630  0.103486
1  0.456720         1  0.850961  0.183556  0.885724  0.993898
2  0.411568         1  0.003360  0.774391  0.822560  0.840763
3  0.148390         0  0.115748  0.089891  0.842580  0.565432
4  0.505548         0  0.370198  0.566005  0.498009  0.601986
5  0.527433         0  0.550194  0.991227  0.516154  0.283175
6  0.983699         0  0.514049  0.958328  0.005034  0.050860
7  0.923172         0  0.531747  0.026763  0.450077  0.961465
8  0.344771         1  0.332537  0.046829  0.047598  0.324098
9  0.195655         0  0.903370  0.399686  0.170009  0.578925

y = X.pop('churn')

You can now instanciate GroupShuffleSplit, and do as you would with train_test_split, with the only difference of specifying a group column, which will be used to split X and y so the groups are split according the the groups values:

gs = GroupShuffleSplit(n_splits=2, train_size=.7, random_state=42)

As mentioned, this is more handy when you want to split into multiple groups, generally for cross validation purposes. Here's just an example of how you'd get two splits, as mentioned in the question:

train_ix, test_ix = next(gs.split(X, y, groups=X.calibrat))

X_train = X.loc[train_ix]
y_train = y.loc[train_ix]

X_test = X.loc[test_ix]
y_test = y.loc[test_ix]

Giving:

print(X_train)

   customer  calibrat  churndep   revenue       mou
3  0.148390         0  0.089891  0.842580  0.565432
4  0.505548         0  0.566005  0.498009  0.601986
5  0.527433         0  0.991227  0.516154  0.283175
6  0.983699         0  0.958328  0.005034  0.050860
7  0.923172         0  0.026763  0.450077  0.961465
9  0.195655         0  0.399686  0.170009  0.578925

print(X_test)

   customer  calibrat  churndep   revenue       mou
0  0.523571         1  0.933637  0.232630  0.103486
1  0.456720         1  0.183556  0.885724  0.993898
2  0.411568         1  0.774391  0.822560  0.840763
8  0.344771         1  0.046829  0.047598  0.324098
Discontinue answered 9/4, 2020 at 7:33 Comment(3)
Thanks for answering. I do not want to split the dataset into multiple sets. I have a dataset which is in variable data and want to have calibration data with calibrat as 1 and validation data with calibrat as 0Clue
This is great! Easier to implement than the other examples I've found.Melton
Love that little trick of doing y = X.pop('churn')Annalee

© 2022 - 2024 — McMap. All rights reserved.