The short answer is you can't, at least not easily.
The devs have gone out of their way to make adding a new transformer/estimator as difficult as possible. Basically everything in org.apache.spark.ml.util.ReadWrite
is private (except for MLWritable
and MLReadable
) so there is no way to use any of the utility methods/classes/objects there. There is also (as I'm sure you've already discovered) absolutely no documentation on how this should be done, but hey good code documents itself right?!
Digging through the code in org.apache.spark.ml.util.ReadWrite
and org.apache.spark.ml.feature.HashingTF
it seems that you need to override MLWritable.write
and MLReadable.read
. The DefaultParamsWriter
and DefaultParamsReader
which seem to contain the actually save/load implementations are saving and loading a bunch of metadata:
- class
- timestamp
- sparkVersion
- uid
- paramMap
- (optionally, extra metadata)
so any implementation would at least need to cover those, and a transformer that doesn't need to learn any model would probably get away with just that. A model that needs to be fitted then also needs to save that data in its implementation of save/write
- for instance this is that the LocalLDAModel
does (https://github.com/apache/spark/blob/v1.6.3/mllib/src/main/scala/org/apache/spark/ml/clustering/LDA.scala#L523) so the learned model is just saves as a parquet file (it seems)
val data = sqlContext.read.parquet(dataPath)
.select("vocabSize", "topicsMatrix", "docConcentration", "topicConcentration",
"gammaShape")
.head()
As a test I copied everything from org.apache.spark.ml.util.ReadWrite
that seems to be needed and tested the following transformer which does not do anything useful.
WARNING: this is almost certainly the wrong thing to do and is very likely to break in the future. I sincerely hope I've misunderstood something and someone is going to correct me on how to actually create a transformer that can be serialised/deserialised
this is for spark 1.6.3 and may already be broken, if you're using 2.x
import org.apache.spark.sql.types._
import org.apache.spark.ml.param._
import org.apache.hadoop.fs.Path
import org.apache.spark.SparkContext
import org.apache.spark.ml.Transformer
import org.apache.spark.ml.util.{Identifiable, MLReadable, MLReader, MLWritable, MLWriter}
import org.apache.spark.sql.{SQLContext, DataFrame}
import org.apache.spark.mllib.linalg._
import org.json4s._
import org.json4s.JsonDSL._
import org.json4s.jackson.JsonMethods._
object CustomTransform extends DefaultParamsReadable[CustomTransform] {
/* Companion object for deserialisation */
override def load(path: String): CustomTransform = super.load(path)
}
class CustomTransform(override val uid: String)
extends Transformer with DefaultParamsWritable {
def this() = this(Identifiable.randomUID("customThing"))
def setInputCol(value: String): this.type = set(inputCol, value)
def setOutputCol(value: String): this.type = set(outputCol, value)
def getOutputCol(): String = getOrDefault(outputCol)
val inputCol = new Param[String](this, "inputCol", "input column")
val outputCol = new Param[String](this, "outputCol", "output column")
override def transform(dataset: DataFrame): DataFrame = {
val sqlContext = SQLContext.getOrCreate(SparkContext.getOrCreate())
import sqlContext.implicits._
val outCol = extractParamMap.getOrElse(outputCol, "output")
val inCol = extractParamMap.getOrElse(inputCol, "input")
val transformUDF = udf({ vector: SparseVector =>
vector.values.map( _ * 10 )
// WHAT EVER YOUR TRANSFORMER NEEDS TO DO GOES HERE
})
dataset.withColumn(outCol, transformUDF(col(inCol)))
}
override def copy(extra: ParamMap): Transformer = defaultCopy(extra)
override def transformSchema(schema: StructType): StructType = {
val outputFields = schema.fields :+ StructField(extractParamMap.getOrElse(outputCol, "filtered"), new VectorUDT, nullable = false)
StructType(outputFields)
}
}
Then we need all the utilities from org.apache.spark.ml.util.ReadWrite
https://github.com/apache/spark/blob/v1.6.3/mllib/src/main/scala/org/apache/spark/ml/util/ReadWrite.scala
trait DefaultParamsWritable extends MLWritable { self: Params =>
override def write: MLWriter = new DefaultParamsWriter(this)
}
trait DefaultParamsReadable[T] extends MLReadable[T] {
override def read: MLReader[T] = new DefaultParamsReader
}
class DefaultParamsWriter(instance: Params) extends MLWriter {
override protected def saveImpl(path: String): Unit = {
DefaultParamsWriter.saveMetadata(instance, path, sc)
}
}
object DefaultParamsWriter {
/**
* Saves metadata + Params to: path + "/metadata"
* - class
* - timestamp
* - sparkVersion
* - uid
* - paramMap
* - (optionally, extra metadata)
* @param extraMetadata Extra metadata to be saved at same level as uid, paramMap, etc.
* @param paramMap If given, this is saved in the "paramMap" field.
* Otherwise, all [[org.apache.spark.ml.param.Param]]s are encoded using
* [[org.apache.spark.ml.param.Param.jsonEncode()]].
*/
def saveMetadata(
instance: Params,
path: String,
sc: SparkContext,
extraMetadata: Option[JObject] = None,
paramMap: Option[JValue] = None): Unit = {
val uid = instance.uid
val cls = instance.getClass.getName
val params = instance.extractParamMap().toSeq.asInstanceOf[Seq[ParamPair[Any]]]
val jsonParams = paramMap.getOrElse(render(params.map { case ParamPair(p, v) =>
p.name -> parse(p.jsonEncode(v))
}.toList))
val basicMetadata = ("class" -> cls) ~
("timestamp" -> System.currentTimeMillis()) ~
("sparkVersion" -> sc.version) ~
("uid" -> uid) ~
("paramMap" -> jsonParams)
val metadata = extraMetadata match {
case Some(jObject) =>
basicMetadata ~ jObject
case None =>
basicMetadata
}
val metadataPath = new Path(path, "metadata").toString
val metadataJson = compact(render(metadata))
sc.parallelize(Seq(metadataJson), 1).saveAsTextFile(metadataPath)
}
}
class DefaultParamsReader[T] extends MLReader[T] {
override def load(path: String): T = {
val metadata = DefaultParamsReader.loadMetadata(path, sc)
val cls = Class.forName(metadata.className, true, Option(Thread.currentThread().getContextClassLoader).getOrElse(getClass.getClassLoader))
val instance =
cls.getConstructor(classOf[String]).newInstance(metadata.uid).asInstanceOf[Params]
DefaultParamsReader.getAndSetParams(instance, metadata)
instance.asInstanceOf[T]
}
}
object DefaultParamsReader {
/**
* All info from metadata file.
*
* @param params paramMap, as a [[JValue]]
* @param metadata All metadata, including the other fields
* @param metadataJson Full metadata file String (for debugging)
*/
case class Metadata(
className: String,
uid: String,
timestamp: Long,
sparkVersion: String,
params: JValue,
metadata: JValue,
metadataJson: String)
/**
* Load metadata from file.
*
* @param expectedClassName If non empty, this is checked against the loaded metadata.
* @throws IllegalArgumentException if expectedClassName is specified and does not match metadata
*/
def loadMetadata(path: String, sc: SparkContext, expectedClassName: String = ""): Metadata = {
val metadataPath = new Path(path, "metadata").toString
val metadataStr = sc.textFile(metadataPath, 1).first()
val metadata = parse(metadataStr)
implicit val format = DefaultFormats
val className = (metadata \ "class").extract[String]
val uid = (metadata \ "uid").extract[String]
val timestamp = (metadata \ "timestamp").extract[Long]
val sparkVersion = (metadata \ "sparkVersion").extract[String]
val params = metadata \ "paramMap"
if (expectedClassName.nonEmpty) {
require(className == expectedClassName, s"Error loading metadata: Expected class name" +
s" $expectedClassName but found class name $className")
}
Metadata(className, uid, timestamp, sparkVersion, params, metadata, metadataStr)
}
/**
* Extract Params from metadata, and set them in the instance.
* This works if all Params implement [[org.apache.spark.ml.param.Param.jsonDecode()]].
*/
def getAndSetParams(instance: Params, metadata: Metadata): Unit = {
implicit val format = DefaultFormats
metadata.params match {
case JObject(pairs) =>
pairs.foreach { case (paramName, jsonValue) =>
val param = instance.getParam(paramName)
val value = param.jsonDecode(compact(render(jsonValue)))
instance.set(param, value)
}
case _ =>
throw new IllegalArgumentException(
s"Cannot recognize JSON metadata: ${metadata.metadataJson}.")
}
}
/**
* Load a [[Params]] instance from the given path, and return it.
* This assumes the instance implements [[MLReadable]].
*/
def loadParamsInstance[T](path: String, sc: SparkContext): T = {
val metadata = DefaultParamsReader.loadMetadata(path, sc)
val cls = Class.forName(metadata.className, true, Option(Thread.currentThread().getContextClassLoader).getOrElse(getClass.getClassLoader))
cls.getMethod("read").invoke(null).asInstanceOf[MLReader[T]].load(path)
}
}
With that in place you can use the CustomTransformer
in a Pipeline
and save/load the pipeline. I tested that fairly quickly in spark shell and it seems to work but certainly isn't pretty.