RaySGD API Documentation

TorchTrainer

class ray.util.sgd.torch.TorchTrainer(*, model_creator=None, data_creator=None, optimizer_creator=None, loss_creator=None, scheduler_creator=None, training_operator_cls=None, initialization_hook=None, config=None, num_workers=1, use_gpu='auto', backend='auto', wrap_ddp=True, use_fp16=False, use_tqdm=False, apex_args=None, add_dist_sampler=True, scheduler_step_freq='batch', num_replicas=None, batch_size=None, data_loader_args=None)[source]

Train a PyTorch model using distributed PyTorch.

Launches a set of actors which connect via distributed PyTorch and coordinate gradient updates to train the provided model. If Ray is not initialized, TorchTrainer will automatically initialize a local Ray cluster for you. Be sure to run ray.init(address=”auto”) to leverage multi-node training.

def model_creator(config):
    return nn.Linear(1, 1)


def optimizer_creator(model, config):
    return torch.optim.SGD(
        model.parameters(), lr=config.get("lr", 1e-4))


def data_creator(config):
    batch_size = config["batch_size"]
    train_data, val_data = LinearDataset(2, 5), LinearDataset(2, 5)
    train_loader = DataLoader(train_data, batch_size=batch_size)
    val_loader = DataLoader(val_data, batch_size=batch_size)
    return train_loader, val_loader


trainer = TorchTrainer(
    model_creator=model_creator,
    data_creator=data_creator,
    optimizer_creator=optimizer_creator,
    loss_creator=nn.MSELoss,
    config={"batch_size": 32},
    use_gpu=True
)
for i in range(4):
    trainer.train()
Parameters
  • model_creator (dict -> Model(s)) – Constructor function that takes in config and returns the model(s) to be optimized. These must be torch.nn.Module objects. If multiple models are returned, a training_operator_cls must be specified. You do not need to handle GPU/devices in this function; RaySGD will do that under the hood.

  • data_creator (dict -> Iterable(s)) – Constructor function that takes in the passed config and returns one or two Iterable objects. Note that even though two Iterable objects can be returned, only one will be used for training, and the other will be used for validation. If not provided, you must provide a custom TrainingOperator.

  • optimizer_creator ((models, dict) -> optimizers) – Constructor function that takes in the return values from model_creator and the passed config and returns One or more Torch optimizer objects. You do not need to handle GPU/devices in this function; RaySGD will do that for you.

  • loss_creator (torch.nn.*Loss class | dict -> loss) – A constructor function for the training loss. This can be either a function that takes in the provided config for customization or a subclass of torch.nn.modules.loss._Loss, which is most Pytorch loss classes. For example, loss_creator=torch.nn.BCELoss. If not provided, you must provide a custom TrainingOperator.

  • scheduler_creator ((optimizers, dict) -> scheduler) – A constructor function for the torch scheduler. This is a function that takes in the generated optimizers (from optimizer_creator) provided config for customization. Be sure to set scheduler_step_freq to increment the scheduler correctly.

  • training_operator_cls (type) – Custom training operator class that subclasses the TrainingOperator class. This class will be copied onto all remote workers and used to specify custom training and validation operations. Defaults to TrainingOperator.

  • config (dict) – Custom configuration value to be passed to all creator and operator constructors.

  • num_workers (int) – the number of workers used in distributed training. If 1, the worker will not be wrapped with DistributedDataParallel.

  • use_gpu (bool) – Sets resource allocation for workers to 1 GPU if true, and automatically moves both the model and optimizer to the available CUDA device.

  • backend (string) – backend used by distributed PyTorch. Currently support “nccl”, “gloo”, and “auto”. If “auto”, RaySGD will automatically use “nccl” if use_gpu is True, and “gloo” otherwise.

  • wrap_ddp (bool) – Whether to automatically wrap DistributedDataParallel over each model. If False, you are expected to call it yourself.

  • add_dist_sampler (bool) – Whether to automatically add a DistributedSampler to all created dataloaders. Only applicable if num_workers > 1.

  • use_fp16 (bool) – Enables mixed precision training via apex if apex is installed. This is automatically done after the model and optimizers are constructed and will work for multi-model training. Please see https://github.com/NVIDIA/apex for more details.

  • apex_args (dict|None) – Dict containing keyword args for amp.initialize. See https://nvidia.github.io/apex/amp.html#module-apex.amp. By default, the models and optimizers are passed in. Consider using “num_losses” if operating over multiple models and optimizers.

  • scheduler_step_freq – “batch”, “epoch”, or None. This will determine when scheduler.step is called. If “batch”, step will be called after every optimizer step. If “epoch”, step will be called after one pass of the DataLoader.

