Source code for ray.rllib.models.tf.tf_modelv2

import contextlib
import gymnasium as gym
import re
from typing import Dict, List, Union

from ray.util import log_once
from ray.rllib.models.modelv2 import ModelV2
from ray.rllib.utils.annotations import OldAPIStack, override
from ray.rllib.utils.deprecation import deprecation_warning
from ray.rllib.utils.framework import try_import_tf
from ray.rllib.utils.typing import ModelConfigDict, TensorType

tf1, tf, tfv = try_import_tf()


[docs]@OldAPIStack class TFModelV2(ModelV2): """TF version of ModelV2, which should contain a tf keras Model. Note that this class by itself is not a valid model unless you implement forward() in a subclass."""
[docs] def __init__( self, obs_space: gym.spaces.Space, action_space: gym.spaces.Space, num_outputs: int, model_config: ModelConfigDict, name: str, ): """Initializes a TFModelV2 instance. Here is an example implementation for a subclass ``MyModelClass(TFModelV2)``:: def __init__(self, *args, **kwargs): super(MyModelClass, self).__init__(*args, **kwargs) input_layer = tf.keras.layers.Input(...) hidden_layer = tf.keras.layers.Dense(...)(input_layer) output_layer = tf.keras.layers.Dense(...)(hidden_layer) value_layer = tf.keras.layers.Dense(...)(hidden_layer) self.base_model = tf.keras.Model( input_layer, [output_layer, value_layer]) """ super().__init__( obs_space, action_space, num_outputs, model_config, name, framework="tf" ) # Deprecated: TFModelV2 now automatically track their variables. self.var_list = [] if tf1.executing_eagerly(): self.graph = None else: self.graph = tf1.get_default_graph()
[docs] def context(self) -> contextlib.AbstractContextManager: """Returns a contextmanager for the current TF graph.""" if self.graph: return self.graph.as_default() else: return ModelV2.context(self)
[docs] def update_ops(self) -> List[TensorType]: """Return the list of update ops for this model. For example, this should include any BatchNorm update ops.""" return []
[docs] def register_variables(self, variables: List[TensorType]) -> None: """Register the given list of variables with this model.""" if log_once("deprecated_tfmodelv2_register_variables"): deprecation_warning(old="TFModelV2.register_variables", error=False) self.var_list.extend(variables)
[docs] @override(ModelV2) def variables( self, as_dict: bool = False ) -> Union[List[TensorType], Dict[str, TensorType]]: if as_dict: # Old way using `register_variables`. if self.var_list: return {v.name: v for v in self.var_list} # New way: Automatically determine the var tree. else: return self._find_sub_modules("", self.__dict__) # Old way using `register_variables`. if self.var_list: return list(self.var_list) # New way: Automatically determine the var tree. else: return list(self.variables(as_dict=True).values())
[docs] @override(ModelV2) def trainable_variables( self, as_dict: bool = False ) -> Union[List[TensorType], Dict[str, TensorType]]: if as_dict: return { k: v for k, v in self.variables(as_dict=True).items() if v.trainable } return [v for v in self.variables() if v.trainable]
@staticmethod def _find_sub_modules(current_key, struct): # Keras Model: key=k + "." + var-name (replace '/' by '.'). if isinstance(struct, tf.keras.models.Model) or isinstance(struct, tf.Module): ret = {} for var in struct.variables: name = re.sub("/", ".", var.name) key = current_key + "." + name ret[key] = var return ret # Other TFModelV2: Include its vars into ours. elif isinstance(struct, TFModelV2): return { current_key + "." + key: var for key, var in struct.variables(as_dict=True).items() } # tf.Variable elif isinstance(struct, tf.Variable): return {current_key: struct} # List/Tuple. elif isinstance(struct, (tuple, list)): ret = {} for i, value in enumerate(struct): sub_vars = TFModelV2._find_sub_modules( current_key + "_{}".format(i), value ) ret.update(sub_vars) return ret # Dict. elif isinstance(struct, dict): if current_key: current_key += "_" ret = {} for key, value in struct.items(): sub_vars = TFModelV2._find_sub_modules(current_key + str(key), value) ret.update(sub_vars) return ret return {}