Spark train test split
Asked Answered
T

4

21

I am curious if there is something similar to sklearn's http://scikit-learn.org/stable/modules/generated/sklearn.model_selection.StratifiedShuffleSplit.html for apache-spark in the latest 2.0.1 release.

So far I could only find https://spark.apache.org/docs/latest/mllib-statistics.html#stratified-sampling which does not seem to be a great fit for splitting heavily imbalanced dataset into train /test samples.

Towhead answered 12/10, 2016 at 9:2 Comment(7)
See Example: model selection via train validation split TrainValidationSplit creates a single (training, test) dataset pair. It splits the dataset into these two parts using the trainRatio parameter.Nilsson
Thanks. I did not know about that one. However, the TrainValidationSplit does neither seem to be randomized nor to support startified splits. Am I missing something here?Towhead
You're right, there's a Jira Ticket about this Support balanced class labels when splitting train/cross validation sets. So Mllib doesn't yet support this featureNilsson
Do you know about a decent work around until this is merged?Towhead
Have you already saw this answer?Nilsson
I will need to try that.Towhead
TrainValidationSplit is not helpful unless you are also doing you model training in sparks MLlib as well. It requires and Estimator, ParamMaps and more.Lorikeet
T
9

Spark supports stratified samples as outlined in https://s3.amazonaws.com/sparksummit-share/ml-ams-1.0.1/6-sampling/scala/6-sampling_student.html

df.stat.sampleBy("label", Map(0 -> .10, 1 -> .20, 2 -> .3), 0)
Towhead answered 19/6, 2018 at 14:46 Comment(3)
How can you use sampleBy for train test split.The way I see it, if you have two lines, with train = df.stat.sameplyBy... test=df.state.sameplBy... then the samples may have duplicate records. Unless I'm missing something?Factorize
After obtaining train with SampleBy, leftanti join the original df to train and you have testUnsupportable
@Unsupportable this is very helpful. I also found no use in sampleBy, because a sample is different than splitting. But here is the solution. I should add that if you have a prepared dataset then you will need to add a key column to obtain that join.Lorikeet
E
10

Perhaps this method wasn't available when the OP posted this question, but I'm leaving this here for future reference:

# splitting dataset into train and test set
train, test = df.randomSplit([0.7, 0.3], seed=42)
Erst answered 12/5, 2020 at 16:34 Comment(4)
But that won't be stratified as OP asked right? Also, a comma is missing in the tuple (?)Nutlet
Yeah, you got me there.Erst
Thanks! Still mind the "stratified" sampling requested by the OP en.wikipedia.org/wiki/Stratified_sampling (not random!)Nutlet
If you are planning on using RandomSplit, it is wise to take a look at this article: medium.com/udemy-engineering/…Hendecagon
M
9

Let's assume we have a dataset like this:

+---+-----+
| id|label|
+---+-----+
|  0|  0.0|
|  1|  1.0|
|  2|  0.0|
|  3|  1.0|
|  4|  0.0|
|  5|  1.0|
|  6|  0.0|
|  7|  1.0|
|  8|  0.0|
|  9|  1.0|
+---+-----+

This dataset is perfectly balanced, but this approach will work for unbalanced data as well.

Now, let's augment this DataFrame with additional information that will be useful in deciding which rows should go to train set. The steps are as follows:

  • Determine how many examples of every label should be a part of train set given some ratio.
  • Shuffle the rows of the DataFrame.
  • Use window function to partition and order the DataFrame by label and then rank each label's observations using row_number().

We end up with the following data frame:

+---+-----+----------+
| id|label|row_number|
+---+-----+----------+
|  6|  0.0|         1|
|  2|  0.0|         2|
|  0|  0.0|         3|
|  4|  0.0|         4|
|  8|  0.0|         5|
|  9|  1.0|         1|
|  5|  1.0|         2|
|  3|  1.0|         3|
|  1|  1.0|         4|
|  7|  1.0|         5|
+---+-----+----------+

Note: the rows are shuffled (see: random order in id column), partitioned by label (see: label column) and ranked.

Let's assume that we would like to make 80% split. In this case, we would like four 1.0 labels and four 0.0 labels to go to training dataset and one 1.0 label and one 0.0 label to go to test dataset. We have this information in row_number column, so now we can simply use it in user defined function (if row_number is less or equal four, the example goes to train set).

After applying the UDF, the resulting data frame is as follows:

+---+-----+----------+----------+
| id|label|row_number|isTrainSet|
+---+-----+----------+----------+
|  6|  0.0|         1|      true|
|  2|  0.0|         2|      true|
|  0|  0.0|         3|      true|
|  4|  0.0|         4|      true|
|  8|  0.0|         5|     false|
|  9|  1.0|         1|      true|
|  5|  1.0|         2|      true|
|  3|  1.0|         3|      true|
|  1|  1.0|         4|      true|
|  7|  1.0|         5|     false|
+---+-----+----------+----------+

Now, to get the train/test data one has to do:

val train = df.where(col("isTrainSet") === true)
val test = df.where(col("isTrainSet") === false)

These sorting and partitioning steps might be prohibitive for some really big datasets, so I suggest first filtering the dataset as much as possible. The physical plan is as follows:

