Condividi tramite


Interpretare le stime del modello usando l'importanza della funzionalità permutazione

Impara come interpretare le previsioni dei modelli di Machine Learning di ML.NET usando l'importanza delle funzionalità di permutazione (PFI). PFI fornisce il contributo relativo di ogni funzionalità a una stima.

I modelli di Machine Learning vengono spesso considerati come caselle opache che accettano input e generano un output. I passaggi intermedi o le interazioni tra le funzionalità che influenzano l'output vengono raramente compresi. Poiché l'apprendimento automatico viene introdotto in più aspetti della vita quotidiana, ad esempio il settore sanitario, è di fondamentale importanza comprendere perché un modello di Machine Learning prende le decisioni che esegue. Ad esempio, se le diagnosi vengono effettuate da un modello di Machine Learning, i professionisti del settore sanitario hanno bisogno di un modo per esaminare i fattori che sono andati a fare tale diagnosi. Fornire la diagnosi corretta potrebbe fare una grande differenza sul fatto che un paziente abbia o meno un recupero rapido. Pertanto, maggiore è il livello di spiegabilità in un modello, maggiore è la fiducia dei professionisti sanitari nell'accettare o rifiutare le decisioni prese dal modello.

Vengono usate varie tecniche per spiegare i modelli, uno dei quali è PFI. PFI è una tecnica usata per spiegare i modelli di classificazione e regressione ispirati alla foreste casuali di Breiman (vedere la sezione 10). A livello generale, il modo in cui funziona consiste nel mescolamento casuale dei dati una caratteristica alla volta per l'intero set di dati e nel calcolare quanto diminuisce la metrica di prestazione di interesse. Maggiore è la modifica, più importante è la funzionalità.

Inoltre, evidenziando le funzionalità più importanti, i generatori di modelli possono concentrarsi sull'uso di un subset di caratteristiche più significative, che possono potenzialmente ridurre il rumore e il tempo di training.

Carica i dati

Le funzionalità nel set di dati usato per questo esempio si trovano nelle colonne 1-12. L'obiettivo è prevedere Price.

Colonna Caratteristica Descrizione
1 CrimeRate Tasso di criminalità pro capite
2 Zone residenziali Zone residenziali in città
3 Zone commerciali Zone nonresidenziali in città
4 NearWater Prossimità al corpo dell'acqua
5 Livelli di Rifiuti Tossici Livelli di tossicità (PPM)
6 NumeroMedioStanza Numero medio di camere in casa
7 HomeAge Età della casa
8 BusinessCenterDistance Distanza dal distretto commerciale più vicino
9 Accesso all'Autostrada Prossimità alle autostrade
10 Aliquota fiscale Imposta sulle proprietà
11 Rapporto Studenti-Insegnante Rapporto tra studenti e insegnanti
12 PercentualePopolazioneSottoLaSogliaDiPovertà Percentuale di popolazione che vive al di sotto della povertà
13 Prezzo Prezzo della casa

Di seguito è riportato un esempio del set di dati:

1,24,13,1,0.59,3,96,11,23,608,14,13,32
4,80,18,1,0.37,5,14,7,4,346,19,13,41
2,98,16,1,0.25,10,5,1,8,689,13,36,12

I dati in questo esempio possono essere modellati da una classe come HousingPriceData e caricati in un IDataView.

class HousingPriceData
{
    [LoadColumn(0)]
    public float CrimeRate { get; set; }

    [LoadColumn(1)]
    public float ResidentialZones { get; set; }

    [LoadColumn(2)]
    public float CommercialZones { get; set; }

    [LoadColumn(3)]
    public float NearWater { get; set; }

    [LoadColumn(4)]
    public float ToxicWasteLevels { get; set; }

    [LoadColumn(5)]
    public float AverageRoomNumber { get; set; }

    [LoadColumn(6)]
    public float HomeAge { get; set; }

    [LoadColumn(7)]
    public float BusinessCenterDistance { get; set; }

    [LoadColumn(8)]
    public float HighwayAccess { get; set; }

