Using cross_val_predict against test data set
Asked Answered
G

3

8

I'm confused about using cross_val_predict in a test data set.

I created a simple Random Forest model and used cross_val_predict to make predictions:

from sklearn.ensemble import RandomForestClassifier
from sklearn.cross_validation import cross_val_predict, KFold

lr = RandomForestClassifier(random_state=1, class_weight="balanced", n_estimators=25, max_depth=6)
kf = KFold(train_df.shape[0], random_state=1)
predictions = cross_val_predict(lr,train_df[features_columns], train_df["target"], cv=kf)
predictions = pd.Series(predictions)

I'm confused on the next step here. How do I use what is learnt above to make predictions on the test data set?

Glyptography answered 10/1, 2017 at 2:28 Comment(2)
You have to fit your model first, then you can call predict on it.Pritchard
The chosen answer is irrelevant to the question, and partly wrong.Rawls
P
1

As @DmitryPolonskiy commented, the model has to be trained (with the fit method) before it can be used to predict.

# Train the model (a.k.a. `fit` training data to it).
lr.fit(train_df[features_columns], train_df["target"])
# Use the model to make predictions based on testing data.
y_pred = lr.predict(test_df[feature_columns])
# Compare the predicted y values to actual y values.
accuracy = (y_pred == test_df["target"]).mean()

cross_val_predict is a method of cross validation, which lets you determine the accuracy of your model. Take a look at sklearn's cross-validation page.

Penstemon answered 10/1, 2017 at 4:17 Comment(2)
Unfortunately this answer is wrong; it doesn't mention 'cross_val_predict' anywhere and, as others have answered, I don't think that you have to train before using the function - one of its features is precisely that it does the training, iteratively on the folds.Rawls
I think you're right @Helen. You should write your own answer.Penstemon
H
3

I don't think cross_val_score or cross_val_predict uses fit before predicting. It does it on the fly. If you look at the documentation (section 3.1.1.1), you'll see that they never mention fit anywhere.

Hipped answered 9/8, 2017 at 23:56 Comment(0)
P
1

As @DmitryPolonskiy commented, the model has to be trained (with the fit method) before it can be used to predict.

# Train the model (a.k.a. `fit` training data to it).
lr.fit(train_df[features_columns], train_df["target"])
# Use the model to make predictions based on testing data.
y_pred = lr.predict(test_df[feature_columns])
# Compare the predicted y values to actual y values.
accuracy = (y_pred == test_df["target"]).mean()

cross_val_predict is a method of cross validation, which lets you determine the accuracy of your model. Take a look at sklearn's cross-validation page.

Penstemon answered 10/1, 2017 at 4:17 Comment(2)
Unfortunately this answer is wrong; it doesn't mention 'cross_val_predict' anywhere and, as others have answered, I don't think that you have to train before using the function - one of its features is precisely that it does the training, iteratively on the folds.Rawls
I think you're right @Helen. You should write your own answer.Penstemon
M
0

I am not sure the question was answered. I had a similar thought. I want compare the results (Accuracy for example) with the method that does not apply CV. The CV valiadte accuracy is on the X_train and y_train. The other method fit the model using X_trian and y_train, tested on the X_test and y_test. So the comparison is not fair since they are on different datasets.

What you can do is using the estimator returned by the cross_validate

lr_fit = cross_validate(lr, train_df[features_columns], train_df["target"], cv=kf, return_estimator=Ture)

y_pred = lr_fit.predict(test_df[feature_columns])

accuracy = (y_pred == test_df["target"]).mean()

Mahoney answered 20/3, 2022 at 9:23 Comment(1)
As it’s currently written, your answer is unclear. Please edit to add additional details that will help others understand how this addresses the question asked. You can find more information on how to write good answers in the help center.Hawsepiece

© 2022 - 2024 — McMap. All rights reserved.