Spark load model and continue training
Asked Answered
L

1

9

I'm using Scala with Spark 2.0 to train a model with LinearRegression.

val lr = new LinearRegression()
  .setMaxIter(num_iter)
  .setRegParam(reg)
  .setStandardization(true)

val model = lr.fit(data)

this is working fine and I get good results. I saved the model and loaded it in another class to make some predictions:

val model = LinearRegressionModel.load("models/LRModel")
val result = model.transform(data).select("prediction")

Now I wanted to continue training the model with new data, so I saved the model and loaded it to continue the training.

Saving:

model.save("models/LRModel")
lr.save("models/LR")

Loading:

val lr = LinearRegression.load("models/LR")
val model = LinearRegressionModel.load("models/LRModel")

The Problem is, when I load the model, there is not fit or train function to continue the training. When I load the LinearRegression object it seems like it does not save the weights, only the parameters for the algorithm. I tested it by training the same data for the same number of iterations and the result was the exact same rootMeanSquaredError and it was definitely not converged at this point of learning. I also can not load the model into the LinearRegression, it results in a error:

Exception in thread "main" java.lang.NoSuchMethodException: org.apache.spark.ml.regression.LinearRegressionModel.<init>(java.lang.String)

So the question is, how do I get the LinearRegression object to use the saved LinearRegressionModel?

Langan answered 1/9, 2016 at 13:1 Comment(0)
N
1

You can use pipeline to save and load the machine learning models.

import org.apache.spark.ml.regression.LinearRegression
import org.apache.spark.ml.PipelineModel
val lr = new LinearRegression().setLabelCol("labesl").setFeaturesCol("features").setMaxIter(10).setRegParam(1.0).setElasticNetParam(1.0)

val pipeline = new Pipeline().setStages(Array(lr))

pipeline.fit(trainingData)

pipeline.write.overwrite().save("hdfs://.../spark/mllib/models/linearRegression");

val sameModel = PipelineModel.load("hdfs://...")

sameModel.transform(assembler).select("features", "labels", "prediction").show(
Nashua answered 1/8, 2018 at 1:10 Comment(1)
Here also you are predicting/transforming new data instead of retraining/fitting it.Lactate

© 2022 - 2024 — McMap. All rights reserved.