Interpretare le stime del modello usando l'importanza della funzionalità permutazione
Impara come interpretare le previsioni dei modelli di Machine Learning di ML.NET usando l'importanza delle funzionalità di permutazione (PFI). PFI fornisce il contributo relativo di ogni funzionalità a una stima.
I modelli di Machine Learning vengono spesso considerati come caselle opache che accettano input e generano un output. I passaggi intermedi o le interazioni tra le funzionalità che influenzano l'output vengono raramente compresi. Poiché l'apprendimento automatico viene introdotto in più aspetti della vita quotidiana, ad esempio il settore sanitario, è di fondamentale importanza comprendere perché un modello di Machine Learning prende le decisioni che esegue. Ad esempio, se le diagnosi vengono effettuate da un modello di Machine Learning, i professionisti del settore sanitario hanno bisogno di un modo per esaminare i fattori che sono andati a fare tale diagnosi. Fornire la diagnosi corretta potrebbe fare una grande differenza sul fatto che un paziente abbia o meno un recupero rapido. Pertanto, maggiore è il livello di spiegabilità in un modello, maggiore è la fiducia dei professionisti sanitari nell'accettare o rifiutare le decisioni prese dal modello.
Vengono usate varie tecniche per spiegare i modelli, uno dei quali è PFI. PFI è una tecnica usata per spiegare i modelli di classificazione e regressione ispirati
Inoltre, evidenziando le funzionalità più importanti, i generatori di modelli possono concentrarsi sull'uso di un subset di caratteristiche più significative, che possono potenzialmente ridurre il rumore e il tempo di training.
Carica i dati
Le funzionalità nel set di dati usato per questo esempio si trovano nelle colonne 1-12. L'obiettivo è prevedere Price
.
Colonna | Caratteristica | Descrizione |
---|---|---|
1 | CrimeRate | Tasso di criminalità pro capite |
2 | Zone residenziali | Zone residenziali in città |
3 | Zone commerciali | Zone nonresidenziali in città |
4 | NearWater | Prossimità al corpo dell'acqua |
5 | Livelli di Rifiuti Tossici | Livelli di tossicità (PPM) |
6 | NumeroMedioStanza | Numero medio di camere in casa |
7 | HomeAge | Età della casa |
8 | BusinessCenterDistance | Distanza dal distretto commerciale più vicino |
9 | Accesso all'Autostrada | Prossimità alle autostrade |
10 | Aliquota fiscale | Imposta sulle proprietà |
11 | Rapporto Studenti-Insegnante | Rapporto tra studenti e insegnanti |
12 | PercentualePopolazioneSottoLaSogliaDiPovertà | Percentuale di popolazione che vive al di sotto della povertà |
13 | Prezzo | Prezzo della casa |
Di seguito è riportato un esempio del set di dati:
1,24,13,1,0.59,3,96,11,23,608,14,13,32
4,80,18,1,0.37,5,14,7,4,346,19,13,41
2,98,16,1,0.25,10,5,1,8,689,13,36,12
I dati in questo esempio possono essere modellati da una classe come HousingPriceData
e caricati in un IDataView
.
class HousingPriceData
{
[LoadColumn(0)]
public float CrimeRate { get; set; }
[LoadColumn(1)]
public float ResidentialZones { get; set; }
[LoadColumn(2)]
public float CommercialZones { get; set; }
[LoadColumn(3)]
public float NearWater { get; set; }
[LoadColumn(4)]
public float ToxicWasteLevels { get; set; }
[LoadColumn(5)]
public float AverageRoomNumber { get; set; }
[LoadColumn(6)]
public float HomeAge { get; set; }
[LoadColumn(7)]
public float BusinessCenterDistance { get; set; }
[LoadColumn(8)]
public float HighwayAccess { get; set; }
[LoadColumn(9)]
public float TaxRate { get; set; }
[LoadColumn(10)]
public float StudentTeacherRatio { get; set; }
[LoadColumn(11)]
public float PercentPopulationBelowPoverty { get; set; }
[LoadColumn(12)]
[ColumnName("Label")]
public float Price { get; set; }
}
Addestrare il modello
L'esempio di codice seguente illustra il processo di training di un modello di regressione lineare per stimare i prezzi delle abitazioni.
// 1. Get the column name of input features.
string[] featureColumnNames =
data.Schema
.Select(column => column.Name)
.Where(columnName => columnName != "Label").ToArray();
// 2. Define training pipeline.
IEstimator<ITransformer> sdcaEstimator =
mlContext.Transforms.Concatenate("Features", featureColumnNames)
.Append(mlContext.Transforms.NormalizeMinMax("Features"))
.Append(mlContext.Regression.Trainers.Sdca());
// 3. Train machine learning model.
var sdcaModel = sdcaEstimator.Fit(data);
Spiegare il modello con Importanza delle caratteristiche per permutazione (PFI)
In ML.NET usare il metodo PermutationFeatureImportance
per la rispettiva attività.
// Use the model to make predictions
var transformedData = sdcaModel.Transform(data);
// Calculate feature importance
ImmutableArray<RegressionMetricsStatistics> permutationFeatureImportance =
mlContext
.Regression
.PermutationFeatureImportance(sdcaModel, transformedData, permutationCount:3);
Il risultato dell'uso di PermutationFeatureImportance
nel set di dati di addestramento è un ImmutableArray
composto da RegressionMetricsStatistics
oggetti.
RegressionMetricsStatistics
fornisce statistiche riepilogative come la media e la deviazione standard per numerose osservazioni di RegressionMetrics
, che corrispondono al numero di permutazioni specificato dal parametro permutationCount
.
La metrica usata per misurare l'importanza della funzionalità dipende dall'attività di Machine Learning usata per risolvere il problema. Ad esempio, le attività di regressione possono usare una metrica di valutazione comune, ad esempio R quadrato per misurare l'importanza. Per ulteriori informazioni sulle metriche di valutazione del modello, consultare valutare il modello di ML.NET con metriche.
L'importanza, o in questo caso, la diminuzione media assoluta della metrica R-quadrato, calcolata da PermutationFeatureImportance
, può quindi essere ordinata dalla più importante alla meno importante.
// Order features by importance.
var featureImportanceMetrics =
permutationFeatureImportance
.Select((metric, index) => new { index, metric.RSquared })
.OrderByDescending(myFeatures => Math.Abs(myFeatures.RSquared.Mean));
Console.WriteLine("Feature\tPFI");
foreach (var feature in featureImportanceMetrics)
{
Console.WriteLine($"{featureColumnNames[feature.index],-20}|\t{feature.RSquared.Mean:F6}");
}
La stampa dei valori per ognuna delle funzionalità in featureImportanceMetrics
genera un output simile all'output seguente. Dovresti aspettarti di vedere risultati diversi perché questi valori variano in base ai dati che ricevono.
Caratteristica | Passare a R-Squared |
---|---|
HighwayAccess | -0.042731 |
Rapporto studenti-insegnanti | -0.012730 |
BusinessCenterDistance | -0.010491 |
Aliquota Fiscale | -0.008545 |
NumeroMedioCamere | -0.003949 |
CrimeRate | -0.003665 |
Zone commerciali | 0.002749 |
HomeAge | -0.002426 |
Zone residenziali | -0.002319 |
Vicino all'acqua | 0.000203 |
PercentualeDellaPopolazioneCheViveSottoLaSogliaDiPovertà | 0.000031 |
Livelli di rifiuti tossici | -0.000019 |
Se si esaminano le cinque funzionalità più importanti per questo set di dati, il prezzo di una casa stimata da questo modello è influenzato dalla vicinanza alle autostrade, dal rapporto degli studenti delle scuole nell'area, dalla prossimità ai principali centri di occupazione, dall'aliquota fiscale immobiliare e dal numero medio di camere nella casa.