Distributed Training (Experimental)

Ray includes abstractions for distributed model training that integrate with deep learning frameworks, such as PyTorch.

Ray Train is built on top of the Ray task and actor abstractions to provide seamless integration into existing Ray applications.

PyTorch Interface

To use Ray Train with PyTorch, pass model and data creator functions to the ray.experimental.sgd.pytorch.PyTorchTrainer class. To drive the distributed training, trainer.train() can be called repeatedly.

model_creator = lambda config: YourPyTorchModel()
data_creator = lambda config: YourTrainingSet(), YourValidationSet()

trainer = PyTorchTrainer(
    model_creator,
    data_creator,
    optimizer_creator=utils.sgd_mse_optimizer,
    config={"lr": 1e-4},
    num_replicas=2,
    resources_per_replica=Resources(num_gpus=1),
    batch_size=16,
    backend="auto")

for i in range(NUM_EPOCHS):
    trainer.train()

Under the hood, Ray Train will create replicas of your model (controlled by num_replicas) which are each managed by a worker. Multiple devices (e.g. GPUs) can be managed by each replica (controlled by resources_per_replica), which allows training of lage models across multiple GPUs. The PyTorchTrainer class coordinates the distributed computation and training to improve the model.

The full documentation for PyTorchTrainer is as follows:

class ray.experimental.sgd.pytorch.PyTorchTrainer(model_creator, data_creator, optimizer_creator=<function sgd_mse_optimizer>, config=None, num_replicas=1, use_gpu=False, batch_size=16, backend='auto')[source]

Train a PyTorch model using distributed PyTorch.

Launches a set of actors which connect via distributed PyTorch and coordinate gradient updates to train the provided model.

__init__(model_creator, data_creator, optimizer_creator=<function sgd_mse_optimizer>, config=None, num_replicas=1, use_gpu=False, batch_size=16, backend='auto')[source]

Sets up the PyTorch trainer.

Parameters:
  • model_creator (dict -> torch.nn.Module) – creates the model using the config.
  • data_creator (dict -> Dataset, Dataset) – creates the training and validation data sets using the config.
  • optimizer_creator (torch.nn.Module, dict -> loss, optimizer) – creates the loss and optimizer using the model and the config.
  • config (dict) – configuration passed to ‘model_creator’, ‘data_creator’, and ‘optimizer_creator’.
  • num_replicas (int) – the number of workers used in distributed training.
  • use_gpu (bool) – Sets resource allocation for workers to 1 GPU if true.
  • batch_size (int) – batch size for an update.
  • backend (string) – backend used by distributed PyTorch.
train()[source]

Runs a training epoch.

validate()[source]

Evaluates the model on the validation data set.

get_model()[source]

Returns the learned model.

save(checkpoint)[source]

Saves the model at the provided checkpoint.

Parameters:checkpoint (str) – Path to target checkpoint file.
restore(checkpoint)[source]

Restores the model from the provided checkpoint.

Parameters:checkpoint (str) – Path to target checkpoint file.
shutdown()[source]

Shuts down workers and releases resources.