Ray Tune: Hyperparameter Optimization Framework

Ray Tune is a scalable hyperparameter optimization framework for reinforcement learning and deep learning. Go from running one experiment on a single machine to running on a large cluster with efficient search algorithms without changing your code.

Getting Started

Installation

You’ll need to first install ray to import Ray Tune.

Quick Start

import ray
import ray.tune as tune

ray.init()
tune.register_trainable("train_func", train_func)

all_trials = tune.run_experiments({
    "my_experiment": {
        "run": "train_func",
        "stop": {"mean_accuracy": 99},
        "config": {
            "lr": tune.grid_search([0.2, 0.4, 0.6]),
            "momentum": tune.grid_search([0.1, 0.2]),
        }
    }
})

For the function you wish to tune, add a two-line modification (note that we use PyTorch as an example but Ray Tune works with any deep learning framework):

 def train_func(config, reporter):  # add a reporter arg
     model = NeuralNet()
     optimizer = torch.optim.SGD(
         model.parameters(), lr=config["lr"], momentum=config["momentum"])
     dataset = ( ... )

     for idx, (data, target) in enumerate(dataset):
         # ...
         output = model(data)
         loss = F.MSELoss(output, target)
         loss.backward()
         optimizer.step()
         accuracy = eval_accuracy(...)
         reporter(timesteps_total=idx, mean_accuracy=accuracy) # report metrics

This PyTorch script runs a small grid search over the train_func function using Ray Tune, reporting status on the command line until the stopping condition of mean_accuracy >= 99 is reached (for metrics like loss that decrease over time, specify neg_mean_loss as a condition instead):

== Status ==
Using FIFO scheduling algorithm.
Resources used: 4/8 CPUs, 0/0 GPUs
Result logdir: ~/ray_results/my_experiment
 - train_func_0_lr=0.2,momentum=1:  RUNNING [pid=6778], 209 s, 20604 ts, 7.29 acc
 - train_func_1_lr=0.4,momentum=1:  RUNNING [pid=6780], 208 s, 20522 ts, 53.1 acc
 - train_func_2_lr=0.6,momentum=1:  TERMINATED [pid=6789], 21 s, 2190 ts, 100 acc
 - train_func_3_lr=0.2,momentum=2:  RUNNING [pid=6791], 208 s, 41004 ts, 8.37 acc
 - train_func_4_lr=0.4,momentum=2:  RUNNING [pid=6800], 209 s, 41204 ts, 70.1 acc
 - train_func_5_lr=0.6,momentum=2:  TERMINATED [pid=6809], 10 s, 2164 ts, 100 acc

In order to report incremental progress, train_func periodically calls the reporter function passed in by Ray Tune to return the current timestep and other metrics as defined in ray.tune.result.TrainingResult. Incremental results will be synced to local disk on the head node of the cluster.

tune.run_experiments returns a list of Trial objects which you can inspect results of via trial.last_result.

Learn more about specifying experiments.

Features

Ray Tune has the following features:

Concepts

_images/tune-api.svg

Ray Tune schedules a number of trials in a cluster. Each trial runs a user-defined Python function or class and is parameterized by a config variation passed to the user code.

In order to run any given function, you need to run register_trainable to a name. This makes all Ray workers aware of the function.

ray.tune.register_trainable(name, trainable)

Register a trainable function or class.

Parameters:
  • name (str) – Name to register.
  • trainable (obj) – Function or tune.Trainable class. Functions must take (config, status_reporter) as arguments and will be automatically converted into a class during registration.

Ray Tune provides a run_experiments function that generates and runs the trials described by the experiment specification. The trials are scheduled and managed by a trial scheduler that implements the search algorithm (default is FIFO).

ray.tune.run_experiments(experiments, scheduler=None, with_server=False, server_port=4321, verbose=True, queue_trials=False)

Tunes experiments.

