Source code for ray.rllib.optimizers.sync_replay_optimizer

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import logging
import collections
import numpy as np

import ray
from ray.rllib.optimizers.replay_buffer import ReplayBuffer, \
    PrioritizedReplayBuffer
from ray.rllib.optimizers.policy_optimizer import PolicyOptimizer
from ray.rllib.evaluation.metrics import get_learner_stats
from ray.rllib.policy.sample_batch import SampleBatch, DEFAULT_POLICY_ID, \
    MultiAgentBatch
from ray.rllib.utils.annotations import override
from ray.rllib.utils.compression import pack_if_needed
from ray.rllib.utils.timer import TimerStat
from ray.rllib.utils.schedules import LinearSchedule
from ray.rllib.utils.memory import ray_get_and_free

logger = logging.getLogger(__name__)


[docs]class SyncReplayOptimizer(PolicyOptimizer): """Variant of the local sync optimizer that supports replay (for DQN). This optimizer requires that rollout workers return an additional "td_error" array in the info return of compute_gradients(). This error term will be used for sample prioritization.""" def __init__(self, workers, learning_starts=1000, buffer_size=10000, prioritized_replay=True, prioritized_replay_alpha=0.6, prioritized_replay_beta=0.4, prioritized_replay_eps=1e-6, schedule_max_timesteps=100000, beta_annealing_fraction=0.2, final_prioritized_replay_beta=0.4, train_batch_size=32, sample_batch_size=4, before_learn_on_batch=None, synchronize_sampling=False): """Initialize an sync replay optimizer. Arguments: workers (WorkerSet): all workers learning_starts (int): wait until this many steps have been sampled before starting optimization. buffer_size (int): max size of the replay buffer prioritized_replay (bool): whether to enable prioritized replay prioritized_replay_alpha (float): replay alpha hyperparameter prioritized_replay_beta (float): replay beta hyperparameter prioritized_replay_eps (float): replay eps hyperparameter schedule_max_timesteps (int): number of timesteps in the schedule beta_annealing_fraction (float): fraction of schedule to anneal beta over final_prioritized_replay_beta (float): final value of beta train_batch_size (int): size of batches to learn on sample_batch_size (int): size of batches to sample from workers before_learn_on_batch (function): callback to run before passing the sampled batch to learn on synchronize_sampling (bool): whether to sample the experiences for all policies with the same indices (used in MADDPG). """ PolicyOptimizer.__init__(self, workers) self.replay_starts = learning_starts # linearly annealing beta used in Rainbow paper self.prioritized_replay_beta = LinearSchedule( schedule_timesteps=int( schedule_max_timesteps * beta_annealing_fraction), initial_p=prioritized_replay_beta, final_p=final_prioritized_replay_beta) self.prioritized_replay_eps = prioritized_replay_eps self.train_batch_size = train_batch_size self.before_learn_on_batch = before_learn_on_batch self.synchronize_sampling = synchronize_sampling # Stats self.update_weights_timer = TimerStat() self.sample_timer = TimerStat() self.replay_timer = TimerStat() self.grad_timer = TimerStat() self.learner_stats = {} # Set up replay buffer if prioritized_replay: def new_buffer(): return PrioritizedReplayBuffer( buffer_size, alpha=prioritized_replay_alpha) else: def new_buffer(): return ReplayBuffer(buffer_size) self.replay_buffers = collections.defaultdict(new_buffer) if buffer_size < self.replay_starts: logger.warning("buffer_size={} < replay_starts={}".format( buffer_size, self.replay_starts))
[docs] @override(PolicyOptimizer) def step(self): with self.update_weights_timer: if self.workers.remote_workers(): weights = ray.put(self.workers.local_worker().get_weights()) for e in self.workers.remote_workers(): e.set_weights.remote(weights) with self.sample_timer: if self.workers.remote_workers(): batch = SampleBatch.concat_samples( ray_get_and_free([ e.sample.remote() for e in self.workers.remote_workers() ])) else: batch = self.workers.local_worker().sample() # Handle everything as if multiagent if isinstance(batch, SampleBatch): batch = MultiAgentBatch({ DEFAULT_POLICY_ID: batch }, batch.count) for policy_id, s in batch.policy_batches.items(): for row in s.rows(): self.replay_buffers[policy_id].add( pack_if_needed(row["obs"]), row["actions"], row["rewards"], pack_if_needed(row["new_obs"]), row["dones"], weight=None) if self.num_steps_sampled >= self.replay_starts: self._optimize() self.num_steps_sampled += batch.count
[docs] @override(PolicyOptimizer) def stats(self): return dict( PolicyOptimizer.stats(self), **{ "sample_time_ms": round(1000 * self.sample_timer.mean, 3), "replay_time_ms": round(1000 * self.replay_timer.mean, 3), "grad_time_ms": round(1000 * self.grad_timer.mean, 3), "update_time_ms": round(1000 * self.update_weights_timer.mean, 3), "opt_peak_throughput": round(self.grad_timer.mean_throughput, 3), "opt_samples": round(self.grad_timer.mean_units_processed, 3), "learner": self.learner_stats, })
def _optimize(self): samples = self._replay() with self.grad_timer: if self.before_learn_on_batch: samples = self.before_learn_on_batch( samples, self.workers.local_worker().policy_map, self.train_batch_size) info_dict = self.workers.local_worker().learn_on_batch(samples) for policy_id, info in info_dict.items(): self.learner_stats[policy_id] = get_learner_stats(info) replay_buffer = self.replay_buffers[policy_id] if isinstance(replay_buffer, PrioritizedReplayBuffer): td_error = info["td_error"] new_priorities = ( np.abs(td_error) + self.prioritized_replay_eps) replay_buffer.update_priorities( samples.policy_batches[policy_id]["batch_indexes"], new_priorities) self.grad_timer.push_units_processed(samples.count) self.num_steps_trained += samples.count def _replay(self): samples = {} idxes = None with self.replay_timer: for policy_id, replay_buffer in self.replay_buffers.items(): if self.synchronize_sampling: if idxes is None: idxes = replay_buffer.sample_idxes( self.train_batch_size) else: idxes = replay_buffer.sample_idxes(self.train_batch_size) if isinstance(replay_buffer, PrioritizedReplayBuffer): (obses_t, actions, rewards, obses_tp1, dones, weights, batch_indexes) = replay_buffer.sample_with_idxes( idxes, beta=self.prioritized_replay_beta.value( self.num_steps_trained)) else: (obses_t, actions, rewards, obses_tp1, dones) = replay_buffer.sample_with_idxes(idxes) weights = np.ones_like(rewards) batch_indexes = -np.ones_like(rewards) samples[policy_id] = SampleBatch({ "obs": obses_t, "actions": actions, "rewards": rewards, "new_obs": obses_tp1, "dones": dones, "weights": weights, "batch_indexes": batch_indexes }) return MultiAgentBatch(samples, self.train_batch_size)