How to plot ROC curve and precision-recall curve from BinaryClassificationMetrics
Asked Answered
H

1

6

I was trying to plot ROC curve and Precision-Recall curve in graph. The points are generated from the Spark Mllib BinaryClassificationMetrics. By following the following Spark https://spark.apache.org/docs/latest/mllib-evaluation-metrics.html

[(1.0,1.0), (0.0,0.4444444444444444)] Precision
[(1.0,1.0), (0.0,1.0)] Recall
[(1.0,1.0), (0.0,0.6153846153846153)] - F1Measure    
[(0.0,1.0), (1.0,1.0), (1.0,0.4444444444444444)]- Precision-Recall curve
[(0.0,0.0), (0.0,1.0), (1.0,1.0), (1.0,1.0)] - ROC curve
Hectograph answered 5/7, 2016 at 16:6 Comment(0)
L
7

It looks like you have a similar problem to what I experienced. You need to either flip your parameters to the Metrics constructor or perhaps pass in the probability instead of the prediction. So, for example, if you are using the BinaryClassificationMetrics and a RandomForestClassifier, then according to this page (under outputs) there is "prediction" and "probability".

Then initialize your Metrics thus:

    new BinaryClassificationMetrics(predictionsWithResponse
      .select(col("probability"),col("myLabel"))
      .rdd.map(r=>(r.getAs[DenseVector](0)(1),r.getDouble(1))))

With the DenseVector call used to extract the probability of the 1 class.

As for actual plotting, that's up to you (many fine tools for that), but at least you will get more than 1 point on you curve (besides the endpoints).

And in case it's not clear:

metrics.roc().collect() will give you the data for the ROC curve: Tuples of: (false positive rate, true positive rate).

Lenni answered 28/9, 2016 at 4:3 Comment(0)

© 2022 - 2024 — McMap. All rights reserved.