Package Reference

PyTorchTrainer

class ray.experimental.sgd.pytorch.PyTorchTrainer(model_creator, data_creator, optimizer_creator, loss_creator, scheduler_creator=None, train_function=None, validation_function=None, initialization_hook=None, config=None, dataloader_config=None, num_replicas=1, use_gpu=False, batch_size=16, backend='auto', use_fp16=False, apex_args=None, scheduler_step_freq='batch')[source]

Train a PyTorch model using distributed PyTorch.

Launches a set of actors which connect via distributed PyTorch and coordinate gradient updates to train the provided model.

def model_creator(config):
    return nn.Linear(1, 1)


def optimizer_creator(model, config):
    return torch.optim.SGD(
        model.parameters(), lr=config.get("lr", 1e-4))


def data_creator(config):
    return LinearDataset(2, 5), LinearDataset(2, 5, size=400)

trainer = PyTorchTrainer(
    model_creator,
    data_creator,
    optimizer_creator,
    loss_creator=nn.MSELoss,
    use_gpu=True
)
trainer.train()
Parameters:
  • model_creator (dict -> Model(s)) – Constructor function that takes in config and returns the model(s) to be optimized. These must be torch.nn.Module objects. If multiple models are returned, a train_function must be specified. You do not need to handle GPU/devices in this function; RaySGD will do that under the hood.
  • data_creator (dict -> Dataset(s)) – Constructor function that takes in the passed config and returns one or two torch.utils.data.Dataset objects. Note that even though two Dataset objects can be returned, only one dataset will be used for training. RaySGD will automatically wrap the objects with a DataLoader.
  • optimizer_creator ((models, dict) -> optimizers) – Constructor function that takes in the return values from model_creator and the passed config and returns One or more Torch optimizer objects. You do not need to handle GPU/devices in this function; RaySGD will do that for you.
  • loss_creator (torch.nn.*Loss class | dict -> loss) – A constructor function for the training loss. This can be either a function that takes in the provided config for customization or a subclass of torch.nn.modules.loss._Loss, which is most Pytorch loss classes. For example, loss_creator=torch.nn.BCELoss.
  • scheduler_creator (optimizers, dict -> loss) – A constructor function for the scheduler loss. This is a function that takes in the generated optimizers (from optimizer_creator) provided config for customization. Be sure to set scheduler_step_freq to increment the scheduler correctly.
  • train_function – Custom function for training. This function will be executed in parallel across all workers at once. The function needs to take in (models, train_dataloader, criterion, optimizers, config), and return a dict of training stats.
  • validation_function – Custom function for validation. This function will be executed in parallel across all workers at once. This takes in (model, val_dataloader, criterion, config) and returns a dict of validation stats.
  • config (dict) – Custom configuration value to be passed to “model_creator”, “data_creator”, “optimizer_creator”, and “loss_creator”.
  • dataloader_config (dict) – Configuration values to be passed into the torch.utils.data.DataLoader object that wraps the dataset on each parallel worker for both training and validation. Note that if num_replicas is greater than 1, shuffle and sampler will be automatically set. See the available arguments here https://pytorch.org/docs/stable/data.html.
  • num_replicas (int) – the number of workers used in distributed training.
  • use_gpu (bool) – Sets resource allocation for workers to 1 GPU if true, and automatically moves both the model and optimizer to the available CUDA device.
  • batch_size (int) – Total batch size for each minibatch. This value is divided among all workers and rounded.
  • backend (string) – backend used by distributed PyTorch. Currently support “nccl”, “gloo”, and “auto”. If “auto”, RaySGD will automatically use “nccl” if use_gpu is True, and “gloo” otherwise.
  • use_fp16 (bool) – Enables mixed precision training via apex if apex is installed. This is automatically done after the model and optimizers are constructed and will work for multi-model training. Please see https://github.com/NVIDIA/apex for more details.
  • apex_args (dict|None) – Dict containing keyword args for amp.initialize. See https://nvidia.github.io/apex/amp.html#module-apex.amp. By default, the models and optimizers are passed in. Consider using “num_losses” if operating over multiple models and optimizers.
  • scheduler_step_freq – “batch”, “epoch”, or None. This will determine when scheduler.step is called. If “batch”, step will be called after every optimizer step. If “epoch”, step will be called after one pass of the DataLoader.
