超參數微調 - 對抗乳腺癌
本教學課程說明如何使用 SynapseML 來識別所選分類器的最佳超參數組合,最終產生更精確且可靠的模型。 為了進行示範,我們將展示如何執行分散式隨機方格搜尋超參數微調,以建置模型來識別乳腺癌。
1 - 設定相依性
從匯入 Pandas 和設定 Spark 工作階段開始。
import pandas as pd
from pyspark.sql import SparkSession
# Bootstrap Spark Session
spark = SparkSession.builder.getOrCreate()
接下來,讀取資料,並將其分割成微調和測試集。
data = spark.read.parquet(
"wasbs://publicwasb@mmlspark.blob.core.windows.net/BreastCancer.parquet"
).cache()
tune, test = data.randomSplit([0.80, 0.20])
tune.limit(10).toPandas()
定義要使用的模型。
from synapse.ml.automl import TuneHyperparameters
from synapse.ml.train import TrainClassifier
from pyspark.ml.classification import (
LogisticRegression,
RandomForestClassifier,
GBTClassifier,
)
logReg = LogisticRegression()
randForest = RandomForestClassifier()
gbt = GBTClassifier()
smlmodels = [logReg, randForest, gbt]
mmlmodels = [TrainClassifier(model=model, labelCol="Label") for model in smlmodels]
2 - 使用 AutoML 尋找最佳模型
從 synapse.ml.automl
匯入 SynapseML 的 AutoML 類別。
使用 HyperparamBuilder
指定超參數。 新增 DiscreteHyperParam
或 RangeHyperParam
超參數。 TuneHyperparameters
會從均勻分佈中隨機選擇值:
from synapse.ml.automl import *
paramBuilder = (
HyperparamBuilder()
.addHyperparam(logReg, logReg.regParam, RangeHyperParam(0.1, 0.3))
.addHyperparam(randForest, randForest.numTrees, DiscreteHyperParam([5, 10]))
.addHyperparam(randForest, randForest.maxDepth, DiscreteHyperParam([3, 5]))
.addHyperparam(gbt, gbt.maxBins, RangeHyperParam(8, 16))
.addHyperparam(gbt, gbt.maxDepth, DiscreteHyperParam([3, 5]))
)
searchSpace = paramBuilder.build()
# The search space is a list of params to tuples of estimator and hyperparam
print(searchSpace)
randomSpace = RandomSpace(searchSpace)
接下來,執行 TuneHyperparameters 以取得最佳模型。
bestModel = TuneHyperparameters(
evaluationMetric="accuracy",
models=mmlmodels,
numFolds=2,
numRuns=len(mmlmodels) * 2,
parallelism=1,
paramSpace=randomSpace.space(),
seed=0,
).fit(tune)
3 - 評估模型
我們可以檢視最佳模型的參數,並擷取基礎最佳模型管線
print(bestModel.getBestModelInfo())
print(bestModel.getBestModel())
我們可以針對測試集評分並檢視計量。
from synapse.ml.train import ComputeModelStatistics
prediction = bestModel.transform(test)
metrics = ComputeModelStatistics().transform(prediction)
metrics.limit(10).toPandas()