has several other options other than train_test_split
. One of them, aims at solving what you're asking for. In this case you could use GroupShuffleSplit
, which as mentioned inthe docs it provides randomized train/test indices to split data according to a third-party provided group. This is useful when you're doing cross-validation, and you want to split in validation-train multiple times, ensuring that the sets are split by the group
field. You also have GroupKFold
for these cases which is very useful.
So, adapting your example, here's what you could do.
Say you have for instance:
from sklearn.model_selection import GroupShuffleSplit
cols = ['customer', 'calibrat', 'churn', 'churndep', 'revenue', 'mou',]
X = pd.DataFrame(np.random.rand(10, 6), columns=cols)
X['calibrat'] = np.random.choice([0,1], size=10)
customer calibrat churn churndep revenue mou
0 0.523571 1 0.394896 0.933637 0.232630 0.103486
1 0.456720 1 0.850961 0.183556 0.885724 0.993898
2 0.411568 1 0.003360 0.774391 0.822560 0.840763
3 0.148390 0 0.115748 0.089891 0.842580 0.565432
4 0.505548 0 0.370198 0.566005 0.498009 0.601986
5 0.527433 0 0.550194 0.991227 0.516154 0.283175
6 0.983699 0 0.514049 0.958328 0.005034 0.050860
7 0.923172 0 0.531747 0.026763 0.450077 0.961465
8 0.344771 1 0.332537 0.046829 0.047598 0.324098
9 0.195655 0 0.903370 0.399686 0.170009 0.578925
y = X.pop('churn')
You can now instanciate GroupShuffleSplit
, and do as you would with train_test_split
, with the only difference of specifying a group
column, which will be used to split X
and y
so the groups are split according the the groups values:
gs = GroupShuffleSplit(n_splits=2, train_size=.7, random_state=42)
As mentioned, this is more handy when you want to split into multiple groups, generally for cross validation purposes. Here's just an example of how you'd get two splits, as mentioned in the question:
train_ix, test_ix = next(gs.split(X, y, groups=X.calibrat))
X_train = X.loc[train_ix]
y_train = y.loc[train_ix]
X_test = X.loc[test_ix]
y_test = y.loc[test_ix]
customer calibrat churndep revenue mou
3 0.148390 0 0.089891 0.842580 0.565432
4 0.505548 0 0.566005 0.498009 0.601986
5 0.527433 0 0.991227 0.516154 0.283175
6 0.983699 0 0.958328 0.005034 0.050860
7 0.923172 0 0.026763 0.450077 0.961465
9 0.195655 0 0.399686 0.170009 0.578925
customer calibrat churndep revenue mou
0 0.523571 1 0.933637 0.232630 0.103486
1 0.456720 1 0.183556 0.885724 0.993898
2 0.411568 1 0.774391 0.822560 0.840763
8 0.344771 1 0.046829 0.047598 0.324098
X1_train, X1_test, y1_train, y1_test = train_test_split(X1.loc[X1['calibrat']==1], y1.loc[X1['calibrat']!=1], test_size=0.3, random_state=0)
? – Gautious