Source code for ray.tune.suggest.sigopt

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import copy
import os
import logging
import pickle
try:
    import sigopt as sgo
except ImportError:
    sgo = None

from ray.tune.suggest.suggestion import SuggestionAlgorithm

logger = logging.getLogger(__name__)


[docs]class SigOptSearch(SuggestionAlgorithm): """A wrapper around SigOpt to provide trial suggestions. Requires SigOpt to be installed. Requires user to store their SigOpt API key locally as an environment variable at `SIGOPT_KEY`. Parameters: space (list of dict): SigOpt configuration. Parameters will be sampled from this configuration and will be used to override parameters generated in the variant generation process. name (str): Name of experiment. Required by SigOpt. max_concurrent (int): Number of maximum concurrent trials supported based on the user's SigOpt plan. Defaults to 1. metric (str): The training result objective value attribute. mode (str): One of {min, max}. Determines whether objective is minimizing or maximizing the metric attribute. Example: >>> space = [ >>> { >>> 'name': 'width', >>> 'type': 'int', >>> 'bounds': { >>> 'min': 0, >>> 'max': 20 >>> }, >>> }, >>> { >>> 'name': 'height', >>> 'type': 'int', >>> 'bounds': { >>> 'min': -100, >>> 'max': 100 >>> }, >>> }, >>> ] >>> algo = SigOptSearch( >>> space, name="SigOpt Example Experiment", >>> max_concurrent=1, metric="mean_loss", mode="min") """ def __init__(self, space, name="Default Tune Experiment", max_concurrent=1, reward_attr=None, metric="episode_reward_mean", mode="max", **kwargs): assert sgo is not None, "SigOpt must be installed!" assert type(max_concurrent) is int and max_concurrent > 0 assert "SIGOPT_KEY" in os.environ, \ "SigOpt API key must be stored as environ variable at SIGOPT_KEY" assert mode in ["min", "max"], "`mode` must be 'min' or 'max'!" if reward_attr is not None: mode = "max" metric = reward_attr logger.warning( "`reward_attr` is deprecated and will be removed in a future " "version of Tune. " "Setting `metric={}` and `mode=max`.".format(reward_attr)) if "use_early_stopped_trials" in kwargs: logger.warning( "`use_early_stopped_trials` is not used in SigOptSearch.") self._max_concurrent = max_concurrent self._metric = metric if mode == "max": self._metric_op = 1. elif mode == "min": self._metric_op = -1. self._live_trial_mapping = {} # Create a connection with SigOpt API, requires API key self.conn = sgo.Connection(client_token=os.environ["SIGOPT_KEY"]) self.experiment = self.conn.experiments().create( name=name, parameters=space, parallel_bandwidth=self._max_concurrent, ) super(SigOptSearch, self).__init__(**kwargs) def _suggest(self, trial_id): if self._num_live_trials() >= self._max_concurrent: return None # Get new suggestion from SigOpt suggestion = self.conn.experiments( self.experiment.id).suggestions().create() self._live_trial_mapping[trial_id] = suggestion return copy.deepcopy(suggestion.assignments) def on_trial_result(self, trial_id, result): pass def on_trial_complete(self, trial_id, result=None, error=False, early_terminated=False): """Notification for the completion of trial. If a trial fails, it will be reported as a failed Observation, telling the optimizer that the Suggestion led to a metric failure, which updates the feasible region and improves parameter recommendation. Creates SigOpt Observation object for trial. """ if result: self.conn.experiments(self.experiment.id).observations().create( suggestion=self._live_trial_mapping[trial_id].id, value=self._metric_op * result[self._metric], ) # Update the experiment object self.experiment = self.conn.experiments(self.experiment.id).fetch() elif error or early_terminated: # Reports a failed Observation self.conn.experiments(self.experiment.id).observations().create( failed=True, suggestion=self._live_trial_mapping[trial_id].id) del self._live_trial_mapping[trial_id] def _num_live_trials(self): return len(self._live_trial_mapping) def save(self, checkpoint_dir): trials_object = (self.conn, self.experiment) with open(checkpoint_dir, "wb") as outputFile: pickle.dump(trials_object, outputFile) def restore(self, checkpoint_dir): with open(checkpoint_dir, "rb") as inputFile: trials_object = pickle.load(inputFile) self.conn = trials_object[0] self.experiment = trials_object[1]