Поделиться через


PyTorch

Проект PyTorch — это пакет Python, обеспечивающий тензорное вычисление с ускорением GPU и высокоуровневые функции для создания сетей глубокого обучения. Сведения о лицензировании см. в документации по лицензии PyTorch на GitHub.

Для отслеживания и отладки моделей PyTorch рекомендуется использовать TensorBoard.

PyTorch входит в состав Databricks Runtime для Машинного обучения. Если вы используете Databricks Runtime, инструкции по установке PyTorch см. в разделе Установка PyTorch.

Примечание.

Эта статья не является исчерпывающим руководством по PyTorch. Дополнительные сведения см. на веб-сайте PyTorch.

Единый узел и распределенное обучение

Для тестирования и переноса рабочих процессов одного компьютера используйте кластер с одним узлом.

Описание возможностей распределенного обучения в контексте глубокого обучения см. в разделе Распределенное обучение.

Пример записной книжки

Записная книжка PyTorch

Получить записную книжку

Установка PyTorch

Databricks Runtime для машинного обучения

Databricks Runtime для Машинного обучения включает PyTorch, благодаря чему вы можете создать кластер и приступить к работе с PyTorch. Сведения о версии PyTorch, установленной в вашей версии Databricks Runtime ML, см. в заметках о выпуске.

Databricks Runtime

Databricks рекомендует использовать PyTorch, включенный в среду выполнения Databricks для Машинное обучение. Однако если необходимо использовать стандартную среду выполнения Databricks, PyTorch можно установить в качестве библиотеки PyPI Databricks. В следующем примере показано, как установить PyTorch 1.5.0.

  • На кластерах GPU установите pytorch и torchvision, указав следующее:

    • torch==1.5.0
    • torchvision==0.6.0
  • В кластерах ЦП установите pytorch и torchvision с помощью следующих файлов колес Python:

    https://download.pytorch.org/whl/cpu/torch-1.5.0%2Bcpu-cp37-cp37m-linux_x86_64.whl
    
    https://download.pytorch.org/whl/cpu/torchvision-0.6.0%2Bcpu-cp37-cp37m-linux_x86_64.whl
    

Ошибки и устранение неполадок для распределенного PyTorch

В следующих разделах описаны распространенные сообщения об ошибках и рекомендации по устранению неполадок для классов: PyTorch DataParallel или PyTorch DistributedDataParallel. Большинство этих ошибок, вероятно, можно устранить с помощью TorchDistributor, которая доступна в Databricks Runtime ML 13.0 и выше. Однако, если TorchDistributor это не жизнеспособное решение, рекомендуемые решения также предоставляются в каждом разделе.

Ниже приведен пример использования TorchDistributor:


from pyspark.ml.torch.distributor import TorchDistributor

def train_fn(learning_rate):
        # ...

num_processes=2
distributor = TorchDistributor(num_processes=num_processes, local_mode=True)

distributor.run(train_fn, 1e-3)

процесс 0 завершен с кодом выхода 1

При использовании записных книжек в Databricks или локально возникает следующая ошибка:

process 0 terminated with exit code 1

Чтобы избежать этой ошибки, используйте torch.multiprocessing.start_processes с start_method=fork вместо torch.multiprocessing.spawn.

Например:

import torch

def train_fn(rank, learning_rate):
    # required setup, e.g. setup(rank)
        # ...

num_processes = 2
torch.multiprocessing.start_processes(train_fn, args=(1e-3,), nprocs=num_processes, start_method="fork")

Не удалось привязать сокет сервера к порту

Следующая ошибка возникает при перезапуске распределенного обучения после прерывания ячейки во время обучения:

The server socket has failed to bind to [::]:{PORT NUMBER} (errno: 98 - Address already in use).

Чтобы устранить проблему, перезапустите кластер. Если перезагрузка не решает проблему, в коде функции обучения может возникнуть ошибка.

Вы можете столкнуться с дополнительными проблемами с CUDA, так как start_method=”fork”не совместим с CUDA. Использование любых .cuda команд в любой ячейке может привести к сбоям. Чтобы избежать этих ошибок, добавьте следующую проверку перед вызовом torch.multiprocessing.start_method:

if torch.cuda.is_initialized():
    raise Exception("CUDA was initialized; distributed training will fail.") # or something similar