Spark custom estimator including persistence
Asked Answered
J

2

6

I want to develop a custom estimator for spark which handles persistence of the great pipeline API as well. But as How to Roll a Custom Estimator in PySpark mllib put it there is not a lot of documentation out there (yet).

I have some data cleansing code written in spark and would like to wrap it in a custom estimator. Some na-substitutions, column deletions, filtering and basic feature generation are included (e.g. birthdate to age).

  • transformSchema will use the case class of the dataset ScalaReflection.schemaFor[MyClass].dataType.asInstanceOf[StructType]
  • fit will only fit e.g. mean age as na. substitutes

What is still pretty unclear to me:

  • transform in the custom pipeline model will be used to transform the "fitted" Estimator on new data. Is this correct? If yes how should I transfer the fitted values e.g. the mean age from above into the model?

  • how to handle persistence? I found some generic loadImpl method within private spark components but am unsure how to transfer my own parameters e.g. the mean age into the MLReader / MLWriter which are used for serialization.

It would be great if you could help me with a custom estimator - especially with the persistence part.

Jaco answered 26/11, 2016 at 10:38 Comment(0)
K
4

First of all I believe you're mixing a bit two different things:

  • Estimators - which represent stages that can be fit-ted. Estimator fit method takes Dataset and returns Transformer (model).
  • Transformers - which represent stages that can transform data.

When you fit Pipeline it fits all Estimators and returns PipelineModel. PipelineModel can transform data sequentially calling transform on all Transformers in the the model.

how should I transfer the fitted values

There is no single answer to this question. In general you have two options:

  • Pass parameters of the fitted model as the arguments of the Transformer.
  • Make parameters of the fitted model Params of the Transformer.

The first approach is typically used by the built-in Transformer, but the second one should work in some simple cases.

how to handle persistence

  • If Transformer is defined only by its Params you can extend DefaultParamsReadable.
  • If you use more complex arguments you should extend MLWritable and implement MLWriter that makes sense for your data. There are multiple examples in Spark source which show how to implement data and metadata reading / writing.

If you're looking for an easy to comprehend example take a look a the CountVectorizer(Model) where:

Khedive answered 26/11, 2016 at 12:0 Comment(3)
Merely extending DefaultParamsReadable and DefaultParamsWritable give java.lang.NoSuchMethodException: read(). Examples seem to use DefaultParamsWriter and DefaultParamsReader, but they are marked as private[ml] and can't be used in custom transformers. Am I missing something?Winterize
@Winterize Doesn't sound like you do. By contract to read / write stage you need MLReadable[_] / MLWritable[_], which need accompanying MLReader[_] / MLWriter[_] to satisfy their read and write methods.Khedive
Do you know of a good example of this. Every example I can find involves DefaultParamsReader and DefaultParamsWriter both of which are private[ml] making them inaccessible to me.Winterize
T
2

The following uses the Scala API but you can easily refactor it to Python if you really want to...

First things first:

  • Estimator: implements .fit() that returns a Transformer
  • Transformer: implements .transform() and manipulates the DataFrame
  • Serialization/Deserialization: Do your best to use built-in Params and leverage simple DefaultParamsWritable trait + companion object extending DefaultParamsReadable[T]. a.k.a Stay away from MLReader / MLWriter and keep your code simple.
  • Parameters passing: Use a common trait extending the Params and share it between your Estimator and Model (a.k.a. Transformer)

Skeleton code:

// Common Parameters
trait MyCommonParams extends Params {
  final val inputCols: StringArrayParam = // usage: new MyMeanValueStuff().setInputCols(...)
    new StringArrayParam(this, "inputCols", "doc...")
    def setInputCols(value: Array[String]): this.type = set(inputCols, value)
    def getInputCols: Array[String] = $(inputCols)

  final val meanValues: DoubleArrayParam = 
    new DoubleArrayParam(this, "meanValues", "doc...")
    // more setters and getters
}

// Estimator
class MyMeanValueStuff(override val uid: String) extends Estimator[MyMeanValueStuffModel] 
  with DefaultParamsWritable // Enables Serialization of MyCommonParams
  with MyCommonParams {

  override def copy(extra: ParamMap): Estimator[MeanValueFillerModel] = defaultCopy(extra) // deafult
  override def transformSchema(schema: StructType): StructType = schema // no changes
  override def fit(dataset: Dataset[_]): MyMeanValueStuffModel = {
    // your logic here. I can't do all the work for you! ;)
   this.setMeanValues(meanValues)
   copyValues(new MyMeanValueStuffModel(uid + "_model").setParent(this))
  }
}
// Companion object enables deserialization of MyCommonParams
object MyMeanValueStuff extends DefaultParamsReadable[MyMeanValueStuff]

// Model (Transformer)
class MyMeanValueStuffModel(override val uid: String) extends Model[MyMeanValueStuffModel] 
  with DefaultParamsWritable // Enables Serialization of MyCommonParams
  with MyCommonParams {

  override def copy(extra: ParamMap): MyMeanValueStuffModel = defaultCopy(extra) // default
  override def transformSchema(schema: StructType): StructType = schema // no changes
  override def transform(dataset: Dataset[_]): DataFrame = {
      // your logic here: zip inputCols and meanValues, toMap, replace nulls with NA functions
      // you have access to both inputCols and meanValues here!
  }
}
// Companion object enables deserialization of MyCommonParams
object MyMeanValueStuffModel extends DefaultParamsReadable[MyMeanValueStuffModel]

With the code above you can Serialize/Deserialize a Pipeline containing a MyMeanValueStuff stage.

Want to look at some real simple implementation of an Estimator? MinMaxScaler! (My example is actually simpler though...)

Toaster answered 31/5, 2017 at 3:22 Comment(1)
working off your example I get java.lang.NoSuchMethodException: read() when I try to load the model. What am I not understanding in your example?Winterize

© 2022 - 2024 — McMap. All rights reserved.