scikit-learn clustering: predict(X) vs. fit_predict(X)
Asked Answered
J

4

27

In scikit-learn, some clustering algorithms have both predict(X) and fit_predict(X) methods, like KMeans and MeanShift, while others only have the latter, like SpectralClustering. According to the doc:

fit_predict(X[, y]):    Performs clustering on X and returns cluster labels.
predict(X): Predict the closest cluster each sample in X belongs to.

I don't really understand the difference between the two, they seem equivalent to me.

Jonette answered 9/5, 2016 at 2:25 Comment(1)
does predict returns the same thing as kmeans.labels_ or more accurate ones ?Negativism
O
24

In order to use the 'predict' you must use the 'fit' method first. So using 'fit()' and then 'predict()' is definitely the same as using 'fit_predict()'. However, one could benefit from using only 'fit()' in such cases where you need to know the initialization parameters of your models rather than if you use 'fit_predict()', where you will just be obtained the labeling results of running your model on the data.

Ovida answered 9/5, 2016 at 3:20 Comment(4)
this doesn't quite answer the question. they're asking "why does KMeans have a predict method, but SpectralClustering doesn't ... and i actually can't work out the answer to this ... maybe it's a bug/missing feature? my understanding is that as part of the scikit-learn API design all classifiers should have a fit and a predict method ...Titanism
Yeah, you are right. My answer was going more towards the 'I don't really understand the difference between the two, they seem equivalent to me.' part.Ovida
So... do we have an answer to @Titanism 's question?Steamtight
This does not answer the question.Firstfoot
R
8

fit_predict is usually used for unsupervised machine learning transductive estimator.

Basically, fit_predict(x) is equivalent to fit(x).predict(x).

Recrement answered 27/5, 2018 at 6:7 Comment(1)
from a comp-sci point of view fit() will affect the state of your object/model yes/no? where as predict() will use the existing model to label the input data (and no change will be made to the object/model, yes/no?Regal
M
2

This might be very late to add an answer here, It just that someone might get benefitted in future

The reason I could relate for having predict in kmeans and only fit_predict in dbscan is

  • In kmeans you get centroids based on the number of clusters considered. So once you trained your datapoints using fit(), you can use that to predict() a new single datapoint to assign to a specific cluster.

  • In dbscan you don't have centroids , based on the min_samples and eps (min distance between two points to be considered as neighbors) you define, clusters are formed . This algorithm returns cluster labels for all the datapoints. This behavior explains why there is no predict() method to predict a single datapoint. Difference between fit() and fit_predict() was already explained by other user -

In another spatial clustering algorithm hdbscan gives us an option to predict using approximate_predict(). Its worth to explore that.

Again its my understanding based on the source code I explored. Any experts can highlight any difference.

Millican answered 6/1, 2021 at 13:26 Comment(0)
C
0

Models such as KMeans can predict targets for any data after fitting. Other methods like SpectralClustering have no idea where unknown data may belong. These can predict data only if the data were part of the fitting process. This is why the method fit_predict(data) exists.

Example: dbscan finds clusters and outliers. Suppose a new data point is added. This data point may bridge clusters together. dbscan would need to repeat the fitting process to determine the situation. Thus, the predict method alone is not always sufficient.

Colet answered 15/7, 2024 at 8:57 Comment(0)

© 2022 - 2025 — McMap. All rights reserved.