The following uses the Scala API but you can easily refactor it to Python if you really want to...
First things first:
- Estimator: implements
.fit()
that returns a Transformer
- Transformer: implements
.transform()
and manipulates the DataFrame
- Serialization/Deserialization: Do your best to use built-in Params and leverage simple
DefaultParamsWritable
trait + companion object extending DefaultParamsReadable[T]
. a.k.a Stay away from MLReader / MLWriter and keep your code simple.
- Parameters passing: Use a common trait extending the
Params
and share it between your Estimator and Model (a.k.a. Transformer)
Skeleton code:
// Common Parameters
trait MyCommonParams extends Params {
final val inputCols: StringArrayParam = // usage: new MyMeanValueStuff().setInputCols(...)
new StringArrayParam(this, "inputCols", "doc...")
def setInputCols(value: Array[String]): this.type = set(inputCols, value)
def getInputCols: Array[String] = $(inputCols)
final val meanValues: DoubleArrayParam =
new DoubleArrayParam(this, "meanValues", "doc...")
// more setters and getters
}
// Estimator
class MyMeanValueStuff(override val uid: String) extends Estimator[MyMeanValueStuffModel]
with DefaultParamsWritable // Enables Serialization of MyCommonParams
with MyCommonParams {
override def copy(extra: ParamMap): Estimator[MeanValueFillerModel] = defaultCopy(extra) // deafult
override def transformSchema(schema: StructType): StructType = schema // no changes
override def fit(dataset: Dataset[_]): MyMeanValueStuffModel = {
// your logic here. I can't do all the work for you! ;)
this.setMeanValues(meanValues)
copyValues(new MyMeanValueStuffModel(uid + "_model").setParent(this))
}
}
// Companion object enables deserialization of MyCommonParams
object MyMeanValueStuff extends DefaultParamsReadable[MyMeanValueStuff]
// Model (Transformer)
class MyMeanValueStuffModel(override val uid: String) extends Model[MyMeanValueStuffModel]
with DefaultParamsWritable // Enables Serialization of MyCommonParams
with MyCommonParams {
override def copy(extra: ParamMap): MyMeanValueStuffModel = defaultCopy(extra) // default
override def transformSchema(schema: StructType): StructType = schema // no changes
override def transform(dataset: Dataset[_]): DataFrame = {
// your logic here: zip inputCols and meanValues, toMap, replace nulls with NA functions
// you have access to both inputCols and meanValues here!
}
}
// Companion object enables deserialization of MyCommonParams
object MyMeanValueStuffModel extends DefaultParamsReadable[MyMeanValueStuffModel]
With the code above you can Serialize/Deserialize a Pipeline containing a MyMeanValueStuff stage.
Want to look at some real simple implementation of an Estimator? MinMaxScaler! (My example is actually simpler though...)