train(num_steps=None, profile=False, reduce_results=True, max_retries=3, info=None)[source]

Runs a training epoch.

Calls operator.train_epoch() on N parallel workers simultaneously underneath the hood.

Set max_retries to enable fault handling in case of instance preemption.

Parameters
  • num_steps (int) – Number of batches to compute update steps on. This corresponds also to the number of times TrainingOperator.train_batch is called.

  • profile (bool) – Returns time stats for the training procedure.

  • reduce_results (bool) – Whether to average all metrics across all workers into one dict. If a metric is a non-numerical value (or nested dictionaries), one value will be randomly selected among the workers. If False, returns a list of dicts.

  • max_retries (int) – Must be non-negative. If set to N, TorchTrainer will detect and recover from training failure. The recovery process will kill all current workers, query the Ray global state for total available resources, and re-launch up to the available resources. Behavior is not well-defined in case of shared cluster usage. Defaults to 3.

  • info (dict) – Optional dictionary passed to the training operator for train_epoch and train_batch.

Returns

(dict | list) A dictionary of metrics for training.

You can provide custom metrics by passing in a custom training_operator_cls. If reduce_results=False, this will return a list of metric dictionaries whose length will be equal to num_workers.

apply_all_workers(fn)[source]

Run a function on all operators on the workers.

Parameters

fn (Callable) – A function that takes in no arguments.

Returns

A list of objects returned by fn on each worker.

apply_all_operators(fn)[source]

Run a function on all operators on the workers.

Parameters

fn (Callable[TrainingOperator]) – A function that takes in a TrainingOperator.

Returns

A list of objects returned by fn on each operator.

validate(num_steps=None, profile=False, reduce_results=True, info=None)[source]

Evaluates the model on the validation data set.

Parameters
  • num_steps (int) – Number of batches to compute update steps on. This corresponds also to the number of times TrainingOperator.validate_batch is called.

  • profile (bool) – Returns time stats for the evaluation procedure.

  • reduce_results (bool) – Whether to average all metrics across all workers into one dict. If a metric is a non-numerical value (or nested dictionaries), one value will be randomly selected among the workers. If False, returns a list of dicts.

  • info (dict) – Optional dictionary passed to the training operator for validate and validate_batch.

Returns

A dictionary of metrics for validation.

You can provide custom metrics by passing in a custom training_operator_cls.

update_scheduler(metric)[source]

Calls scheduler.step(metric) on all schedulers.

This is useful for lr_schedulers such as ReduceLROnPlateau.

get_model()[source]

Returns the learned model(s).

get_local_operator()[source]

Returns the local TrainingOperator object.

Be careful not to perturb its state, or else you can cause the system to enter an inconsistent state.

Returns

The local TrainingOperator object.

Return type

TrainingOperator

save(checkpoint)[source]

Saves the Trainer state to the provided checkpoint path.

Parameters

checkpoint (str) – Path to target checkpoint file.

load(checkpoint)[source]

Loads the Trainer and all workers from the provided checkpoint.

Parameters

checkpoint (str) – Path to target checkpoint file.

shutdown(force=False)[source]

Shuts down workers and releases resources.

classmethod as_trainable(*args, **kwargs)[source]

Creates a BaseTorchTrainable class compatible with Tune.

Any configuration parameters will be overriden by the Tune Trial configuration. You can also subclass the provided Trainable to implement your own iterative optimization routine.

TorchTrainable = TorchTrainer.as_trainable(
    model_creator=ResNet18,
    data_creator=cifar_creator,
    optimizer_creator=optimizer_creator,
    loss_creator=nn.CrossEntropyLoss,
    num_gpus=2
)
analysis = tune.run(
    TorchTrainable,
    config={"lr": tune.grid_search([0.01, 0.1])}
)

