TF Distributed Training

Ray’s TFTrainer simplifies distributed model training for Tensorflow. The TFTrainer is a wrapper around MultiWorkerMirroredStrategy with a Python API to easily incorporate distributed training into a larger Python application, as opposed to write custom logic of setting environments and starting separate processes.

Important

This API has only been tested with TensorFlow2.0rc and is still highly experimental. Please file bug reports if you run into any - thanks!


With Ray:

Wrap your training with this:

ray.init(args.address)

trainer = TFTrainer(
    model_creator=model_creator,
    data_creator=data_creator,
    num_replicas=4,
    use_gpu=True,
    verbose=True,
    config={
        "fit_config": {
            "steps_per_epoch": num_train_steps,
        },
        "evaluate_config": {
            "steps": num_eval_steps,
        }
    })

Then, start a Ray cluster via autoscaler or manually.

ray up CLUSTER.yaml
python train.py --address="localhost:<PORT>"

Before, with Tensorflow:

In your training program, insert the following, and customize for each worker:

os.environ['TF_CONFIG'] = json.dumps({
    'cluster': {
        'worker': ["localhost:12345", "localhost:23456"]
    },
    'task': {'type': 'worker', 'index': 0}
})

...
strategy = tf.distribute.experimental.MultiWorkerMirroredStrategy()
with strategy.scope():
  multi_worker_model = model_creator()

And on each machine, launch a separate process that contains the index of the worker and information about all other nodes of the cluster.

TFTrainer Example

Below is an example of using Ray’s TFTrainer. Under the hood, TFTrainer will create replicas of your model (controlled by num_replicas) which are each managed by a worker.

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import argparse
import tensorflow as tf
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Dense
import numpy as np

import ray
from ray import tune
from ray.experimental.sgd.tf.tf_trainer import TFTrainer, TFTrainable

NUM_TRAIN_SAMPLES = 1000
NUM_TEST_SAMPLES = 400


def create_config(batch_size):
    return {
        "batch_size": batch_size,
        "fit_config": {
            "steps_per_epoch": NUM_TRAIN_SAMPLES // batch_size
        },
        "evaluate_config": {
            "steps": NUM_TEST_SAMPLES // batch_size,
        }
    }


def linear_dataset(a=2, size=1000):
    x = np.random.rand(size)
    y = x / 2

    x = x.reshape((-1, 1))
    y = y.reshape((-1, 1))

    return x, y


def simple_dataset(config):
    batch_size = config["batch_size"]
    x_train, y_train = linear_dataset(size=NUM_TRAIN_SAMPLES)
    x_test, y_test = linear_dataset(size=NUM_TEST_SAMPLES)

    train_dataset = tf.data.Dataset.from_tensor_slices((x_train, y_train))
    test_dataset = tf.data.Dataset.from_tensor_slices((x_test, y_test))
    train_dataset = train_dataset.shuffle(NUM_TRAIN_SAMPLES).repeat().batch(
        batch_size)
    test_dataset = test_dataset.repeat().batch(batch_size)

    return train_dataset, test_dataset


def simple_model(config):
    model = Sequential([Dense(10, input_shape=(1, )), Dense(1)])

    model.compile(
        optimizer="sgd",
        loss="mean_squared_error",
        metrics=["mean_squared_error"])

    return model


def train_example(num_replicas=1, batch_size=128, use_gpu=False):
    trainer = TFTrainer(
        model_creator=simple_model,
        data_creator=simple_dataset,
        num_replicas=num_replicas,
        use_gpu=use_gpu,
        verbose=True,
        config=create_config(batch_size))

    train_stats1 = trainer.train()
    train_stats1.update(trainer.validate())
    print(train_stats1)

    train_stats2 = trainer.train()
    train_stats2.update(trainer.validate())
    print(train_stats2)

    val_stats = trainer.validate()
    print(val_stats)
    print("success!")


def tune_example(num_replicas=1, use_gpu=False):
    config = {
        "model_creator": tune.function(simple_model),
        "data_creator": tune.function(simple_dataset),
        "num_replicas": num_replicas,
        "use_gpu": use_gpu,
        "trainer_config": create_config(batch_size=128)
    }

    analysis = tune.run(
        TFTrainable,
        num_samples=2,
        config=config,
        stop={"training_iteration": 2},
        verbose=1)

    return analysis.get_best_config(metric="validation_loss", mode="min")


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument(
        "--redis-address",
        required=False,
        type=str,
        help="the address to use for Redis")
    parser.add_argument(
        "--num-replicas",
        "-n",
        type=int,
        default=1,
        help="Sets number of replicas for training.")
    parser.add_argument(
        "--use-gpu",
        action="store_true",
        default=False,
        help="Enables GPU training")
    parser.add_argument(
        "--tune", action="store_true", default=False, help="Tune training")

    args, _ = parser.parse_known_args()

    ray.init(redis_address=args.redis_address)

    if args.tune:
        tune_example(num_replicas=args.num_replicas, use_gpu=args.use_gpu)
    else:
        train_example(num_replicas=args.num_replicas, use_gpu=args.use_gpu)

Package Reference

class ray.experimental.sgd.tf.TFTrainer(model_creator, data_creator, config=None, num_replicas=1, use_gpu=False, verbose=False)[source]
__init__(model_creator, data_creator, config=None, num_replicas=1, use_gpu=False, verbose=False)[source]

Sets up the TensorFlow trainer.

Parameters:
  • model_creator (dict -> Model) – This function takes in the config dict and returns a compiled TF model.
  • data_creator (dict -> tf.Dataset, tf.Dataset) – Creates the training and validation data sets using the config. config dict is passed into the function.
  • config (dict) – configuration passed to ‘model_creator’, ‘data_creator’. Also contains fit_config, which is passed into model.fit(data, **fit_config) and evaluate_config which is passed into model.evaluate.
  • num_replicas (int) – Sets number of workers used in distributed training. Workers will be placed arbitrarily across the cluster.
  • use_gpu (bool) – Enables all workers to use GPU.
  • verbose (bool) – Prints output of one model if true.
train()[source]

Runs a training epoch.

validate()[source]

Evaluates the model on the validation data set.

get_model()[source]

Returns the learned model.

save(checkpoint)[source]

Saves the model at the provided checkpoint.

Parameters:checkpoint (str) – Path to target checkpoint file.
restore(checkpoint)[source]

Restores the model from the provided checkpoint.

Parameters:checkpoint (str) – Path to target checkpoint file.
shutdown()[source]

Shuts down workers and releases resources.