RaySGD TensorFlow

RaySGD’s TFTrainer simplifies distributed model training for Tensorflow. The TFTrainer is a wrapper around MultiWorkerMirroredStrategy with a Python API to easily incorporate distributed training into a larger Python application, as opposed to write custom logic of setting environments and starting separate processes.

Important

This API has only been tested with TensorFlow2.0rc and is still highly experimental. Please file bug reports if you run into any - thanks!

Tip

We need your feedback! RaySGD is currently early in its development, and we’re hoping to get feedback from people using or considering it. We’d love to get in touch!


With Ray:

Wrap your training with this:

ray.init(args.address)

trainer = TFTrainer(
    model_creator=model_creator,
    data_creator=data_creator,
    num_replicas=4,
    use_gpu=True,
    verbose=True,
    config={
        "fit_config": {
            "steps_per_epoch": num_train_steps,
        },
        "evaluate_config": {
            "steps": num_eval_steps,
        }
    })

Then, start a Ray cluster via autoscaler or manually.

ray up CLUSTER.yaml
python train.py --address="localhost:<PORT>"

Before, with Tensorflow:

In your training program, insert the following, and customize for each worker:

os.environ['TF_CONFIG'] = json.dumps({
    'cluster': {
        'worker': ["localhost:12345", "localhost:23456"]
    },
    'task': {'type': 'worker', 'index': 0}
})

...
strategy = tf.distribute.experimental.MultiWorkerMirroredStrategy()
with strategy.scope():
  multi_worker_model = model_creator()

And on each machine, launch a separate process that contains the index of the worker and information about all other nodes of the cluster.

TFTrainer Example

Below is an example of using Ray’s TFTrainer. Under the hood, TFTrainer will create replicas of your model (controlled by num_replicas) which are each managed by a worker.

import argparse
import tensorflow as tf
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Dense
import numpy as np

import ray
from ray import tune
from ray.experimental.sgd.tf.tf_trainer import TFTrainer, TFTrainable

NUM_TRAIN_SAMPLES = 1000
NUM_TEST_SAMPLES = 400


def create_config(batch_size):
    return {
        # todo: batch size needs to scale with # of workers
        "batch_size": batch_size,
        "fit_config": {
            "steps_per_epoch": NUM_TRAIN_SAMPLES // batch_size
        },
        "evaluate_config": {
            "steps": NUM_TEST_SAMPLES // batch_size,
        }
    }


def linear_dataset(a=2, size=1000):
    x = np.random.rand(size)
    y = x / 2

    x = x.reshape((-1, 1))
    y = y.reshape((-1, 1))

    return x, y


def simple_dataset(config):
    batch_size = config["batch_size"]
    x_train, y_train = linear_dataset(size=NUM_TRAIN_SAMPLES)
    x_test, y_test = linear_dataset(size=NUM_TEST_SAMPLES)

    train_dataset = tf.data.Dataset.from_tensor_slices((x_train, y_train))
    test_dataset = tf.data.Dataset.from_tensor_slices((x_test, y_test))
    train_dataset = train_dataset.shuffle(NUM_TRAIN_SAMPLES).repeat().batch(
        batch_size)
    test_dataset = test_dataset.repeat().batch(batch_size)

    return train_dataset, test_dataset


def simple_model(config):
    model = Sequential([Dense(10, input_shape=(1, )), Dense(1)])

    model.compile(
        optimizer="sgd",
        loss="mean_squared_error",
        metrics=["mean_squared_error"])

    return model


def train_example(num_replicas=1, batch_size=128, use_gpu=False):
    trainer = TFTrainer(
        model_creator=simple_model,
        data_creator=simple_dataset,
        num_replicas=num_replicas,
        use_gpu=use_gpu,
        verbose=True,
        config=create_config(batch_size))

    # model baseline performance
    start_stats = trainer.validate()
    print(start_stats)

    # train for 2 epochs
    trainer.train()
    trainer.train()

    # model performance after training (should improve)
    end_stats = trainer.validate()
    print(end_stats)

    # sanity check that training worked
    dloss = end_stats["validation_loss"] - start_stats["validation_loss"]
    dmse = (end_stats["validation_mean_squared_error"] -
            start_stats["validation_mean_squared_error"])
    print(f"dLoss: {dloss}, dMSE: {dmse}")

    if dloss > 0 or dmse > 0:
        print("training sanity check failed. loss increased!")
    else:
        print("success!")


def tune_example(num_replicas=1, use_gpu=False):
    config = {
        "model_creator": tune.function(simple_model),
        "data_creator": tune.function(simple_dataset),
        "num_replicas": num_replicas,
        "use_gpu": use_gpu,
        "trainer_config": create_config(batch_size=128)
    }

    analysis = tune.run(
        TFTrainable,
        num_samples=2,
        config=config,
        stop={"training_iteration": 2},
        verbose=1)

    return analysis.get_best_config(metric="validation_loss", mode="min")


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument(
        "--address",
        required=False,
        type=str,
        help="the address to use for Ray")
    parser.add_argument(
        "--num-replicas",
        "-n",
        type=int,
        default=1,
        help="Sets number of replicas for training.")
    parser.add_argument(
        "--use-gpu",
        action="store_true",
        default=False,
        help="Enables GPU training")
    parser.add_argument(
        "--tune", action="store_true", default=False, help="Tune training")

    args, _ = parser.parse_known_args()

    ray.init(address=args.address)

    if args.tune:
        tune_example(num_replicas=args.num_replicas, use_gpu=args.use_gpu)
    else:
        train_example(num_replicas=args.num_replicas, use_gpu=args.use_gpu)