How to gridsearch over transform arguments within a pipeline in scikit-learn
Asked Answered
B

3

10

My goal is to use one model to select the most important variables and another model to use those variables to make predictions. In the example below I am using two instances of RandomForestClassifier, but the second model could be any other classifier.

The RF has a transform method with a threshold argument. I would like to grid search over different possible threshold arguments.

Here is a simplified code snippet:

# Transform object and classifier
rf_filter = RandomForestClassifier(n_estimators=200, n_jobs=-1, random_state=42, oob_score=False)
clf = RandomForestClassifier(n_jobs=-1, random_state=42, oob_score=False)

pipe = Pipeline([("RFF", rf_filter), ("RF", clf)])

# Grid search parameters
rf_n_estimators = [10, 20]
rff_transform = ["median", "mean"] # Search the threshold parameters

estimator = GridSearchCV(pipe,
                            cv = 3, 
                            param_grid = dict(RF__n_estimators = rf_n_estimators,
                                            RFF__threshold = rff_transform))

estimator.fit(X_train, y_train)

The error is ValueError: Invalid parameter threshold for estimator RandomForestClassifier

I thought this would work because the docs say:

If None and if available, the object attribute threshold is used.

I tried setting the threshold attribute before the grid search (rf_filter.threshold = "median") and it worked; however, I couldn't figure out how to then grid search over it.

Is there a way to iterate over different arguments that would normally be expected to be provided within the transform method of a classifier?

Botvinnik answered 19/4, 2014 at 20:13 Comment(0)
A
10

Following the same method as you are describing, namely doing feature selection and classification with two distinct Random Forest classifiers grouped into a Pipeline, I ran into the same issue.

An instance of the RandomForestClassifier class does not have an attribute called threshold. You can indeed manually add one, either using the way you described or with

setattr(object, 'threshold', 'mean')

but the main problem seems to be the way the get_params method checks for valid attributes of any member of BaseEstimator:

class BaseEstimator(object):
"""Base class for all estimators in scikit-learn

Notes
-----
All estimators should specify all the parameters that can be set
at the class level in their __init__ as explicit keyword
arguments (no *args, **kwargs).
"""

@classmethod
def _get_param_names(cls):
    """Get parameter names for the estimator"""
    try:
        # fetch the constructor or the original constructor before
        # deprecation wrapping if any
        init = getattr(cls.__init__, 'deprecated_original', cls.__init__)

        # introspect the constructor arguments to find the model parameters
        # to represent
        args, varargs, kw, default = inspect.getargspec(init)
        if not varargs is None:
            raise RuntimeError("scikit-learn estimators should always "
                               "specify their parameters in the signature"
                               " of their __init__ (no varargs)."
                               " %s doesn't follow this convention."
                               % (cls, ))
        # Remove 'self'
        # XXX: This is going to fail if the init is a staticmethod, but
        # who would do this?
        args.pop(0)
    except TypeError:
        # No explicit __init__
        args = []
    args.sort()
    return args

Indeed, as clearly specified, all estimators should specify all the parameters that can be set at the class level in their __init__ as explicit keyword arguments.

So I tried to specify threshold as an argument in the __init__ function with a default value to 'mean' (which is anyway its default value in the current implementation)

    def __init__(self,
             n_estimators=10,
             criterion="gini",
             max_depth=None,
             min_samples_split=2,
             min_samples_leaf=1,
             max_features="auto",
             bootstrap=True,
             oob_score=False,
             n_jobs=1,
             random_state=None,
             verbose=0,
             min_density=None,
             compute_importances=None,
             threshold="mean"): # ADD THIS!

and then assign the value of this argument to a parameter of the class.

    self.threshold = threshold # ADD THIS LINE SOMEWHERE IN THE FUNCTION __INIT__

