How to get Precision/Recall using CrossValidator for training NaiveBayes Model using Spark
Asked Answered
D

1

2

Supossed I have a Pipeline like this:

val tokenizer = new Tokenizer().setInputCol("tweet").setOutputCol("words")
val hashingTF = new HashingTF().setNumFeatures(1000).setInputCol("words").setOutputCol("features")
val idf = new IDF().setInputCol("features").setOutputCol("idffeatures")
val nb = new org.apache.spark.ml.classification.NaiveBayes()
val pipeline = new Pipeline().setStages(Array(tokenizer, hashingTF, idf, nb))
val paramGrid = new ParamGridBuilder().addGrid(hashingTF.numFeatures, Array(10, 100, 1000)).addGrid(nb.smoothing, Array(0.01, 0.1, 1)).build()
val cv = new CrossValidator().setEstimator(pipeline).setEvaluator(new BinaryClassificationEvaluator()).setEstimatorParamMaps(paramGrid).setNumFolds(10)
val cvModel = cv.fit(df)

As you can see I defined a CrossValidator using a MultiClassClassificationEvaluator. I have seen a lot of examples getting metrics like Precision/Recall during testing process but these metris are gotten when you use a different set of data for testing purposes (See for example this documentation).

From my understanding, CrossValidator is going to create folds and one fold will be use for testing purposes, then CrossValidator will choose the best model. My question is, is possible to get Precision/Recall metrics during training process?

Divided answered 12/6, 2016 at 19:59 Comment(0)
C
3

Well, the only metric which is actually stored is the one you define when you create an instance of an Evaluator. For the BinaryClassificationEvaluator this can take one of the two values:

  • areaUnderROC
  • areaUnderPR

with the former one being default, and can be set using setMetricName method.

These values are collected during training process and can accessed using CrossValidatorModel.avgMetrics. Order of values corresponds to the order of EstimatorParamMaps (CrossValidatorModel.getEstimatorParamMaps).

Collum answered 13/6, 2016 at 16:58 Comment(1)
@zero233 I don't seem to find the avgMetrics and other attributes in my spark . cvModel.avgMetrics --------------------------------------------------------------------------- AttributeError Traceback (most recent call last) <ipython-input-53-d454795422c1> in <module>() ----> 1 cvModel.avgMetrics AttributeError: 'CrossValidatorModel' object has no attribute 'avgMetrics'Quickman

© 2022 - 2024 — McMap. All rights reserved.