Source code for ray.rllib.optimizers.async_gradients_optimizer

import ray
from ray.rllib.evaluation.metrics import get_learner_stats
from ray.rllib.optimizers.policy_optimizer import PolicyOptimizer
from ray.rllib.utils.annotations import override
from ray.rllib.utils.timer import TimerStat
from ray.rllib.utils.memory import ray_get_and_free


[docs]class AsyncGradientsOptimizer(PolicyOptimizer): """An asynchronous RL optimizer, e.g. for implementing A3C. This optimizer asynchronously pulls and applies gradients from remote workers, sending updated weights back as needed. This pipelines the gradient computations on the remote workers. """ def __init__(self, workers, grads_per_step=100): """Initialize an async gradients optimizer. Arguments: grads_per_step (int): The number of gradients to collect and apply per each call to step(). This number should be sufficiently high to amortize the overhead of calling step(). """ PolicyOptimizer.__init__(self, workers) self.apply_timer = TimerStat() self.wait_timer = TimerStat() self.dispatch_timer = TimerStat() self.grads_per_step = grads_per_step self.learner_stats = {} if not self.workers.remote_workers(): raise ValueError( "Async optimizer requires at least 1 remote workers")
[docs] @override(PolicyOptimizer) def step(self): weights = ray.put(self.workers.local_worker().get_weights()) pending_gradients = {} num_gradients = 0 # Kick off the first wave of async tasks for e in self.workers.remote_workers(): e.set_weights.remote(weights) future = e.compute_gradients.remote(e.sample.remote()) pending_gradients[future] = e num_gradients += 1 while pending_gradients: with self.wait_timer: wait_results = ray.wait( list(pending_gradients.keys()), num_returns=1) ready_list = wait_results[0] future = ready_list[0] gradient, info = ray_get_and_free(future) e = pending_gradients.pop(future) self.learner_stats = get_learner_stats(info) if gradient is not None: with self.apply_timer: self.workers.local_worker().apply_gradients(gradient) self.num_steps_sampled += info["batch_count"] self.num_steps_trained += info["batch_count"] if num_gradients < self.grads_per_step: with self.dispatch_timer: e.set_weights.remote( self.workers.local_worker().get_weights()) future = e.compute_gradients.remote(e.sample.remote()) pending_gradients[future] = e num_gradients += 1
[docs] @override(PolicyOptimizer) def stats(self): return dict( PolicyOptimizer.stats(self), **{ "wait_time_ms": round(1000 * self.wait_timer.mean, 3), "apply_time_ms": round(1000 * self.apply_timer.mean, 3), "dispatch_time_ms": round(1000 * self.dispatch_timer.mean, 3), "learner": self.learner_stats, })