Of course, this implies modifying the class RandomForestClassifier (in /python2.7/site-packages/sklearn/ensemble/forest.py) which might not be the best way... But it works for me! I am now able to grid search (and cross validate) over different threshold argument and thus different number of features selected.

Addi answered 9/7, 2014 at 17:21 Comment(2)
Thank you so much! I am ultra-impressed you were able to answer what was a very difficult question. Thank you!Botvinnik
hey, long time later but thanks again for this great answer.Efficacy
D
5
class my_rf_filter(BaseEstimator, TransformerMixin):
def __init__(self,threshold):
    self.threshold = threshold

def fit(self,X,y):
    model = RandomForestClassifier(n_estimators=100, n_jobs=-1, random_state=42, oob_score=False)
    model.fit(X,y)
    self.model = model
    return self

def transform(self,X):
    return self.model.transform(X,self.threshold)

By wrapping RandomForestClassifier in a new class, it will work.

rf_filter = my_rf_filter(threshold='mean')
clf = RandomForestClassifier(n_jobs=-1, random_state=42, oob_score=False)

pipe = Pipeline([("RFF", rf_filter), ("RF", clf)])

# Grid search parameters
rf_n_estimators = [10, 20]
rff_transform = ["median", "mean"] # Search the threshold parameters

estimator = GridSearchCV(pipe,
                         cv = 3, 
                         param_grid = dict(RF__n_estimators = rf_n_estimators,
                                           RFF__threshold = rff_transform))

A testing example:

from sklearn import datasets
digits = datasets.load_digits()
X_digits = digits.data
y_digits = digits.target

estimator.fit(X_digits, y_digits)


Out[143]:
GridSearchCV(cv=3,
       estimator=Pipeline(steps=[('RFF', my_rf_filter(threshold='mean')), ('RF', RandomForestClassifier(bootstrap=True, compute_importances=None,
            criterion='gini', max_depth=None, max_features='auto',
            max_leaf_nodes=None, min_density=None, min_samples_leaf=1,
            min_samples_split=2, n_estimators=10, n_jobs=-1,
            oob_score=False, random_state=42, verbose=0))]),
       fit_params={}, iid=True, loss_func=None, n_jobs=1,
       param_grid={'RF__n_estimators': [10, 20], 'RFF__threshold': ['median', 'mean']},
       pre_dispatch='2*n_jobs', refit=True, score_func=None, scoring=None,
       verbose=0)


estimator.grid_scores_

Out[144]:
[mean: 0.89705, std: 0.00912, params: {'RF__n_estimators': 10, 'RFF__threshold': 'median'},
 mean: 0.91597, std: 0.00871, params: {'RF__n_estimators': 20, 'RFF__threshold': 'median'},
 mean: 0.89705, std: 0.00912, params: {'RF__n_estimators': 10, 'RFF__threshold': 'mean'},
 mean: 0.91597, std: 0.00871, params: {'RF__n_estimators': 20, 'RFF__threshold': 'mean'}]

If you need to modify the parameters of the RandomForestClassifier in the my_rf_filter class, I think you need to add them explicitly, i.e., not using **kwargs in __init__() and model.set_paras(**kwargs) since I failed doing that. I think add n_estimators=200 to __init__() and then model.n_estimators = self.n_estimators will work.

Decompress answered 20/7, 2015 at 19:38 Comment(0)
C
1

You can avoid most of the additional coding with the below hack.

First capture the variable reference for the estimator. ("estimator" in this case)You can look up the actual referred hyperparameter name during debugging.

For the above question

pipe = Pipeline([("RFF", rf_filter), ("RF", clf)])
...

param_grid = {"clf__estimator__n_estimators": [10, 20],

}

estimator = GridSearchCV(pipe,
                         cv = 3, 
                         param_grid )

So simply change Hyperparameter i.e. max_features to clf__estimator__max_features

Cordillera answered 9/2, 2017 at 8:30 Comment(0)

© 2022 - 2024 — McMap. All rights reserved.