RandomForestClassifier
与许多其他 ML 算法一样,需要在标签列上设置特定的元数据,并且标签值是 [0, 1, 2 ..., #classes) 中表示为双精度的整数值。通常这是由上游处理的Transformers
like StringIndexer
。由于您手动转换标签,因此未设置元数据字段,并且分类器无法确认是否满足这些要求。
val df = Seq(
(0.0, Vectors.dense(1, 0, 0, 0)),
(1.0, Vectors.dense(0, 1, 0, 0)),
(2.0, Vectors.dense(0, 0, 1, 0)),
(2.0, Vectors.dense(0, 0, 0, 1))
).toDF("label", "features")
val rf = new RandomForestClassifier()
.setFeaturesCol("features")
.setNumTrees(5)
rf.setLabelCol("label").fit(df)
// java.lang.IllegalArgumentException: RandomForestClassifier was given input ...
您可以使用重新编码标签列StringIndexer
:
import org.apache.spark.ml.feature.StringIndexer
val indexer = new StringIndexer()
.setInputCol("label")
.setOutputCol("label_idx")
.fit(df)
rf.setLabelCol("label_idx").fit(indexer.transform(df))
or 手动设置所需的元数据:
val meta = NominalAttribute
.defaultAttr
.withName("label")
.withValues("0.0", "1.0", "2.0")
.toMetadata
rf.setLabelCol("label_meta").fit(
df.withColumn("label_meta", $"label".as("", meta))
)
Note:
使用创建的标签StringIndexer
取决于频率而不是值:
indexer.labels
// Array[String] = Array(2.0, 0.0, 1.0)
PySpark:
在 Python 中,元数据字段可以直接在模式上设置:
from pyspark.sql.types import StructField, DoubleType
StructField(
"label", DoubleType(), False,
{"ml_attr": {
"name": "label",
"type": "nominal",
"vals": ["0.0", "1.0", "2.0"]
}}
)