Best Practices: Ray with Tensorflow

This document describes best practices for using Ray with TensorFlow. Feel free to contribute if you think this document is missing anything.

Use Actors for Parallel Models

If you are training a deep network in the distributed setting, you may need to ship your deep network between processes (or machines). However, shipping the model is not always straightforward.

Tip

Avoid sending the Tensorflow model directly. A straightforward attempt to pickle a TensorFlow graph gives mixed results. Furthermore, creating a TensorFlow graph can take tens of seconds, and so serializing a graph and recreating it in another process will be inefficient.

It is recommended to replicate the same TensorFlow graph on each worker once at the beginning and then to ship only the weights between the workers.

Suppose we have a simple network definition (this one is modified from the TensorFlow documentation).

import tensorflow as tf
from tensorflow.keras import layers


def create_keras_model():
    model = tf.keras.Sequential()
    # Adds a densely-connected layer with 64 units to the model:
    model.add(layers.Dense(64, activation="relu", input_shape=(32, )))
    # Add another:
    model.add(layers.Dense(64, activation="relu"))
    # Add a softmax layer with 10 output units:
    model.add(layers.Dense(10, activation="softmax"))

    model.compile(
        optimizer=tf.train.RMSPropOptimizer(0.01),
        loss=tf.keras.losses.categorical_crossentropy,
        metrics=[tf.keras.metrics.categorical_accuracy])
    return model

It is strongly recommended you create actors to handle this. To do this, first initialize ray and define an Actor class:

import ray
import numpy as np

ray.init()

def random_one_hot_labels(shape):
    n, n_class = shape
    classes = np.random.randint(0, n_class, n)
    labels = np.zeros((n, n_class))
    labels[np.arange(n), classes] = 1
    return labels


# Use GPU wth
# @ray.remote(num_gpus=1)
@ray.remote
class Network():
    def __init__(self):
        self.model = create_keras_model()
        self.dataset = np.random.random((1000, 32))
        self.labels = random_one_hot_labels((1000, 10))

    def train(self):
        history = self.model.fit(self.dataset, self.labels, verbose=False)
        return history.history

    def get_weights(self):
        return self.model.get_weights()

    def set_weights(self, weights):
        # Note that for simplicity this does not handle the optimizer state.
        self.model.set_weights(weights)

Then, we can instantiate this actor and train it on the separate process:

NetworkActor = Network.remote()
result_object_id = NetworkActor.train.remote()
ray.get(result_object_id)

We can then use set_weights and get_weights to move the weights of the neural network around. This allows us to manipulate weights between different models running in parallel without shipping the actual TensorFlow graphs, which are much more complex Python objects.

NetworkActor2 = Network.remote()
NetworkActor2.train.remote()
weights = ray.get(
    [NetworkActor.get_weights.remote(),
     NetworkActor2.get_weights.remote()])

averaged_weights = [(layer1 + layer2) / 2
                    for layer1, layer2 in zip(weights[0], weights[1])]

weight_id = ray.put(averaged_weights)
[
    actor.set_weights.remote(weight_id)
    for actor in [NetworkActor, NetworkActor2]
]
ray.get([actor.train.remote() for actor in [NetworkActor, NetworkActor2]])

Lower-level TF Utilities

Given a low-level TF definition:

import tensorflow as tf
import numpy as np

x_data = tf.placeholder(tf.float32, shape=[100])
y_data = tf.placeholder(tf.float32, shape=[100])

w = tf.Variable(tf.random_uniform([1], -1.0, 1.0))
b = tf.Variable(tf.zeros([1]))
y = w * x_data + b

loss = tf.reduce_mean(tf.square(y - y_data))
optimizer = tf.train.GradientDescentOptimizer(0.5)
grads = optimizer.compute_gradients(loss)
train = optimizer.apply_gradients(grads)

init = tf.global_variables_initializer()
sess = tf.Session()

To extract the weights and set the weights, you can use the following helper method.

import ray.experimental.tf_utils
variables = ray.experimental.tf_utils.TensorFlowVariables(loss, sess)

The TensorFlowVariables object provides methods for getting and setting the weights as well as collecting all of the variables in the model.

Now we can use these methods to extract the weights, and place them back in the network as follows.

sess = tf.Session()
# First initialize the weights.
sess.run(init)
# Get the weights
weights = variables.get_weights()  # Returns a dictionary of numpy arrays
# Set the weights
variables.set_weights(weights)

Note: If we were to set the weights using the assign method like below, each call to assign would add a node to the graph, and the graph would grow unmanageably large over time.

w.assign(np.zeros(1))  # This adds a node to the graph every time you call it.
b.assign(np.zeros(1))  # This adds a node to the graph every time you call it.
class ray.experimental.tf_utils.TensorFlowVariables(output, sess=None, input_variables=None)[source]

A class used to set and get weights for Tensorflow networks.

sess

The tensorflow session used to run assignment.

Type:tf.Session
variables

Extracted variables from the loss or additional variables that are passed in.

Type:Dict[str, tf.Variable]
placeholders

Placeholders for weights.

Type:Dict[str, tf.placeholders]
assignment_nodes

Nodes that assign weights.

Type:Dict[str, tf.Tensor]
set_session(sess)[source]

Sets the current session used by the class.

Parameters:sess (tf.Session) – Session to set the attribute with.
get_flat_size()[source]

Returns the total length of all of the flattened variables.

Returns:The length of all flattened variables concatenated.
get_flat()[source]

Gets the weights and returns them as a flat array.

Returns:1D Array containing the flattened weights.
set_flat(new_weights)[source]

Sets the weights to new_weights, converting from a flat array.

Note

You can only set all weights in the network using this function, i.e., the length of the array must match get_flat_size.

Parameters:new_weights (np.ndarray) – Flat array containing weights.
get_weights()[source]

Returns a dictionary containing the weights of the network.

Returns:Dictionary mapping variable names to their weights.
set_weights(new_weights)[source]

Sets the weights to new_weights.

Note

Can set subsets of variables as well, by only passing in the variables you want to be set.

Parameters:new_weights (Dict) – Dictionary mapping variable names to their weights.

Note

This may not work with tf.Keras.

Troubleshooting

Note that TensorFlowVariables uses variable names to determine what variables to set when calling set_weights. One common issue arises when two networks are defined in the same TensorFlow graph. In this case, TensorFlow appends an underscore and integer to the names of variables to disambiguate them. This will cause TensorFlowVariables to fail. For example, if we have a class definiton Network with a TensorFlowVariables instance:

import ray
import tensorflow as tf

class Network(object):
    def __init__(self):
        a = tf.Variable(1)
        b = tf.Variable(1)
        c = tf.add(a, b)
        sess = tf.Session()
        init = tf.global_variables_initializer()
        sess.run(init)
        self.variables = ray.experimental.tf_utils.TensorFlowVariables(c, sess)

    def set_weights(self, weights):
        self.variables.set_weights(weights)

    def get_weights(self):
        return self.variables.get_weights()

and run the following code:

a = Network()
b = Network()
b.set_weights(a.get_weights())

the code would fail. If we instead defined each network in its own TensorFlow graph, then it would work:

with tf.Graph().as_default():
    a = Network()
with tf.Graph().as_default():
    b = Network()
b.set_weights(a.get_weights())

This issue does not occur between actors that contain a network, as each actor is in its own process, and thus is in its own graph. This also does not occur when using set_flat.

Another issue to keep in mind is that TensorFlowVariables needs to add new operations to the graph. If you close the graph and make it immutable, e.g. creating a MonitoredTrainingSession the initialization will fail. To resolve this, simply create the instance before you close the graph.