Best Practices: Ray with PyTorch

This document describes best practices for using Ray with PyTorch. Feel free to contribute if you think this document is missing anything.

Downloading Data

It is very common for multiple Ray actors running PyTorch to have code that downloads the dataset for training and testing.

# This is running inside a Ray actor
# ...
torch.utils.data.DataLoader(
    datasets.MNIST(
        "../data", train=True, download=True,
        transform=transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize((0.1307,), (0.3081,))
        ])),
    128, shuffle=True, **kwargs)
# ...

This may cause different processes to simultaneously download the data and cause data corruption. One easy workaround for this is to use Filelock:

from filelock import FileLock

with FileLock("./data.lock"):
    torch.utils.data.DataLoader(
        datasets.MNIST(
            "./data", train=True, download=True,
            transform=transforms.Compose([
                transforms.ToTensor(),
                transforms.Normalize((0.1307,), (0.3081,))
            ])),
        128, shuffle=True, **kwargs)

Use Actors for Parallel Models

One common use case for using Ray with PyTorch is to parallelize the training of multiple models.

Tip

Avoid sending the PyTorch model directly. Send model.state_dict(), as PyTorch tensors are natively supported by the Plasma Object Store.

Suppose we have a simple network definition (this one is modified from the PyTorch documentation).

import torch
import torch.nn as nn
import torch.nn.functional as F


class Model(nn.Module):
    def __init__(self):
        super(Model, self).__init__()
        self.conv1 = nn.Conv2d(1, 20, 5, 1)
        self.conv2 = nn.Conv2d(20, 50, 5, 1)
        self.fc1 = nn.Linear(4 * 4 * 50, 500)
        self.fc2 = nn.Linear(500, 10)

    def forward(self, x):
        x = F.relu(self.conv1(x))
        x = F.max_pool2d(x, 2, 2)
        x = F.relu(self.conv2(x))
        x = F.max_pool2d(x, 2, 2)
        x = x.view(-1, 4 * 4 * 50)
        x = F.relu(self.fc1(x))
        x = self.fc2(x)
        return F.log_softmax(x, dim=1)

Along with these helper training functions:

from filelock import FileLock
from torchvision import datasets, transforms


def train(model, device, train_loader, optimizer):
    model.train()
    for batch_idx, (data, target) in enumerate(train_loader):
        # This break is for speeding up the tutorial.
        if batch_idx * len(data) > 1024:
            return
        data, target = data.to(device), target.to(device)
        optimizer.zero_grad()
        output = model(data)
        loss = F.nll_loss(output, target)
        loss.backward()
        optimizer.step()


def test(model, device, test_loader):
    model.eval()
    test_loss = 0
    correct = 0
    with torch.no_grad():
        for data, target in test_loader:
            data, target = data.to(device), target.to(device)
            output = model(data)

            # sum up batch loss
            test_loss += F.nll_loss(
                output, target, reduction="sum").item()
            pred = output.argmax(
                dim=1,
                keepdim=True)
            correct += pred.eq(target.view_as(pred)).sum().item()

    test_loss /= len(test_loader.dataset)
    return {
        "loss": test_loss,
        "accuracy": 100. * correct / len(test_loader.dataset)
    }


def dataset_creators(use_cuda):
    kwargs = {"num_workers": 1, "pin_memory": True} if use_cuda else {}
    with FileLock("./data.lock"):
        train_loader = torch.utils.data.DataLoader(
            datasets.MNIST(
                "./data",
                train=True,
                download=True,
                transform=transforms.Compose([
                    transforms.ToTensor(),
                    transforms.Normalize((0.1307, ), (0.3081, ))
                ])),
            128,
            shuffle=True,
            **kwargs)
    test_loader = torch.utils.data.DataLoader(
        datasets.MNIST(
            "./data",
            train=False,
            transform=transforms.Compose([
                transforms.ToTensor(),
                transforms.Normalize((0.1307, ), (0.3081, ))
            ])),
        128,
        shuffle=True,
        **kwargs)

    return train_loader, test_loader

Let’s now define a class that captures the training process.

import torch.optim as optim


class Network():
    def __init__(self, lr=0.01, momentum=0.5):
        use_cuda = torch.cuda.is_available()
        self.device = device = torch.device("cuda" if use_cuda else "cpu")
        self.train_loader, self.test_loader = dataset_creators(use_cuda)

        self.model = Model().to(device)
        self.optimizer = optim.SGD(
            self.model.parameters(), lr=lr, momentum=momentum)

    def train(self):
        train(self.model, self.device, self.train_loader, self.optimizer)
        return test(self.model, self.device, self.test_loader)

    def get_weights(self):
        return self.model.state_dict()

    def set_weights(self, weights):
        self.model.load_state_dict(weights)

    def save(self):
        torch.save(self.model.state_dict(), "mnist_cnn.pt")


net = Network()
net.train()

To train multiple models, you can convert the above class into a Ray Actor class.

import ray
ray.init()

RemoteNetwork = ray.remote(Network)
# Use the below instead of `ray.remote(network)` to leverage the GPU.
# RemoteNetwork = ray.remote(num_gpus=1)(Network)

Then, we can instantiate multiple copies of the Model, each running on different processes. If GPU is enabled, each copy runs on a different GPU. See the GPU guide for more information.

NetworkActor = RemoteNetwork.remote()
NetworkActor2 = RemoteNetwork.remote()

ray.get([NetworkActor.train.remote(), NetworkActor2.train.remote()])

We can then use set_weights and get_weights to move the weights of the neural network around. The below example averages the weights of the two networks and sends them back to update the original actors.

weights = ray.get(
    [NetworkActor.get_weights.remote(),
     NetworkActor2.get_weights.remote()])

from collections import OrderedDict
averaged_weights = OrderedDict(
    [(k, (weights[0][k] + weights[1][k]) / 2) for k in weights[0]])

weight_id = ray.put(averaged_weights)
[
    actor.set_weights.remote(weight_id)
    for actor in [NetworkActor, NetworkActor2]
]
ray.get([actor.train.remote() for actor in [NetworkActor, NetworkActor2]])