Simple Parallel Model Selection

In this example, we’ll demonstrate how to quickly write a hyperparameter tuning script that evaluates a set of hyperparameters in parallel.

This script will demonstrate how to use two important parts of the Ray API: using ray.remote to define remote functions and ray.wait to wait for their results to be ready.


For a production-grade implementation of distributed hyperparameter tuning, use Tune, a scalable hyperparameter tuning library built using Ray’s Actor API.

import os
import numpy as np
from filelock import FileLock

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import datasets, transforms

import ray


# The number of sets of random hyperparameters to try.
num_evaluations = 10

# A function for generating random hyperparameters.
def generate_hyperparameters():
    return {
        "learning_rate": 10**np.random.uniform(-5, 1),
        "batch_size": np.random.randint(1, 100),
        "momentum": np.random.uniform(0, 1)

def get_data_loaders(batch_size):
    mnist_transforms = transforms.Compose(
         transforms.Normalize((0.1307, ), (0.3081, ))])

    # We add FileLock here because multiple workers will want to
    # download data, and this may cause overwrites since
    # DataLoader is not threadsafe.
    with FileLock(os.path.expanduser("~/data.lock")):
        train_loader =
    test_loader =
        datasets.MNIST("~/data", train=False, transform=mnist_transforms),
    return train_loader, test_loader

class ConvNet(nn.Module):
    """Simple two layer Convolutional Neural Network."""

    def __init__(self):
        super(ConvNet, self).__init__()
        self.conv1 = nn.Conv2d(1, 3, kernel_size=3)
        self.fc = nn.Linear(192, 10)

    def forward(self, x):
        x = F.relu(F.max_pool2d(self.conv1(x), 3))
        x = x.view(-1, 192)
        x = self.fc(x)
        return F.log_softmax(x, dim=1)

def train(model, optimizer, train_loader, device=torch.device("cpu")):
    """Optimize the model with one pass over the data.

    Cuts off at 1024 samples to simplify training.
    for batch_idx, (data, target) in enumerate(train_loader):
        if batch_idx * len(data) > 1024:
        data, target =,
        output = model(data)
        loss = F.nll_loss(output, target)

def test(model, test_loader, device=torch.device("cpu")):
    """Checks the validation accuracy of the model.

    Cuts off at 512 samples for simplicity.
    correct = 0
    total = 0
    with torch.no_grad():
        for batch_idx, (data, target) in enumerate(test_loader):
            if batch_idx * len(data) > 512:
            data, target =,
            outputs = model(data)
            _, predicted = torch.max(, 1)
            total += target.size(0)
            correct += (predicted == target).sum().item()

    return correct / total

def evaluate_hyperparameters(config):
    model = ConvNet()
    train_loader, test_loader = get_data_loaders(config["batch_size"])
    optimizer = optim.SGD(
    train(model, optimizer, train_loader)
    return test(model, test_loader)

# Keep track of the best hyperparameters and the best accuracy.
best_hyperparameters = None
best_accuracy = 0
# A list holding the object IDs for all of the experiments that we have
# launched but have not yet been processed.
remaining_ids = []
# A dictionary mapping an experiment's object ID to its hyperparameters.
# hyerparameters used for that experiment.
hyperparameters_mapping = {}

# Randomly generate sets of hyperparameters and launch a task to test each set.
for i in range(num_evaluations):
    hyperparameters = generate_hyperparameters()
    accuracy_id = evaluate_hyperparameters.remote(hyperparameters)
    hyperparameters_mapping[accuracy_id] = hyperparameters

# Fetch and print the results of the tasks in the order that they complete.
while remaining_ids:
    # Use ray.wait to get the object ID of the first task that completes.
    done_ids, remaining_ids = ray.wait(remaining_ids)
    # There is only one return result by default.
    result_id = done_ids[0]

    hyperparameters = hyperparameters_mapping[result_id]
    accuracy = ray.get(result_id)
    print("""We achieve accuracy {:.3}% with
        learning_rate: {:.2}
        batch_size: {}
        momentum: {:.2}
      """.format(100 * accuracy, hyperparameters["learning_rate"],
                 hyperparameters["batch_size"], hyperparameters["momentum"]))
    if accuracy > best_accuracy:
        best_hyperparameters = hyperparameters
        best_accuracy = accuracy

# Record the best performing set of hyperparameters.
print("""Best accuracy over {} trials was {:.3} with
      learning_rate: {:.2}
      batch_size: {}
      momentum: {:.2}
      """.format(num_evaluations, 100 * best_accuracy,

Total running time of the script: ( 0 minutes 0.000 seconds)

Gallery generated by Sphinx-Gallery