RLlib Training APIs

Getting Started

At a high level, RLlib provides an Agent class which holds a policy for environment interaction. Through the agent interface, the policy can be trained, checkpointed, or an action computed.


You can train a simple DQN agent with the following command

python ray/python/ray/rllib/train.py --run DQN --env CartPole-v0

By default, the results will be logged to a subdirectory of ~/ray_results. This subdirectory will contain a file params.json which contains the hyperparameters, a file result.json which contains a training summary for each episode and a TensorBoard file that can be used to visualize training process with TensorBoard by running

tensorboard --logdir=~/ray_results

The train.py script has a number of options you can show by running

python ray/python/ray/rllib/train.py --help

The most important options are for choosing the environment with --env (any OpenAI gym environment including ones registered by the user can be used) and for choosing the algorithm with --run (available options are PPO, PG, A2C, A3C, IMPALA, ES, DDPG, DQN, APEX, and APEX_DDPG).

Evaluating Trained Agents

In order to save checkpoints from which to evaluate agents, set --checkpoint-freq (number of training iterations between checkpoints) when running train.py.

An example of evaluating a previously trained DQN agent is as follows:

python ray/python/ray/rllib/rollout.py \
      ~/ray_results/default/DQN_CartPole-v0_0upjmdgr0/checkpoint_1/checkpoint-1 \
      --run DQN --env CartPole-v0 --steps 10000

The rollout.py helper script reconstructs a DQN agent from the checkpoint located at ~/ray_results/default/DQN_CartPole-v0_0upjmdgr0/checkpoint_1/checkpoint-1 and renders its behavior in the environment specified by --env.


Specifying Parameters

Each algorithm has specific hyperparameters that can be set with --config, in addition to a number of common hyperparameters. See the algorithms documentation for more information.

In an example below, we train A2C by specifying 8 workers through the config flag.

python ray/python/ray/rllib/train.py --env=PongDeterministic-v4 \
    --run=A2C --config '{"num_workers": 8}'

Specifying Resources

You can control the degree of parallelism used by setting the num_workers hyperparameter for most agents. The number of GPUs the driver should use can be set via the num_gpus option. Similarly, the resource allocation to workers can be controlled via num_cpus_per_worker, num_gpus_per_worker, and custom_resources_per_worker. The number of GPUs can be a fractional quantity to allocate only a fraction of a GPU. For example, with DQN you can pack five agents onto one GPU by setting num_gpus: 0.2. Note that in Ray < 0.6.0 fractional GPU support requires setting the environment variable RAY_USE_XRAY=1.

Common Parameters