PyTorch TrainingOperator

class ray.util.sgd.torch.TrainingOperator(config, models, optimizers, train_loader, validation_loader, world_rank, criterion=None, schedulers=None, device_ids=None, use_gpu=False, use_fp16=False, use_tqdm=False)[source]

Abstract class for custom training or validation loops.

The scheduler will only be called at a batch or epoch frequency, depending on the user parameter. Be sure to set scheduler_step_freq in TorchTrainer to either “batch” or “epoch” to increment the scheduler correctly during training. If using a learning rate scheduler that depends on validation loss, you can use trainer.update_scheduler.

For both training and validation, there are two granularities that you can provide customization: per epoch or per batch. You do not need to override both.

../_images/raysgd-custom.jpg
Raises

ValueError if multiple models/optimizers/schedulers are provided. – You are expected to subclass this class if you wish to train over multiple models/optimizers/schedulers.

setup(config)[source]

Override this method to implement custom operator setup.

Parameters

config (dict) – Custom configuration value to be passed to all creator and operator constructors. Same as self.config.

train_epoch(iterator, info)[source]

Runs one standard training pass over the training dataloader.

By default, this method will iterate over the given iterator and call self.train_batch over each batch. If scheduler_step_freq is set, this default method will also step the scheduler accordingly.

You do not need to call train_batch in this method if you plan to implement a custom optimization/training routine here.

You may find ray.util.sgd.utils.AverageMeterCollection useful when overriding this method. See example below:

def train_epoch(self, ...):
    meter_collection = AverageMeterCollection()
    self.model.train()
    for batch in iterator:
        # do some processing
        metrics = {"metric_1": 1, "metric_2": 3} # dict of metrics

        # This keeps track of all metrics across multiple batches
        meter_collection.update(metrics, n=len(batch))

    # Returns stats of the meters.
    stats = meter_collection.summary()
    return stats
Parameters
  • iterator (iter) – Iterator over the training data for the entire epoch. This iterator is expected to be entirely consumed.

  • info (dict) – Dictionary for information to be used for custom training operations.

Returns

A dict of metrics from training.

train_batch(batch, batch_info)[source]

Computes loss and updates the model over one batch.

This method is responsible for computing the loss and gradient and updating the model.

By default, this method implementation assumes that batches are in (features, labels) format. If using amp/fp16 training, it will also scale the loss automatically.

You can provide custom loss metrics and training operations if you override this method. If overriding this method, you can access model, optimizer, criterion via self.model, self.optimizer, and self.criterion.

You do not need to override this method if you plan to override train_epoch.

Parameters
  • batch – One item of the validation iterator.

  • batch_info (dict) – Information dict passed in from train_epoch.

Returns

A dictionary of metrics.

By default, this dictionary contains “loss” and “num_samples”. “num_samples” corresponds to number of datapoints in the batch. However, you can provide any number of other values. Consider returning “num_samples” in the metrics because by default, train_epoch uses “num_samples” to calculate averages.

validate(val_iterator, info)[source]

Runs one standard validation pass over the val_iterator.

This will call model.eval() and torch.no_grad when iterating over the validation dataloader.

If overriding this method, you can access model, criterion via self.model and self.criterion. You also do not need to call validate_batch if overriding this method.

Parameters
  • val_iterator (iter) – Iterable constructed from the validation dataloader.

  • info – (dict): Dictionary for information to be used for custom validation operations.

Returns

A dict of metrics from the evaluation.

By default, returns “val_accuracy” and “val_loss” which is computed by aggregating “loss” and “correct” values from validate_batch and dividing it by the sum of num_samples from all calls to self.validate_batch.

validate_batch(batch, batch_info)[source]

Calcuates the loss and accuracy over a given batch.

You can override this method to provide arbitrary metrics.

Parameters
  • batch – One item of the validation iterator.

  • batch_info (dict) – Contains information per batch from validate().

Returns

A dict of metrics.

By default, returns “val_loss”, “val_accuracy”, and “num_samples”. When overriding, consider returning “num_samples” in the metrics because by default, validate uses “num_samples” to calculate averages.

state_dict()[source]

Override this to return a representation of the operator state.

Returns

The state dict of the operator.

Return type

