I am new to both Spark and PySpark Data Frames and ML. How can I create a custom cross validation for the ML library. I want for example change the way the training folds are formed, e.g. stratified splits.
This is my current code
numFolds = 10
predictions = []
lr = LogisticRegression()\
.setFeaturesCol("features")\
.setLabelCol('label')
# Grid search on LR model
lrparamGrid = ParamGridBuilder()\
.addGrid(lr.regParam, [0.01, 0.1, 0.5, 1.0, 2.0])\
.addGrid(lr.elasticNetParam, [0.0, 0.1, 0.5, 0.8, 1.0])\
.addGrid(lr.maxIter, [5, 10, 20])\
.build()
pipelineModel = Pipeline(stages=[lr])
evaluator = BinaryClassificationEvaluator()
cv = CrossValidator()\
.setEstimator(pipelineModel)\
.setEvaluator(evaluator)\
.setEstimatorParamMaps(lrparamGrid).setNumFolds(5)
# My own Cross-Validation with stratified splits
for i in range(numFolds):
# Use Stratified indexOfStratifiedSplits
trainingData = df[df.ID.isin(indexOfStratifiedSplits[i][0])]
testingData = df[df.ID.isin(indexOfStratifiedSplits[i][1])]
# Training and Grid Search
cvModel = cv.fit(trainingData)
predictions.append(cvModel.transform(testingData))
I would like to have a Cross-Validation class to be called like this
cv = MyCrossValidator()\
.setEstimator(pipelineModel)\
.setEvaluator(evaluator)\
.setEstimatorParamMaps(lrparamGrid).setNumFolds(5)\
# Option 1
.setSplitIndexes(indexOfStratifiedSplits)
# Option 2
.setSplitType("Stratified",ColumnName)
I don't know whether the best option is to create a class that extends CrossValidation.fit or Passing Functions to Spark. Either option is challenging for me as newbie, I tried copying GitHub codes but I get tons of errors, specially I don't speak Scala, but this pipeline is faster in Scala API.
Although I have my own functions to split the data in the way I want (sklearn based), I want to use Pipelines, grid search and cv together, such that all permutations are distributed rather than executed in master. That loop with "My own Cross-Validation" only uses part of the cluster nodes as the loop happens in master/driver.
Any Python or Scala API is fine, but preferable Scala.
Thanks