The following is a list of the common agent hyperparameters:

    # === Debugging ===
    # Whether to write episode stats and videos to the agent log dir
    "monitor": False,
    # Set the ray.rllib.* log level for the agent process and its evaluators
    "log_level": "INFO",
    # Callbacks that will be run during various phases of training. These all
    # take a single "info" dict as an argument. For episode callbacks, custom
    # metrics can be attached to the episode by updating the episode object's
    # custom metrics dict (see examples/custom_metrics_and_callbacks.py).
    "callbacks": {
        "on_episode_start": None,  # arg: {"env": .., "episode": ...}
        "on_episode_step": None,   # arg: {"env": .., "episode": ...}
        "on_episode_end": None,    # arg: {"env": .., "episode": ...}
        "on_sample_end": None,     # arg: {"samples": .., "evaluator": ...}

    # === Policy ===
    # Arguments to pass to model. See models/catalog.py for a full list of the
    # available model options.
    "model": MODEL_DEFAULTS,
    # Arguments to pass to the policy optimizer. These vary by optimizer.
    "optimizer": {},

    # === Environment ===
    # Discount factor of the MDP
    "gamma": 0.99,
    # Number of steps after which the episode is forced to terminate
    "horizon": None,
    # Arguments to pass to the env creator
    "env_config": {},
    # Environment name can also be passed via config
    "env": None,
    # Whether to clip rewards prior to experience postprocessing. Setting to
    # None means clip for Atari only.
    "clip_rewards": None,
    # Whether to use rllib or deepmind preprocessors by default
    "preprocessor_pref": "deepmind",

    # === Resources ===
    # Number of actors used for parallelism
    "num_workers": 2,
    # Number of GPUs to allocate to the driver. Note that not all algorithms
    # can take advantage of driver GPUs. This can be fraction (e.g., 0.3 GPUs).
    "num_gpus": 0,
    # Number of CPUs to allocate per worker.
    "num_cpus_per_worker": 1,
    # Number of GPUs to allocate per worker. This can be fractional.
    "num_gpus_per_worker": 0,
    # Any custom resources to allocate per worker.
    "custom_resources_per_worker": {},
    # Number of CPUs to allocate for the driver. Note: this only takes effect
    # when running in Tune.
    "num_cpus_for_driver": 1,

    # === Execution ===
    # Number of environments to evaluate vectorwise per worker.
    "num_envs_per_worker": 1,
    # Default sample batch size
    "sample_batch_size": 200,
    # Training batch size, if applicable. Should be >= sample_batch_size.
    # Samples batches will be concatenated together to this size for training.
    "train_batch_size": 200,
    # Whether to rollout "complete_episodes" or "truncate_episodes"
    "batch_mode": "truncate_episodes",
    # Whether to use a background thread for sampling (slightly off-policy)
    "sample_async": False,
    # Element-wise observation filter, either "NoFilter" or "MeanStdFilter"
    "observation_filter": "NoFilter",
    # Whether to synchronize the statistics of remote filters.
    "synchronize_filters": True,
    # Configure TF for single-process operation by default
    "tf_session_args": {
        # note: overriden by `local_evaluator_tf_session_args`
        "intra_op_parallelism_threads": 2,
        "inter_op_parallelism_threads": 2,
        "gpu_options": {
            "allow_growth": True,
        "log_device_placement": False,
        "device_count": {
            "CPU": 1
        "allow_soft_placement": True,  # required by PPO multi-gpu
    # Override the following tf session args on the local evaluator
    "local_evaluator_tf_session_args": {
        # Allow a higher level of parallelism by default, but not unlimited
        # since that can cause crashes with many concurrent drivers.
        "intra_op_parallelism_threads": 8,
        "inter_op_parallelism_threads": 8,
    # Whether to LZ4 compress observations
    "compress_observations": False,
    # Drop metric batches from unresponsive workers after this many seconds
    "collect_metrics_timeout": 180,

    # === Multiagent ===
    "multiagent": {
        # Map from policy ids to tuples of (policy_graph_cls, obs_space,
        # act_space, config). See policy_evaluator.py for more info.
        "policy_graphs": {},
        # Function mapping agent ids to policy ids.
        "policy_mapping_fn": None,
        # Optional whitelist of policies to train, or None for all policies.
        "policies_to_train": None,

Tuned Examples

Some good hyperparameters and settings are available in the repository (some of them are tuned to run on GPUs). If you find better settings or tune an algorithm on a different domain, consider submitting a Pull Request!

You can run these with the train.py script as follows:

python ray/python/ray/rllib/train.py -f /path/to/tuned/example.yaml

Python API

The Python API provides the needed flexibility for applying RLlib to new problems. You will need to use this API if you wish to use custom environments, preprocessors, or models with RLlib.

Here is an example of the basic usage:

import ray
import ray.rllib.agents.ppo as ppo
from ray.tune.logger import pretty_print

config = ppo.DEFAULT_CONFIG.copy()
config["num_gpus"] = 0
config["num_workers"] = 1
agent = ppo.PPOAgent(config=config, env="CartPole-v0")

# Can optionally call agent.restore(path) to load a checkpoint.

for i in range(1000):
   # Perform one iteration of training the policy with PPO
   result = agent.train()

   if i % 100 == 0:
       checkpoint = agent.save()
       print("checkpoint saved at", checkpoint)


It’s recommended that you run RLlib agents with Tune, for easy experiment management and visualization of results. Just set "run": AGENT_NAME, "env": ENV_NAME in the experiment config.

All RLlib agents are compatible with the Tune API. This enables them to be easily used in experiments with Tune. For example, the following code performs a simple hyperparam sweep of PPO:

import ray
import ray.tune as tune

    "my_experiment": {
        "run": "PPO",
        "env": "CartPole-v0",
        "stop": {"episode_reward_mean": 200},
        "config": {
            "num_gpus": 0,
            "num_workers": 1,
            "sgd_stepsize": tune.grid_search([0.01, 0.001, 0.0001]),

Tune will schedule the trials to run in parallel on your Ray cluster:

== Status ==
Using FIFO scheduling algorithm.
Resources requested: 4/4 CPUs, 0/0 GPUs
Result logdir: ~/ray_results/my_experiment
PENDING trials:
 - PPO_CartPole-v0_2_sgd_stepsize=0.0001:   PENDING
RUNNING trials:
 - PPO_CartPole-v0_0_sgd_stepsize=0.01:     RUNNING [pid=21940], 16 s, 4013 ts, 22 rew
 - PPO_CartPole-v0_1_sgd_stepsize=0.001:    RUNNING [pid=21942], 27 s, 8111 ts, 54.7 rew

Accessing Policy State

It is common to need to access an agent’s internal state, e.g., to set or get internal weights. In RLlib an agent’s state is replicated across multiple policy evaluators (Ray actors) in the cluster. However, you can easily get and update this state between calls to train() via agent.optimizer.foreach_evaluator() or agent.optimizer.foreach_evaluator_with_index(). These functions take a lambda function that is applied with the evaluator as an arg. You can also return values from these functions and those will be returned as a list.

You can also access just the “master” copy of the agent state through agent.local_evaluator, but note that updates here may not be immediately reflected in remote replicas if you have configured num_workers > 0. For example, to access the weights of a local TF policy, you can run agent.local_evaluator.policy_map["default"].get_weights(). This is also equivalent to agent.local_evaluator.for_policy(lambda p: p.get_weights()):

# Get weights of the local policy

# Same as above
agent.local_evaluator.for_policy(lambda p: p.get_weights())

# Get list of weights of each evaluator, including remote replicas
    lambda ev: ev.for_policy(lambda p: p.get_weights()))

# Same as above
    lambda ev, i: ev.for_policy(lambda p: p.get_weights()))

Global Coordination

Sometimes, it is necessary to coordinate between pieces of code that live in different processes managed by RLlib. For example, it can be useful to maintain a global average of a certain variable, or centrally control a hyperparameter used by policies. Ray provides a general way to achieve this through named actors (learn more about Ray actors here). As an example, consider maintaining a shared global counter that is incremented by environments and read periodically from your driver program:

from ray.experimental import named_actors

class Counter:
   def __init__(self):
      self.count = 0
   def inc(self, n):
      self.count += n
   def get(self):
      return self.count

# on the driver
counter = Counter.remote()
named_actors.register_actor("global_counter", counter)
print(ray.get(counter.get.remote()))  # get the latest count

# in your envs
counter = named_actors.get_actor("global_counter")
counter.inc.remote(1)  # async call to increment the global count

Ray actors provide high levels of performance, so in more complex cases they can be used implement communication patterns such as parameter servers and allreduce.


Gym Monitor

The "monitor": true config can be used to save Gym episode videos to the result dir. For example:

python ray/python/ray/rllib/train.py --env=PongDeterministic-v4 \
    --run=A2C --config '{"num_workers": 2, "monitor": true}'

# videos will be saved in the ~/ray_results/<experiment> dir, for example

Log Verbosity

You can control the agent log level via the "log_level" flag. Valid values are “INFO” (default), “DEBUG”, “WARN”, and “ERROR”. This can be used to increase or decrease the verbosity of internal logging. For example:

python ray/python/ray/rllib/train.py --env=PongDeterministic-v4 \
    --run=A2C --config '{"num_workers": 2, "log_level": "DEBUG"}'

Callbacks and Custom Metrics

You can provide callback functions to be called at points during policy evaluation. These functions have access to an info dict containing state for the current episode. Custom state can be stored for the episode in the info["episode"].user_data dict, and custom scalar metrics reported by saving values to the info["episode"].custom_metrics dict. These custom metrics will be averaged and reported as part of training results. The following example (full code here) logs a custom metric from the environment:

def on_episode_start(info):
    print(info.keys())  # -> "env", 'episode"
    episode = info["episode"]
    print("episode {} started".format(episode.episode_id))
    episode.user_data["pole_angles"] = []

def on_episode_step(info):
    episode = info["episode"]
    pole_angle = abs(episode.last_observation_for()[2])

def on_episode_end(info):
    episode = info["episode"]
    mean_pole_angle = np.mean(episode.user_data["pole_angles"])
    print("episode {} ended with length {} and pole angles {}".format(
        episode.episode_id, episode.length, mean_pole_angle))
    episode.custom_metrics["mean_pole_angle"] = mean_pole_angle

trials = tune.run_experiments({
    "test": {
        "env": "CartPole-v0",
        "run": "PG",
        "config": {
            "callbacks": {
                "on_episode_start": tune.function(on_episode_start),
                "on_episode_step": tune.function(on_episode_step),
                "on_episode_end": tune.function(on_episode_end),

Custom metrics can be accessed and visualized like any other training result:



In some cases (i.e., when interacting with an externally hosted simulator or production environment) it makes more sense to interact with RLlib as if were an independently running service, rather than RLlib hosting the simulations itself. This is possible via RLlib’s external agents interface.

class ray.rllib.utils.policy_client.PolicyClient(address)

REST client to interact with a RLlib policy server.

start_episode(episode_id=None, training_enabled=True)

Record the start of an episode.

  • episode_id (str) – Unique string id for the episode or None for it to be auto-assigned.
  • training_enabled (bool) – Whether to use experiences for this episode to improve the policy.

Unique string id for the episode.

Return type:

episode_id (str)

get_action(episode_id, observation)

Record an observation and get the on-policy action.

  • episode_id (str) – Episode id returned from start_episode().
  • observation (obj) – Current environment observation.

Action from the env action space.

Return type:

action (obj)

log_action(episode_id, observation, action)

Record an observation and (off-policy) action taken.

  • episode_id (str) – Episode id returned from start_episode().
  • observation (obj) – Current environment observation.
  • action (obj) – Action for the observation.
log_returns(episode_id, reward, info=None)

Record returns from the environment.

The reward will be attributed to the previous action taken by the episode. Rewards accumulate until the next action. If no reward is logged before the next action, a reward of 0.0 is assumed.

  • episode_id (str) – Episode id returned from start_episode().
  • reward (float) – Reward from the environment.
end_episode(episode_id, observation)

Record the end of an episode.

  • episode_id (str) – Episode id returned from start_episode().
  • observation (obj) – Current environment observation.
class ray.rllib.utils.policy_server.PolicyServer(external_env, address, port)

REST server than can be launched from a ExternalEnv.

This launches a multi-threaded server that listens on the specified host and port to serve policy requests and forward experiences to RLlib.


>>> class CartpoleServing(ExternalEnv):
       def __init__(self):
               self, spaces.Discrete(2),
       def run(self):
           server = PolicyServer(self, "localhost", 8900)
>>> register_env("srv", lambda _: CartpoleServing())
>>> pg = PGAgent(env="srv", config={"num_workers": 0})
>>> while True:
>>> client = PolicyClient("localhost:8900")
>>> eps_id = client.start_episode()
>>> action = client.get_action(eps_id, obs)
>>> ...
>>> client.log_returns(eps_id, reward)
>>> ...
>>> client.log_returns(eps_id, reward)

For a full client / server example that you can run, see the example client script and also the corresponding server script, here configured to serve a policy for the toy CartPole-v0 environment.