dict

load_state_dict(state_dict)[source]

Override this to load the representation of the operator state.

Parameters

state_dict (dict) – State dict as returned by the operator.

property device

The appropriate torch device, at your convenience.

Type

torch.device

property config

Provided into TorchTrainer.

Type

dict

property model

First or only model created by the provided model_creator.

property models

List of models created by the provided model_creator.

property optimizer

First or only optimizer(s) created by the optimizer_creator.

property optimizers

List of optimizers created by the optimizer_creator.

property train_loader

1st Dataloader from data_creator.

Type

Iterable

property validation_loader

2nd Dataloader from data_creator.

Type

Iterable

property world_rank

The rank of the parent runner. Always 0 if not distributed.

Type

int

property criterion

Criterion created by the provided loss_creator.

property scheduler

First or only scheduler(s) created by the scheduler_creator.

property schedulers

List of schedulers created by the scheduler_creator.

property use_gpu

Returns True if cuda is available and use_gpu is True.

property use_fp16

Whether the model and optimizer have been FP16 enabled.

Type

bool

property use_tqdm

Whether tqdm progress bars are enabled.

Type

bool

property device_ids

Device IDs for the model.

This is useful for using batch norm with DistributedDataParallel.

Type

List[int]

BaseTorchTrainable

class ray.util.sgd.torch.BaseTorchTrainable(config=None, logger_creator=None)[source]

Base class for converting TorchTrainer to a Trainable class.

This class is produced when you call TorchTrainer.as_trainable(...).

You can override the produced Trainable to implement custom iterative training procedures:

TorchTrainable = TorchTrainer.as_trainable(
    model_creator=ResNet18,
    data_creator=cifar_creator,
    optimizer_creator=optimizer_creator,
    loss_creator=nn.CrossEntropyLoss,
    num_gpus=2
)
# TorchTrainable is subclass of BaseTorchTrainable.

class CustomTrainable(TorchTrainable):
    def _train(self):
        for i in range(5):
            train_stats = self.trainer.train()
        validation_stats = self.trainer.validate()
        train_stats.update(validation_stats)
        return train_stats

analysis = tune.run(
    CustomTrainable,
    config={"lr": tune.grid_search([0.01, 0.1])}
)
_setup(config)[source]

Constructs a TorchTrainer object as self.trainer.

_train()[source]

Calls self.trainer.train() and self.trainer.validate() once.

You may want to override this if using a custom LR scheduler.

_save(checkpoint_dir)[source]

Returns a path containing the trainer state.

_restore(checkpoint_path)[source]

Restores the trainer state.

Override this if you have state external to the Trainer object.

_stop()[source]

Shuts down the trainer.

property trainer

An instantiated TorchTrainer object.

Use this when specifying custom training procedures for Tune.

TFTrainer

class ray.util.sgd.tf.TFTrainer(model_creator, data_creator, config=None, num_replicas=1, use_gpu=False, verbose=False)[source]
__init__(model_creator, data_creator, config=None, num_replicas=1, use_gpu=False, verbose=False)[source]

Sets up the TensorFlow trainer.

Parameters
  • model_creator (dict -> Model) – This function takes in the config dict and returns a compiled TF model.

  • data_creator (dict -> tf.Dataset, tf.Dataset) – Creates the training and validation data sets using the config. config dict is passed into the function.

  • config (dict) – configuration passed to ‘model_creator’, ‘data_creator’. Also contains fit_config, which is passed into model.fit(data, **fit_config) and evaluate_config which is passed into model.evaluate.

  • num_replicas (int) – Sets number of workers used in distributed training. Workers will be placed arbitrarily across the cluster.

  • use_gpu (bool) – Enables all workers to use GPU.

  • verbose (bool) – Prints output of one model if true.

train()[source]

Runs a training epoch.

validate()[source]

Evaluates the model on the validation data set.

get_model()[source]

Returns the learned model.

save(checkpoint)[source]

Saves the model at the provided checkpoint.

Parameters

checkpoint (str) – Path to target checkpoint file.

restore(checkpoint)[source]

Restores the model from the provided checkpoint.

Parameters

checkpoint (str) – Path to target checkpoint file.

shutdown()[source]

Shuts down workers and releases resources.