Population Based Training

Ray Tune includes a distributed implementation of Population Based Training (PBT).

PBT Scheduler

Ray Tune’s PBT scheduler can be plugged in on top of an existing grid or random search experiment. This can be enabled by setting the scheduler parameter of run_experiments, e.g.

run_experiments(
    {...},
    scheduler=PopulationBasedTraining(
        time_attr='time_total_s',
        reward_attr='mean_accuracy',
        perturbation_interval=600.0,
        hyperparam_mutations={
            "lr": [1e-3, 5e-4, 1e-4, 5e-5, 1e-5],
            "alpha": lambda: random.uniform(0.0, 1.0),
            ...
        }))

When the PBT scheduler is enabled, each trial variant is treated as a member of the population. Periodically, top-performing trials are checkpointed (this requires your Trainable to support checkpointing). Low-performing trials clone the checkpoints of top performers and perturb the configurations in the hope of discovering an even better variation.

You can run this toy PBT example to get an idea of how how PBT operates. When training in PBT mode, a single trial may see many different hyperparameters over its lifetime, which is recorded in its result.json file. The following figure generated by the example shows PBT discovering new hyperparams over the course of a single experiment:

_images/pbt.png
class ray.tune.pbt.PopulationBasedTraining(time_attr='time_total_s', reward_attr='episode_reward_mean', perturbation_interval=60.0, hyperparam_mutations={}, resample_probability=0.25, custom_explore_fn=None)

Implements the Population Based Training (PBT) algorithm.

https://deepmind.com/blog/population-based-training-neural-networks

PBT trains a group of models (or agents) in parallel. Periodically, poorly performing models clone the state of the top performers, and a random mutation is applied to their hyperparameters in the hopes of outperforming the current top models.

Unlike other hyperparameter search algorithms, PBT mutates hyperparameters during training time. This enables very fast hyperparameter discovery and also automatically discovers good annealing schedules.

This Ray Tune PBT implementation considers all trials added as part of the PBT population. If the number of trials exceeds the cluster capacity, they will be time-multiplexed as to balance training progress across the population.

Parameters:
  • time_attr (str) – The TrainingResult attr to use for comparing time. Note that you can pass in something non-temporal such as training_iteration as a measure of progress, the only requirement is that the attribute should increase monotonically.
  • reward_attr (str) – The TrainingResult objective value attribute. As with time_attr, this may refer to any objective value. Stopping procedures will use this attribute.
  • perturbation_interval (float) – Models will be considered for perturbation at this interval of time_attr. Note that perturbation incurs checkpoint overhead, so you shouldn’t set this to be too frequent.
  • hyperparam_mutations (dict) – Hyperparams to mutate. The format is as follows: for each key, either a list or function can be provided. A list specifies an allowed set of categorical values. A function specifies the distribution of a continuous parameter. You must specify at least one of hyperparam_mutations or custom_explore_fn.
  • resample_probability (float) – The probability of resampling from the original distribution when applying hyperparam_mutations. If not resampled, the value will be perturbed by a factor of 1.2 or 0.8 if continuous, or changed to an adjacent value if discrete.
  • custom_explore_fn (func) – You can also specify a custom exploration function. This function is invoked as f(config) after built-in perturbations from hyperparam_mutations are applied, and should return config updated as needed. You must specify at least one of hyperparam_mutations or custom_explore_fn.

Example

>>> pbt = PopulationBasedTraining(
>>>     time_attr="training_iteration",
>>>     reward_attr="episode_reward_mean",
>>>     perturbation_interval=10,  # every 10 `time_attr` units
>>>                                # (training_iterations in this case)
>>>     hyperparam_mutations={
>>>         # Perturb factor1 by scaling it by 0.8 or 1.2. Resampling
>>>         # resets it to a value sampled from the lambda function.
>>>         "factor_1": lambda: random.uniform(0.0, 20.0),
>>>         # Perturb factor2 by changing it to an adjacent value, e.g.
>>>         # 10 -> 1 or 10 -> 100. Resampling will choose at random.
>>>         "factor_2": [1, 10, 100, 1000, 10000],
>>>     })
>>> run_experiments({...}, scheduler=pbt)