Parameters:
  • experiments (Experiment | list | dict) – Experiments to run.
  • scheduler (TrialScheduler) – Scheduler for executing the experiment. Choose among FIFO (default), MedianStopping, AsyncHyperBand, HyperBand, or HyperOpt.
  • with_server (bool) – Starts a background Tune server. Needed for using the Client API.
  • server_port (int) – Port number for launching TuneServer.
  • verbose (bool) – How much output should be printed for each trial.
  • queue_trials (bool) – Whether to queue trials when the cluster does not currently have enough resources to launch one. This should be set to True when running on an autoscaling cluster to enable automatic scale-up.
Returns:

List of Trial objects, holding data for each executed trial.

Ray Tune can be used anywhere Ray can, e.g. on your laptop with ray.init() embedded in a Python script, or in an auto-scaling cluster for massive parallelism.

You can find the code for Ray Tune here on GitHub.

Trial Schedulers

By default, Ray Tune schedules trials in serial order with the FIFOScheduler class. However, you can also specify a custom scheduling algorithm that can early stop trials, perturb parameters, or incorporate suggestions from an external service. Currently implemented trial schedulers include Population Based Training (PBT), Median Stopping Rule, Model Based Optimization (HyperOpt), and HyperBand.

run_experiments({...}, scheduler=AsyncHyperBandScheduler())

Handling Large Datasets

You often will want to compute a large object (e.g., training data, model weights) on the driver and use that object within each trial. Ray Tune provides a pin_in_object_store utility function that can be used to broadcast such large objects. Objects pinned in this way will never be evicted from the Ray object store while the driver process is running, and can be efficiently retrieved from any task via get_pinned_object.

import ray
from ray.tune import register_trainable, run_experiments
from ray.tune.util import pin_in_object_store, get_pinned_object

import numpy as np

ray.init()

# X_id can be referenced in closures
X_id = pin_in_object_store(np.random.random(size=100000000))

def f(config, reporter):
    X = get_pinned_object(X_id)
    # use X

register_trainable("f", f)
run_experiments(...)

HyperOpt Integration

The HyperOptScheduler is a Trial Scheduler that is backed by HyperOpt to perform sequential model-based hyperparameter optimization. In order to use this scheduler, you will need to install HyperOpt via the following command:

$ pip install --upgrade git+git://github.com/hyperopt/hyperopt.git

An example of this can be found in hyperopt_example.py.

class ray.tune.hpo_scheduler.HyperOptScheduler(max_concurrent=None, reward_attr='episode_reward_mean')

FIFOScheduler that uses HyperOpt to provide trial suggestions.

Requires HyperOpt to be installed via source. Uses the Tree-structured Parzen Estimators algorithm. Externally added trials will not be tracked by HyperOpt. Also, variant generation will be limited, as the hyperparameter configuration must be specified using HyperOpt primitives.

Parameters:
  • max_concurrent (int | None) – Number of maximum concurrent trials. If None, then trials will be queued only if resources are available.
  • reward_attr (str) – The TrainingResult objective value attribute. This refers to an increasing value, which is internally negated when interacting with HyperOpt. Suggestion procedures will use this attribute.

Examples

>>> space = {'param': hp.uniform('param', 0, 20)}
>>> config = {"my_exp": {
                  "run": "exp",
                  "repeat": 5,
                  "config": {"space": space}}}
>>> run_experiments(config, scheduler=HyperOptScheduler())

Visualizing Results

Ray Tune logs trial results to a unique directory per experiment, e.g. ~/ray_results/my_experiment in the above example. The log records are compatible with a number of visualization tools:

To visualize learning in tensorboard, install TensorFlow:

$ pip install tensorflow

Then, after you run a experiment, you can visualize your experiment with TensorBoard by specifying the output directory of your results. Note that if you running Ray on a remote cluster, you can forward the tensorboard port to your local machine through SSH using ssh -L 6006:localhost:6006 <address>:

$ tensorboard --logdir=~/ray_results/my_experiment
_images/ray-tune-tensorboard.png

