ray.rllib.utils.torch_utils.flatten_inputs_to_1d_tensor#

ray.rllib.utils.torch_utils.flatten_inputs_to_1d_tensor(inputs: numpy.array | jnp.ndarray | tf.Tensor | torch.Tensor | dict | tuple, spaces_struct: gymnasium.spaces.Space | dict | tuple | None = None, time_axis: bool = False) numpy.array | jnp.ndarray | tf.Tensor | torch.Tensor[source]#

Flattens arbitrary input structs according to the given spaces struct.

Returns a single 1D tensor resulting from the different input components’ values.

Thereby: - Boxes (any shape) get flattened to (B, [T]?, -1). Note that image boxes are not treated differently from other types of Boxes and get flattened as well. - Discrete (int) values are one-hot’d, e.g. a batch of [1, 0, 3] (B=3 with Discrete(4) space) results in [[0, 1, 0, 0], [1, 0, 0, 0], [0, 0, 0, 1]]. - MultiDiscrete values are multi-one-hot’d, e.g. a batch of [[0, 2], [1, 4]] (B=2 with MultiDiscrete([2, 5]) space) results in [[1, 0, 0, 0, 1, 0, 0], [0, 1, 0, 0, 0, 0, 1]].

Parameters:
  • inputs – The inputs to be flattened.

  • spaces_struct – The structure of the spaces that behind the input

  • time_axis – Whether all inputs have a time-axis (after the batch axis). If True, will keep not only the batch axis (0th), but the time axis (1st) as-is and flatten everything from the 2nd axis up.

Returns:

A single 1D tensor resulting from concatenating all flattened/one-hot’d input components. Depending on the time_axis flag, the shape is (B, n) or (B, T, n).

from gymnasium.spaces import Discrete, Box
from ray.rllib.utils.torch_utils import flatten_inputs_to_1d_tensor
import torch
struct = {
    "a": np.array([1, 3]),
    "b": (
        np.array([[1.0, 2.0], [4.0, 5.0]]),
        np.array(
            [[[8.0], [7.0]], [[5.0], [4.0]]]
        ),
    ),
        "c": {
            "cb": np.array([1.0, 2.0]),
        },
}
struct_torch = tree.map_structure(lambda s: torch.from_numpy(s), struct)
spaces = dict(
    {
        "a": gym.spaces.Discrete(4),
        "b": (gym.spaces.Box(-1.0, 10.0, (2,)), gym.spaces.Box(-1.0, 1.0, (2,
                1))),
        "c": dict(
            {
                "cb": gym.spaces.Box(-1.0, 1.0, ()),
            }
        ),
    }
)
print(flatten_inputs_to_1d_tensor(struct_torch, spaces_struct=spaces))
tensor([[0., 1., 0., 0., 1., 2., 8., 7., 1.],
        [0., 0., 0., 1., 4., 5., 5., 4., 2.]])