我在 Spark 中定义了一个新的自定义 UnaryTransformer(示例代码中的 cleanText)并在 Pipeline 中使用它。当我保存安装的管道并尝试读回它时,出现以下错误:
java.lang.NoSuchMethodException:test_job$cleanText.read()
当我保存并加载一元变压器时,它工作正常。
重现错误的示例代码(在 Spark 2.2 中测试):
import org.apache.spark.ml.UnaryTransformer
import org.apache.spark.ml.param.DoubleParam
import org.apache.spark.ml.util.{DefaultParamsReadable, DefaultParamsWritable, Identifiable}
import org.apache.spark.sql.types._
import org.apache.spark.ml.{PipelineModel}
import org.apache.spark.{SparkConf, SparkContext}
import org.apache.spark.sql.functions._
import org.apache.spark.sql.SparkSession
import org.apache.spark.sql.types.{DataType, DataTypes}
import org.apache.spark.util.Utils
import org.apache.spark.ml.{Pipeline, PipelineStage}
import org.apache.spark.ml.param._
object test_job {
class cleanText(override val uid: String) extends UnaryTransformer[String, String, cleanText] with DefaultParamsWritable {
def this() = this(Identifiable.randomUID("cleantext"))
override protected def validateInputType(inputType: DataType): Unit =
require(inputType == StringType)
protected def createTransformFunc: String => String = {
val regex = "[^a-zA-Z0-9]".r
s => regex.replaceAllIn(s, m => " ")
}
protected def outputDataType: DataType = StringType
}
object cleanText extends DefaultParamsReadable[cleanText]
//{
// override def load(path: String): cleanText = super.load(path)
//}
def main(args: Array[String]) {
val sc: SparkContext = new SparkContext(new SparkConf().setAppName("test_job"))
val sqlc = SparkSession.builder.appName("test_job").getOrCreate()
import sqlc.implicits._
val cleaner = new cleanText()
cleaner.setInputCol("word").setOutputCol("r_clean")
val someDF = sc.parallelize(Seq(
(1, "sample text 1"),
(2, "sample text 2"),
(3, "sample text 3")
)).toDF("number", "word")
val pipeline = new Pipeline().setStages(Array(cleaner))
val pipeline_fitted = pipeline.fit(someDF)
pipeline_fitted.write.overwrite().save("/tmp/model/")
//Saving just the transformer
//cleaner.write.overwrite().save("/tmp/model/")
println("Pipeline saved")
val pl2 = PipelineModel.load("/tmp/model/")
//Loading just the transformer will work
//val cln = cleanText.load("/tmp/model/")
println("Pipeline loaded")
sqlc.stop()
}
}
None
本文内容由网友自发贡献,版权归原作者所有,本站不承担相应法律责任。如您发现有涉嫌抄袭侵权的内容,请联系:hwhale#tublm.com(使用前将#替换为@)