Why we should call split() function during passing StratifiedKFold() as a parameter of GridSearchCV?
Asked Answered
F

1

3

What I am trying to do?

I am trying to use StratifiedKFold() in GridSearchCV().

Then, what does confuse me?

When we use K Fold Cross Validation, we just pass the number of CV inside GridSearchCV() like the following.

grid_search_m = GridSearchCV(rdm_forest_clf, param_grid, cv=5, scoring='f1', return_train_score=True, n_jobs=2)

Then, when I will need to use StratifiedKFold(), I think the procedure should remain same. That is, set the number of splits only - StratifiedKFold(n_splits=5) to cv.

grid_search_m = GridSearchCV(rdm_forest_clf, param_grid, cv=StratifiedKFold(n_splits=5), scoring='f1', return_train_score=True, n_jobs=2)

But this answer says

whatever the cross validation strategy used, all that is needed is to provide the generator using the function split, as suggested:

kfolds = StratifiedKFold(5)
clf = GridSearchCV(estimator, parameters, scoring=qwk, cv=kfolds.split(xtrain,ytrain))
clf.fit(xtrain, ytrain)

Moreover, one of the answers of this question also suggest to do this. This means, they suggest to call split function :StratifiedKFold(n_splits=5).split(xtrain,ytrain) during using GridSearchCV(). But, I have found that calling split() and without calling split() give me the same f1 score.

Hence, my questions

  • I do not understand why do we need to call split() function during Stratified K Fold as we do not need to do such type of things during K Fold CV.

  • If split() function is called, how GridSearchCV() will work as Split() function returns training and testing data set indices? That is, I want to know how GridSearchCV() will use those indices?

Fish answered 2/6, 2020 at 7:35 Comment(0)
E
3

Basically GridSearchCV is clever and can take multiple options for that cv parameter - a number, an iterator of split indices or an object with a split function. You can look at the code here, copied below.

cv = 5 if cv is None else cv
if isinstance(cv, numbers.Integral):
    if (classifier and (y is not None) and
            (type_of_target(y) in ('binary', 'multiclass'))):
        return StratifiedKFold(cv)
    else:
        return KFold(cv)

if not hasattr(cv, 'split') or isinstance(cv, str):
    if not isinstance(cv, Iterable) or isinstance(cv, str):
        raise ValueError("Expected cv as an integer, cross-validation "
                         "object (from sklearn.model_selection) "
                         "or an iterable. Got %s." % cv)
    return _CVIterableWrapper(cv)

return cv  # New style cv objects are passed without any modification

Basically if you don't pass anything, it uses a KFold with 5. It's also clever enough to automatically use StratifedKFold, if it's a classification problem and the target is binary/multiclass.

If you pass an object with a split function, it just uses that. And if you don't pass any of them, but pass an iterable, it assumes that is an iterable of the split indices and wraps that up for you.

So in your case, assuming it's a classification problem with a binary/multiclass target, all the below will give the exact same results/splits - it does not matter which one you use!

cv=5
cv=StratifiedKFold(5)
cv=StratifiedKFold(5).split(xtrain,ytrain)
Efface answered 4/6, 2020 at 12:20 Comment(2)
Thanks for your response. You remarked that "If you pass an object with a split function, it just uses that." But I do not understand "how GridSearchCV() will use those indices, found by splitting?" Can you please describe it?Fish
So for each parameter set in the grid search, it will use the splits to run the cross validation - so if you have 3 options for 2 parameters in the param grid (6 sets), and 5 fold cross validation, then really you train and validate 30 models. Then the parameter set with the highest mean validation score across the cross validation runs is picked as the "best"Efface

© 2022 - 2024 — McMap. All rights reserved.