Is there any means to serialize custom Transformer in Spark ML Pipeline
Asked Answered
C

2

6

I use ML pipeline with various custom UDF-based transformers. What I'm looking for is a way to serialize/deserialize this pipeline.

I serialize the PipelineModel using

ObjectOutputStream.write() 

However whenever I try to deserialize the pipeline I'm having:

java.lang.ClassNotFoundException: org.sparkexample.DateTransformer

Where is DateTransformer is my custom transformer. Is there any method/interface to implement for proper serialization?

I've found out there is

MLWritable

Interface that might be implemented by my class (DateTransformer extends Transfrormer), however can't find useful example of it.

Clastic answered 27/10, 2016 at 12:8 Comment(0)
I
4

If you are using Spark 2.x+ then extend your transformer with DefaultParamsWritable

for example

class ProbabilityMaxer extends Transformer with DefaultParamsWritable{

Then create a constructor with a string parameter

 def this(_uid: String) {
    this()
  }

Finally for a successful read add a companion class

object ProbabilityMaxer extends  DefaultParamsReadable[ProbabilityMaxer]

I have this working on my production server. I will add gitlab link to the project later when I upload it

Ileanaileane answered 9/7, 2017 at 8:11 Comment(0)
T
2

The short answer is you can't, at least not easily.

The devs have gone out of their way to make adding a new transformer/estimator as difficult as possible. Basically everything in org.apache.spark.ml.util.ReadWrite is private (except for MLWritable and MLReadable) so there is no way to use any of the utility methods/classes/objects there. There is also (as I'm sure you've already discovered) absolutely no documentation on how this should be done, but hey good code documents itself right?!

Digging through the code in org.apache.spark.ml.util.ReadWrite and org.apache.spark.ml.feature.HashingTF it seems that you need to override MLWritable.write and MLReadable.read. The DefaultParamsWriter and DefaultParamsReader which seem to contain the actually save/load implementations are saving and loading a bunch of metadata:

  • class
  • timestamp
  • sparkVersion
  • uid
  • paramMap
  • (optionally, extra metadata)

so any implementation would at least need to cover those, and a transformer that doesn't need to learn any model would probably get away with just that. A model that needs to be fitted then also needs to save that data in its implementation of save/write - for instance this is that the LocalLDAModel does (https://github.com/apache/spark/blob/v1.6.3/mllib/src/main/scala/org/apache/spark/ml/clustering/LDA.scala#L523) so the learned model is just saves as a parquet file (it seems)

val data = sqlContext.read.parquet(dataPath)
        .select("vocabSize", "topicsMatrix", "docConcentration", "topicConcentration",
          "gammaShape")
        .head()

As a test I copied everything from org.apache.spark.ml.util.ReadWrite that seems to be needed and tested the following transformer which does not do anything useful.

WARNING: this is almost certainly the wrong thing to do and is very likely to break in the future. I sincerely hope I've misunderstood something and someone is going to correct me on how to actually create a transformer that can be serialised/deserialised

this is for spark 1.6.3 and may already be broken, if you're using 2.x

import org.apache.spark.sql.types._
import org.apache.spark.ml.param._
import org.apache.hadoop.fs.Path
import org.apache.spark.SparkContext
import org.apache.spark.ml.Transformer
import org.apache.spark.ml.util.{Identifiable, MLReadable, MLReader, MLWritable, MLWriter}
import org.apache.spark.sql.{SQLContext, DataFrame}
import org.apache.spark.mllib.linalg._

import org.json4s._
import org.json4s.JsonDSL._
import org.json4s.jackson.JsonMethods._

object CustomTransform extends DefaultParamsReadable[CustomTransform] {
  /* Companion object for deserialisation */
  override def load(path: String): CustomTransform = super.load(path)
}

class CustomTransform(override val uid: String)
  extends Transformer with DefaultParamsWritable {

  def this() = this(Identifiable.randomUID("customThing"))

  def setInputCol(value: String): this.type = set(inputCol, value)
  def setOutputCol(value: String): this.type = set(outputCol, value)
  def getOutputCol(): String = getOrDefault(outputCol)

  val inputCol = new Param[String](this, "inputCol", "input column")
  val outputCol = new Param[String](this, "outputCol", "output column")

  override def transform(dataset: DataFrame): DataFrame = {
    val sqlContext = SQLContext.getOrCreate(SparkContext.getOrCreate())
    import sqlContext.implicits._

    val outCol = extractParamMap.getOrElse(outputCol, "output")
    val inCol = extractParamMap.getOrElse(inputCol, "input")
    val transformUDF = udf({ vector: SparseVector =>
      vector.values.map( _ * 10 )
      // WHAT EVER YOUR TRANSFORMER NEEDS TO DO GOES HERE
    })

    dataset.withColumn(outCol, transformUDF(col(inCol)))
  }

  override def copy(extra: ParamMap): Transformer = defaultCopy(extra)

  override def transformSchema(schema: StructType): StructType = {
    val outputFields = schema.fields :+ StructField(extractParamMap.getOrElse(outputCol, "filtered"), new VectorUDT, nullable = false)
    StructType(outputFields)
  }
}

Then we need all the utilities from org.apache.spark.ml.util.ReadWrite https://github.com/apache/spark/blob/v1.6.3/mllib/src/main/scala/org/apache/spark/ml/util/ReadWrite.scala

trait DefaultParamsWritable extends MLWritable { self: Params =>
  override def write: MLWriter = new DefaultParamsWriter(this)
}

trait DefaultParamsReadable[T] extends MLReadable[T] {
  override def read: MLReader[T] = new DefaultParamsReader
}

class DefaultParamsWriter(instance: Params) extends MLWriter {
  override protected def saveImpl(path: String): Unit = {
    DefaultParamsWriter.saveMetadata(instance, path, sc)
  }
}

object DefaultParamsWriter {