== Physical Plan ==
*(3) Project [id#4, label#5, row_number#11, if (isnull(row_number#11)) null else UDF(label#5, row_number#11) AS isTrainSet#48]
+- Window [row_number() windowspecdefinition(label#5, label#5 ASC NULLS FIRST, specifiedwindowframe(RowFrame, unboundedpreceding$(), currentrow$())) AS row_number#11], [label#5], [label#5 ASC NULLS FIRST]
   +- *(2) Sort [label#5 ASC NULLS FIRST, label#5 ASC NULLS FIRST], false, 0
      +- Exchange hashpartitioning(label#5, 200)
         +- *(1) Project [id#4, label#5]
            +- *(1) Sort [_nondeterministic#9 ASC NULLS FIRST], true, 0
               +- Exchange rangepartitioning(_nondeterministic#9 ASC NULLS FIRST, 200)
                  +- LocalTableScan [id#4, label#5, _nondeterministic#9

Here's full working example (tested with Spark 2.3.0 and Scala 2.11.12):

import org.apache.spark.SparkConf
import org.apache.spark.sql.expressions.Window
import org.apache.spark.sql.{DataFrame, Row, SparkSession}
import org.apache.spark.sql.functions.{col, row_number, udf, rand}

class StratifiedTrainTestSplitter {

  def getNumExamplesPerClass(ss: SparkSession, label: String, trainRatio: Double)(df: DataFrame): Map[Double, Long] = {
    df.groupBy(label).count().createOrReplaceTempView("labelCounts")
    val query = f"SELECT $label AS ratioLabel, count, cast(count * $trainRatio as long) AS trainExamples FROM labelCounts"
    import ss.implicits._
    ss.sql(query)
      .select("ratioLabel", "trainExamples")
      .map((r: Row) => r.getDouble(0) -> r.getLong(1))
      .collect()
      .toMap
  }

  def split(df: DataFrame, label: String, trainRatio: Double): DataFrame = {
    val w = Window.partitionBy(col(label)).orderBy(col(label))

    val rowNumPartitioner = row_number().over(w)

    val dfRowNum = df.sort(rand).select(col("*"), rowNumPartitioner as "row_number")

    dfRowNum.show()

    val observationsPerLabel: Map[Double, Long] = getNumExamplesPerClass(df.sparkSession, label, trainRatio)(df)

    val addIsTrainColumn = udf((label: Double, rowNumber: Int) => rowNumber <= observationsPerLabel(label))

    dfRowNum.withColumn("isTrainSet", addIsTrainColumn(col(label), col("row_number")))
  }


}

object StratifiedTrainTestSplitter {

  def getDf(ss: SparkSession): DataFrame = {
    val data = Seq(
      (0, 0.0), (1, 1.0), (2, 0.0), (3, 1.0), (4, 0.0), (5, 1.0), (6, 0.0), (7, 1.0), (8, 0.0), (9, 1.0)
    )
    ss.createDataFrame(data).toDF("id", "label")
  }

  def main(args: Array[String]): Unit = {
    val spark: SparkSession = SparkSession
      .builder()
      .config(new SparkConf().setMaster("local[1]"))
      .getOrCreate()

    val df = new StratifiedTrainTestSplitter().split(getDf(spark), "label", 0.8)

    df.cache()

    df.where(col("isTrainSet") === true).show()
    df.where(col("isTrainSet") === false).show()
  }
}

Note: the labels are Doubles in this case. If your labels are Strings you'll have to switch types here and there.

Mientao answered 22/5, 2018 at 21:13 Comment(0)
T
9

Spark supports stratified samples as outlined in https://s3.amazonaws.com/sparksummit-share/ml-ams-1.0.1/6-sampling/scala/6-sampling_student.html

df.stat.sampleBy("label", Map(0 -> .10, 1 -> .20, 2 -> .3), 0)
Towhead answered 19/6, 2018 at 14:46 Comment(3)
How can you use sampleBy for train test split.The way I see it, if you have two lines, with train = df.stat.sameplyBy... test=df.state.sameplBy... then the samples may have duplicate records. Unless I'm missing something?Factorize
After obtaining train with SampleBy, leftanti join the original df to train and you have testUnsupportable
@Unsupportable this is very helpful. I also found no use in sampleBy, because a sample is different than splitting. But here is the solution. I should add that if you have a prepared dataset then you will need to add a key column to obtain that join.Lorikeet
E
4

Although this answer is not specific to Spark, in Apache beam I do this to split train 66% and test 33% (just an illustrative example, you can customize the partition_fn below to be more sophisticated and accept arguments such to specify the number of buckets or bias selection towards something or assure randomization is fair across dimensions, etc):

raw_data = p | 'Read Data' >> Read(...)

clean_data = (raw_data
              | "Clean Data" >> beam.ParDo(CleanFieldsFn())


def partition_fn(element):
    return random.randint(0, 2)

random_buckets = (clean_data | beam.Partition(partition_fn, 3))

clean_train_data = ((random_buckets[0], random_buckets[1])
                    | beam.Flatten())

clean_eval_data = random_buckets[2]
Endoplasm answered 12/12, 2017 at 21:7 Comment(0)

© 2022 - 2024 — McMap. All rights reserved.