火花2.4+
SPARK-21088 CrossValidator、TrainValidationSplit 拟合时应收集所有模型- 添加了对收集子模型的支持。
默认情况下,此行为被禁用,但可以使用控制CollectSubModels
Param
(setCollectSubModels
).
valid = TrainValidationSplit(
estimator=pipeline,
estimatorParamMaps=paramGrid,
evaluator=evaluator,
collectSubModels=True)
model = valid.fit(df)
model.subModels
火花
长话短说,您根本无法获得所有模型的参数,因为,类似于CrossValidator, TrainValidationSplitModel
只保留最好的模型。这些类是为半自动模型选择而不是探索或实验而设计的。
各个型号的参数是多少?
虽然您无法检索实际模型validationMetrics
对应输入Params
所以你应该能够简单地zip
both:
from typing import Dict, Tuple, List, Any
from pyspark.ml.param import Param
from pyspark.ml.tuning import TrainValidationSplitModel
EvalParam = List[Tuple[float, Dict[Param, Any]]]
def get_metrics_and_params(model: TrainValidationSplitModel) -> EvalParam:
return list(zip(model.validationMetrics, model.getEstimatorParamMaps()))
了解指标和参数之间的关系。
如果您需要更多信息,您应该使用管道Params。它将保留可用于进一步处理的所有模型:
models = pipeline.fit(df, params=paramGrid)
它将生成一个列表PipelineModels
对应于params
争论:
zip(models, params)