  /**
    * Saves metadata + Params to: path + "/metadata"
    *  - class
    *  - timestamp
    *  - sparkVersion
    *  - uid
    *  - paramMap
    *  - (optionally, extra metadata)
    * @param extraMetadata  Extra metadata to be saved at same level as uid, paramMap, etc.
    * @param paramMap  If given, this is saved in the "paramMap" field.
    *                  Otherwise, all [[org.apache.spark.ml.param.Param]]s are encoded using
    *                  [[org.apache.spark.ml.param.Param.jsonEncode()]].
    */
  def saveMetadata(
  instance: Params,
  path: String,
  sc: SparkContext,
  extraMetadata: Option[JObject] = None,
  paramMap: Option[JValue] = None): Unit = {
    val uid = instance.uid
    val cls = instance.getClass.getName
    val params = instance.extractParamMap().toSeq.asInstanceOf[Seq[ParamPair[Any]]]
    val jsonParams = paramMap.getOrElse(render(params.map { case ParamPair(p, v) =>
      p.name -> parse(p.jsonEncode(v))
    }.toList))
    val basicMetadata = ("class" -> cls) ~
    ("timestamp" -> System.currentTimeMillis()) ~
    ("sparkVersion" -> sc.version) ~
    ("uid" -> uid) ~
    ("paramMap" -> jsonParams)
    val metadata = extraMetadata match {
      case Some(jObject) =>
        basicMetadata ~ jObject
      case None =>
        basicMetadata
    }
    val metadataPath = new Path(path, "metadata").toString
    val metadataJson = compact(render(metadata))
    sc.parallelize(Seq(metadataJson), 1).saveAsTextFile(metadataPath)
  }
}

class DefaultParamsReader[T] extends MLReader[T] {
  override def load(path: String): T = {
    val metadata = DefaultParamsReader.loadMetadata(path, sc)
    val cls = Class.forName(metadata.className, true, Option(Thread.currentThread().getContextClassLoader).getOrElse(getClass.getClassLoader))
    val instance =
    cls.getConstructor(classOf[String]).newInstance(metadata.uid).asInstanceOf[Params]
    DefaultParamsReader.getAndSetParams(instance, metadata)
    instance.asInstanceOf[T]
  }
}

object DefaultParamsReader {

  /**
    * All info from metadata file.
    *
    * @param params       paramMap, as a [[JValue]]
    * @param metadata     All metadata, including the other fields
    * @param metadataJson Full metadata file String (for debugging)
    */
  case class Metadata(
                       className: String,
                       uid: String,
                       timestamp: Long,
                       sparkVersion: String,
                       params: JValue,
                       metadata: JValue,
                       metadataJson: String)

  /**
    * Load metadata from file.
    *
    * @param expectedClassName If non empty, this is checked against the loaded metadata.
    * @throws IllegalArgumentException if expectedClassName is specified and does not match metadata
    */
  def loadMetadata(path: String, sc: SparkContext, expectedClassName: String = ""): Metadata = {
    val metadataPath = new Path(path, "metadata").toString
    val metadataStr = sc.textFile(metadataPath, 1).first()
    val metadata = parse(metadataStr)

    implicit val format = DefaultFormats
    val className = (metadata \ "class").extract[String]
    val uid = (metadata \ "uid").extract[String]
    val timestamp = (metadata \ "timestamp").extract[Long]
    val sparkVersion = (metadata \ "sparkVersion").extract[String]
    val params = metadata \ "paramMap"
    if (expectedClassName.nonEmpty) {
      require(className == expectedClassName, s"Error loading metadata: Expected class name" +
        s" $expectedClassName but found class name $className")
    }

    Metadata(className, uid, timestamp, sparkVersion, params, metadata, metadataStr)
  }

  /**
    * Extract Params from metadata, and set them in the instance.
    * This works if all Params implement [[org.apache.spark.ml.param.Param.jsonDecode()]].
    */
  def getAndSetParams(instance: Params, metadata: Metadata): Unit = {
    implicit val format = DefaultFormats
    metadata.params match {
      case JObject(pairs) =>
        pairs.foreach { case (paramName, jsonValue) =>
          val param = instance.getParam(paramName)
          val value = param.jsonDecode(compact(render(jsonValue)))
          instance.set(param, value)
        }
      case _ =>
        throw new IllegalArgumentException(
          s"Cannot recognize JSON metadata: ${metadata.metadataJson}.")
    }
  }

  /**
    * Load a [[Params]] instance from the given path, and return it.
    * This assumes the instance implements [[MLReadable]].
    */
  def loadParamsInstance[T](path: String, sc: SparkContext): T = {
    val metadata = DefaultParamsReader.loadMetadata(path, sc)
    val cls = Class.forName(metadata.className, true, Option(Thread.currentThread().getContextClassLoader).getOrElse(getClass.getClassLoader))
    cls.getMethod("read").invoke(null).asInstanceOf[MLReader[T]].load(path)
  }
}

With that in place you can use the CustomTransformer in a Pipeline and save/load the pipeline. I tested that fairly quickly in spark shell and it seems to work but certainly isn't pretty.

Taliped answered 12/3, 2017 at 12:24 Comment(0)

© 2022 - 2024 — McMap. All rights reserved.