    [LoadColumn(9)]
    public float TaxRate { get; set; }

    [LoadColumn(10)]
    public float StudentTeacherRatio { get; set; }

    [LoadColumn(11)]
    public float PercentPopulationBelowPoverty { get; set; }

    [LoadColumn(12)]
    [ColumnName("Label")]
    public float Price { get; set; }
}

Addestrare il modello

L'esempio di codice seguente illustra il processo di training di un modello di regressione lineare per stimare i prezzi delle abitazioni.

// 1. Get the column name of input features.
string[] featureColumnNames =
    data.Schema
        .Select(column => column.Name)
        .Where(columnName => columnName != "Label").ToArray();

// 2. Define training pipeline.
IEstimator<ITransformer> sdcaEstimator =
    mlContext.Transforms.Concatenate("Features", featureColumnNames)
        .Append(mlContext.Transforms.NormalizeMinMax("Features"))
        .Append(mlContext.Regression.Trainers.Sdca());

// 3. Train machine learning model.
var sdcaModel = sdcaEstimator.Fit(data);

Spiegare il modello con Importanza delle caratteristiche per permutazione (PFI)

In ML.NET usare il metodo PermutationFeatureImportance per la rispettiva attività.

// Use the model to make predictions
var transformedData = sdcaModel.Transform(data);

// Calculate feature importance
ImmutableArray<RegressionMetricsStatistics> permutationFeatureImportance =
    mlContext
        .Regression
        .PermutationFeatureImportance(sdcaModel, transformedData, permutationCount:3);

Il risultato dell'uso di PermutationFeatureImportance nel set di dati di addestramento è un ImmutableArray composto da RegressionMetricsStatistics oggetti. RegressionMetricsStatistics fornisce statistiche riepilogative come la media e la deviazione standard per numerose osservazioni di RegressionMetrics, che corrispondono al numero di permutazioni specificato dal parametro permutationCount.

La metrica usata per misurare l'importanza della funzionalità dipende dall'attività di Machine Learning usata per risolvere il problema. Ad esempio, le attività di regressione possono usare una metrica di valutazione comune, ad esempio R quadrato per misurare l'importanza. Per ulteriori informazioni sulle metriche di valutazione del modello, consultare valutare il modello di ML.NET con metriche.

L'importanza, o in questo caso, la diminuzione media assoluta della metrica R-quadrato, calcolata da PermutationFeatureImportance, può quindi essere ordinata dalla più importante alla meno importante.

// Order features by importance.
var featureImportanceMetrics =
    permutationFeatureImportance
        .Select((metric, index) => new { index, metric.RSquared })
        .OrderByDescending(myFeatures => Math.Abs(myFeatures.RSquared.Mean));

Console.WriteLine("Feature\tPFI");

foreach (var feature in featureImportanceMetrics)
{
    Console.WriteLine($"{featureColumnNames[feature.index],-20}|\t{feature.RSquared.Mean:F6}");
}

La stampa dei valori per ognuna delle funzionalità in featureImportanceMetrics genera un output simile all'output seguente. Dovresti aspettarti di vedere risultati diversi perché questi valori variano in base ai dati che ricevono.

Caratteristica Passare a R-Squared
HighwayAccess -0.042731
Rapporto studenti-insegnanti -0.012730
BusinessCenterDistance -0.010491
Aliquota Fiscale -0.008545
NumeroMedioCamere -0.003949
CrimeRate -0.003665
Zone commerciali 0.002749
HomeAge -0.002426
Zone residenziali -0.002319
Vicino all'acqua 0.000203
PercentualeDellaPopolazioneCheViveSottoLaSogliaDiPovertà 0.000031
Livelli di rifiuti tossici -0.000019

Se si esaminano le cinque funzionalità più importanti per questo set di dati, il prezzo di una casa stimata da questo modello è influenzato dalla vicinanza alle autostrade, dal rapporto degli studenti delle scuole nell'area, dalla prossimità ai principali centri di occupazione, dall'aliquota fiscale immobiliare e dal numero medio di camere nella casa.

Passaggi successivi