PyTorch
Il progetto PyTorch è un pacchetto Python che fornisce il calcolo di tensor accelerato della GPU e funzionalità di alto livello per la creazione di reti di Deep Learning. Per informazioni dettagliate sulle licenze, vedere la documentazione sulla licenza di PyTorch su GitHub.
Per monitorare ed eseguire il debug dei modelli PyTorch, è consigliabile usare TensorBoard.
PyTorch è incluso in Databricks Runtime per Machine Learning. Se si usa Databricks Runtime, vedere Installare PyTorch per istruzioni sull'installazione di PyTorch.
Nota
Questa non è una guida completa a PyTorch. Per altre informazioni, vedere il sito Web PyTorch.
Training a nodo singolo e distribuito
Per testare ed eseguire la migrazione di flussi di lavoro a computer singolo, usare un cluster a nodo singolo.
Per le opzioni di training distribuite per il Deep Learning, vedere Training distribuito.
Notebook di esempio
Notebook PyTorch
Installare PyTorch
Databricks Runtime per ML
Databricks Runtime per Machine Learning include PyTorch per poter creare il cluster e iniziare a usare PyTorch. Per la versione di PyTorch installata nella versione di Databricks Runtime ML in uso, vedere le note sulla versione.
Databricks Runtime
Databricks consiglia di usare PyTorch incluso in Databricks Runtime per Machine Learning. Tuttavia, se è necessario usare il runtime di Databricks standard, PyTorch può essere installato come libreria PyPI di Databricks. Nell'esempio seguente viene illustrato come installare PyTorch 1.5.0:
Nei cluster GPU installare
pytorch
etorchvision
specificando quanto segue:torch==1.5.0
torchvision==0.6.0
Nei cluster CPU installare
pytorch
etorchvision
usando i file wheel Python seguenti:https://download.pytorch.org/whl/cpu/torch-1.5.0%2Bcpu-cp37-cp37m-linux_x86_64.whl https://download.pytorch.org/whl/cpu/torchvision-0.6.0%2Bcpu-cp37-cp37m-linux_x86_64.whl
Errori e risoluzione dei problemi per PyTorch distribuito
Le sezioni seguenti descrivono i messaggi di errore comuni e le linee guida per la risoluzione dei problemi per le classi: PyTorch DataParallel o PyTorch DistributedDataParallel. La maggior parte di questi errori può essere probabilmente risolta con TorchDistributor, disponibile in Databricks Runtime ML 13.0 e versioni successive. Tuttavia, se TorchDistributor
non è una soluzione valida, le soluzioni consigliate vengono fornite anche all'interno di ogni sezione.
Di seguito è riportato un esempio di come usare TorchDistributor:
from pyspark.ml.torch.distributor import TorchDistributor
def train_fn(learning_rate):
# ...
num_processes=2
distributor = TorchDistributor(num_processes=num_processes, local_mode=True)
distributor.run(train_fn, 1e-3)
processo 0 terminato con codice di uscita 1
L'errore seguente può verificarsi quando si utilizzano notebook in Databricks o localmente:
process 0 terminated with exit code 1
Per evitare questo errore, usare torch.multiprocessing.start_processes
con start_method=fork
anziché torch.multiprocessing.spawn
.
Ad esempio:
import torch
def train_fn(rank, learning_rate):
# required setup, e.g. setup(rank)
# ...
num_processes = 2
torch.multiprocessing.start_processes(train_fn, args=(1e-3,), nprocs=num_processes, start_method="fork")
Il socket del server non è riuscito a eseguire l'associazione alla porta
L'errore seguente viene visualizzato quando si riavvia l'allenamento distribuito dopo aver interrotto la cella di codice durante l'allenamento:
The server socket has failed to bind to [::]:{PORT NUMBER} (errno: 98 - Address already in use).
Per risolvere il problema, riavviare il cluster. Se il riavvio non risolve il problema, potrebbe verificarsi un errore nel codice della funzione di training.
Errori correlati a CUDA
È possibile incorrere in problemi aggiuntivi con CUDA perché start_method=”fork”
non è compatibile con CUDA. L'uso di qualsiasi comando .cuda
in qualsiasi cella potrebbe causare errori. Per evitare questi errori, aggiungere il controllo seguente prima di richiamare torch.multiprocessing.start_method
:
if torch.cuda.is_initialized():
raise Exception("CUDA was initialized; distributed training will fail.") # or something similar