To use rllab’s VisKit (you may have to install some dependencies), run:

$ git clone https://github.com/rll/rllab.git
$ python rllab/rllab/viskit/frontend.py ~/ray_results/my_experiment
_images/ray-tune-viskit.png

Finally, to view the results with a parallel coordinates visualization, open ParallelCoordinatesVisualization.ipynb as follows and run its cells:

$ cd $RAY_HOME/python/ray/tune
$ jupyter-notebook ParallelCoordinatesVisualization.ipynb
_images/ray-tune-parcoords.png

Trial Checkpointing

To enable checkpointing, you must implement a Trainable class (Trainable functions are not checkpointable, since they never return control back to their caller). The easiest way to do this is to subclass the pre-defined Trainable class and implement its _train, _save, and _restore abstract methods (example): Implementing this interface is required to support resource multiplexing in schedulers such as HyperBand and PBT.

For TensorFlow model training, this would look something like this (full tensorflow example):

class MyClass(Trainable):
    def _setup(self):
        self.saver = tf.train.Saver()
        self.sess = ...
        self.iteration = 0

    def _train(self):
        self.sess.run(...)
        self.iteration += 1

    def _save(self, checkpoint_dir):
        return self.saver.save(
            self.sess, checkpoint_dir + "/save",
            global_step=self.iteration)

    def _restore(self, path):
        return self.saver.restore(self.sess, path)

Additionally, checkpointing can be used to provide fault-tolerance for experiments. This can be enabled by setting checkpoint_freq: N and max_failures: M to checkpoint trials every N iterations and recover from up to M crashes per trial, e.g.:

run_experiments({
    "my_experiment": {
        ...
        "checkpoint_freq": 10,
        "max_failures": 5,
    },
})

The class interface that must be implemented to enable checkpointing is as follows:

class ray.tune.trainable.Trainable(config=None, logger_creator=None)

Abstract class for trainable models, functions, etc.

A call to train() on a trainable will execute one logical iteration of training. As a rule of thumb, the execution time of one train call should be large enough to avoid overheads (i.e. more than a few seconds), but short enough to report progress periodically (i.e. at most a few minutes).

Calling save() should save the training state of a trainable to disk, and restore(path) should restore a trainable to the given state.

Generally you only need to implement _train, _save, and _restore here when subclassing Trainable.

Note that, if you don’t require checkpoint/restore functionality, then instead of implementing this class you can also get away with supplying just a my_train(config, reporter) function and calling:

register_trainable("my_func", train)

to register it for use with Tune. The function will be automatically converted to this interface (sans checkpoint functionality).

config

obj – The hyperparam configuration for this trial.

logdir

str – Directory in which training outputs should be placed.

_train()

Subclasses should override this to implement train().

_save(checkpoint_dir)

Subclasses should override this to implement save().

_restore(checkpoint_path)

Subclasses should override this to implement restore().

_setup()

Subclasses should override this for custom initialization.

_stop()

Subclasses should override this for any cleanup on stop.

Client API

You can modify an ongoing experiment by adding or deleting trials using the Tune Client API. To do this, verify that you have the requests library installed:

$ pip install requests

To use the Client API, you can start your experiment with with_server=True:

run_experiments({...}, with_server=True, server_port=4321)

Then, on the client side, you can use the following class. The server address defaults to localhost:4321. If on a cluster, you may want to forward this port (e.g. ssh -L <local_port>:localhost:<remote_port> <address>) so that you can use the Client on your local machine.

class ray.tune.web_server.TuneClient(tune_address)

Client to interact with ongoing Tune experiment.

Requires server to have started running.

get_all_trials()

Returns a list of all trials (trial_id, config, status).

get_trial(trial_id)

Returns the last result for queried trial.

add_trial(name, trial_spec)

Adds a trial of name with configurations.

stop_trial(trial_id)

Requests to stop trial.

For an example notebook for using the Client API, see the Client API Example.

Examples

You can find a list of examples using Ray Tune and its various features here.