__init__(model_creator, data_creator, optimizer_creator, loss_creator, scheduler_creator=None, train_function=None, validation_function=None, initialization_hook=None, config=None, dataloader_config=None, num_replicas=1, use_gpu=False, batch_size=16, backend='auto', use_fp16=False, apex_args=None, scheduler_step_freq='batch')[source]

Initialize self. See help(type(self)) for accurate signature.

train(max_retries=0, checkpoint='auto')[source]

Runs a training epoch.

Runs an average over all values returned from workers. Set max_retries to enable fault handling in case of instance preemption.

Parameters:
  • max_retries (int) – Must be non-negative. If set to N, will kill all current workers, query the Ray global state for total available resources, and re-launch up to the available resources. Behavior is not well-defined in case of shared cluster usage.
  • checkpoint (str) – Path to checkpoint to restore from if retrying. If max_retries is set and checkpoint == “auto”, PyTorchTrainer will save a checkpoint before starting to train.
validate()[source]

Evaluates the model on the validation data set.

update_scheduler(metric)[source]

Calls scheduler.step(metric) on all schedulers.

This is useful for lr_schedulers such as ReduceLROnPlateau.

get_model()[source]

Returns the learned model(s).

save(checkpoint)[source]

Saves the model(s) to the provided checkpoint.

Parameters:checkpoint (str) – Path to target checkpoint file.
restore(checkpoint)[source]

Restores the model from the provided checkpoint.

Parameters:checkpoint (str) – Path to target checkpoint file.
shutdown(force=False)[source]

Shuts down workers and releases resources.

PyTorchTrainable

class ray.experimental.sgd.pytorch.PyTorchTrainable(config=None, logger_creator=None)[source]
classmethod default_resource_request(config)[source]

Returns the resource requirement for the given configuration.

This can be overriden by sub-classes to set the correct trial resource allocation, so the user does not need to.

Example

>>> def default_resource_request(cls, config):
>>>     return Resources(
>>>         cpu=0,
>>>         gpu=0,
>>>         extra_cpu=config["workers"],
>>>         extra_gpu=int(config["use_gpu"]) * config["workers"])

TFTrainer

class ray.experimental.sgd.tf.TFTrainer(model_creator, data_creator, config=None, num_replicas=1, use_gpu=False, verbose=False)[source]
__init__(model_creator, data_creator, config=None, num_replicas=1, use_gpu=False, verbose=False)[source]

Sets up the TensorFlow trainer.

Parameters:
  • model_creator (dict -> Model) – This function takes in the config dict and returns a compiled TF model.
  • data_creator (dict -> tf.Dataset, tf.Dataset) – Creates the training and validation data sets using the config. config dict is passed into the function.
  • config (dict) – configuration passed to ‘model_creator’, ‘data_creator’. Also contains fit_config, which is passed into model.fit(data, **fit_config) and evaluate_config which is passed into model.evaluate.
  • num_replicas (int) – Sets number of workers used in distributed training. Workers will be placed arbitrarily across the cluster.
  • use_gpu (bool) – Enables all workers to use GPU.
  • verbose (bool) – Prints output of one model if true.
train()[source]

Runs a training epoch.

validate()[source]

Evaluates the model on the validation data set.

get_model()[source]

Returns the learned model.

save(checkpoint)[source]

Saves the model at the provided checkpoint.

Parameters:checkpoint (str) – Path to target checkpoint file.
restore(checkpoint)[source]

Restores the model from the provided checkpoint.

Parameters:checkpoint (str) – Path to target checkpoint file.
shutdown()[source]

Shuts down workers and releases resources.