Source code for smarts.env.utils.action_conversion

# MIT License
#
# Copyright (C) 2023. Huawei Technologies Co., Ltd. All rights reserved.
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
# in the Software without restriction, including without limitation the rights
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
# copies of the Software, and to permit persons to whom the Software is
# furnished to do so, subject to the following conditions:
#
# The above copyright notice and this permission notice shall be included in
# all copies or substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NON-INFRINGEMENT. IN NO EVENT SHALL THE
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
# THE SOFTWARE.
import json
import math
from dataclasses import dataclass, field
from enum import IntEnum
from functools import cached_property, lru_cache
from typing import Any, Callable, Dict

import gymnasium as gym
import numpy as np

from smarts.core.agent_interface import ActionSpaceType, AgentInterface

LINEAR_ACCELERATION_MINIMUM = -1e10
LINEAR_ACCELERATION_MAXIMUM = 1e10
ANGULAR_VELOCITY_MINIMUM = -1e10
ANGULAR_VELOCITY_MAXIMUM = 1e10
SPEED_MINIMUM = -1e10
SPEED_MAXIMUM = 1e10
POSITION_COORDINATE_MINIMUM = -1e10
POSITION_COORDINATE_MAXIMUM = 1e10
DT_MINIMUM = 1e-10
DT_MAXIMUM = 60.0

TRAJECTORY_LENGTH = 20
MPC_ARRAY_COUNT = 4


def _DEFAULT_PASSTHROUGH(action):
    return action


_throttle_break_steering_space = gym.spaces.Box(
    low=np.array([0.0, 0.0, -1.0], dtype=np.float32),
    high=np.array([1.0, 1.0, 1.0], dtype=np.float32),
    dtype=np.float32,
)

_actuator_dynamic_space = _throttle_break_steering_space


_continuous_space = _throttle_break_steering_space


_direct_space = gym.spaces.Box(
    low=np.array(
        [LINEAR_ACCELERATION_MINIMUM, ANGULAR_VELOCITY_MINIMUM], dtype=np.float32
    ),
    high=np.array(
        [LINEAR_ACCELERATION_MAXIMUM, ANGULAR_VELOCITY_MAXIMUM], dtype=np.float32
    ),
    dtype=np.float32,
)


_lane_space = gym.spaces.Discrete(n=4)


def _format_lane_space(action: int):
    _action_to_str = ["keep_lane", "slow_down", "change_lane_left", "change_lane_right"]
    return _action_to_str[action]


_lane_with_continuous_speed_space = gym.spaces.Tuple(
    spaces=[
        gym.spaces.Box(
            low=SPEED_MINIMUM, high=SPEED_MAXIMUM, shape=(), dtype=np.float32
        ),
        gym.spaces.Box(low=-100, high=100, shape=(), dtype=np.int8),
    ]
)

_base_trajectory_space = gym.spaces.Tuple(
    gym.spaces.Box(
        low=np.array([POSITION_COORDINATE_MINIMUM] * TRAJECTORY_LENGTH),
        high=np.array([POSITION_COORDINATE_MAXIMUM] * TRAJECTORY_LENGTH),
        dtype=np.float64,
    )
    for _ in range(MPC_ARRAY_COUNT)
)
_mpc_space = _base_trajectory_space

_base_target_pose_space = gym.spaces.Box(
    low=np.array(
        [POSITION_COORDINATE_MINIMUM, POSITION_COORDINATE_MINIMUM, -math.pi, DT_MINIMUM]
    ),
    high=np.array(
        [POSITION_COORDINATE_MAXIMUM, POSITION_COORDINATE_MAXIMUM, math.pi, DT_MAXIMUM]
    ),
    dtype=np.float64,
)
_multi_target_pose_space = _base_target_pose_space

_target_pose_space = _base_target_pose_space

_relative_target_pose_space = gym.spaces.Box(
    low=np.array([POSITION_COORDINATE_MINIMUM, POSITION_COORDINATE_MINIMUM, -math.pi]),
    high=np.array([POSITION_COORDINATE_MAXIMUM, POSITION_COORDINATE_MAXIMUM, math.pi]),
    dtype=np.float64,
)

_trajectory_space = _base_trajectory_space

_trajectory_with_time_space = gym.spaces.Tuple(
    [
        gym.spaces.Box(
            low=np.array([POSITION_COORDINATE_MINIMUM] * TRAJECTORY_LENGTH),
            high=np.array([POSITION_COORDINATE_MAXIMUM] * TRAJECTORY_LENGTH),
            dtype=np.float64,
        )
        for _ in range(MPC_ARRAY_COUNT)
    ]
    + [
        gym.spaces.Box(
            low=np.array([DT_MINIMUM] * TRAJECTORY_LENGTH),
            high=np.array([DT_MAXIMUM] * TRAJECTORY_LENGTH),
            dtype=np.float64,
        )
    ]
)


[docs]@dataclass(frozen=True) class FormattingGroup: """Describes the conversion necessary to generate the given space.""" space: gym.Space formatting_func: Callable[[Any], Any] = field(default=_DEFAULT_PASSTHROUGH)
[docs]@lru_cache(maxsize=1) def get_formatters() -> Dict[ActionSpaceType, FormattingGroup]: """Get the currently available formatting groups for converting actions from `gym` space standard to SMARTS accepted observations. Returns: Dict[ActionSpaceType, Any]: The currently available formatting groups. """ return { ActionSpaceType.ActuatorDynamic: FormattingGroup( space=_actuator_dynamic_space, ), ActionSpaceType.Continuous: FormattingGroup( space=_continuous_space, ), ActionSpaceType.Direct: FormattingGroup( space=_direct_space, ), ActionSpaceType.Empty: FormattingGroup( space=gym.spaces.Tuple(spaces=()), formatting_func=lambda a: None, ), ActionSpaceType.Lane: FormattingGroup( space=_lane_space, formatting_func=_format_lane_space, ), ActionSpaceType.LaneWithContinuousSpeed: FormattingGroup( space=_lane_with_continuous_speed_space, ), ActionSpaceType.MPC: FormattingGroup( space=_mpc_space, ), ActionSpaceType.MultiTargetPose: FormattingGroup( space=_multi_target_pose_space, ), ActionSpaceType.RelativeTargetPose: FormattingGroup( space=_relative_target_pose_space, ), ActionSpaceType.TargetPose: FormattingGroup( space=_target_pose_space, ), ActionSpaceType.Trajectory: FormattingGroup( space=_trajectory_space, ), ActionSpaceType.TrajectoryWithTime: FormattingGroup( space=_trajectory_with_time_space, ), }
[docs]class ActionOptions(IntEnum): """Defines the options for how the formatting matches the action space.""" multi_agent = 0 """Action must map to partial action space. Only active agents are included.""" full = 1 """Action must map to full action space. Inactive and active agents are included.""" unformatted = 2 """Actions are not reformatted or constrained to action space. Actions must directly map to underlying SMARTS actions.""" default = multi_agent """Defaults to :attr:`multi_agent`."""
[docs]class ActionSpacesFormatter: """Formats actions to adapt SMARTS to `gym` environment requirements. Args: agent_interfaces (Dict[str, AgentInterface]): The agent interfaces needed to determine the shape of the actions. action_options (ActionOptions): Options to configure the end formatting of the actions. """ def __init__( self, agent_interfaces: Dict[str, AgentInterface], action_options: ActionOptions ) -> None: self._agent_interfaces = agent_interfaces self.action_options = action_options for agent_id, agent_interface in agent_interfaces.items(): assert self.supported(agent_interface.action), ( f"Agent `{agent_id}` is using an " f"unsupported `{agent_interface.action}`." f"Available actions:\n{json.dumps(set(agent_interfaces.keys()), indent=2)}" )
[docs] def format(self, actions: Dict[str, Any]): """Format the action to a form that SMARTS can use. Args: actions (Dict[str, Any]): The actions to format. Returns: (Observation, Dict[str, Any]): The formatted actions. """ if self.action_options == ActionOptions.unformatted: return actions out_actions = {} formatting_groups = get_formatters() for agent_id, action in actions.items(): agent_interface = self._agent_interfaces[agent_id] format_ = formatting_groups[agent_interface.action] space: gym.Space = self.space[agent_id] assert space is format_.space dtype = action.dtype if isinstance(action, np.ndarray) else None assert space.contains( action ), f"Action {action} of type `{type(action)}` & {dtype} does not match space {space}!" formatted_action = format_.formatting_func(action) out_actions[agent_id] = formatted_action if self.action_options == ActionOptions.full: assert actions.keys() == self.space.spaces.keys() return out_actions
[docs] @staticmethod def supported(action_type: ActionSpaceType): """Test if the action is in the supported int Args: action_type (ActionSpaceType): The action type to check. Returns: bool: If the action type is supported by the formatter. """ return action_type in get_formatters()
@cached_property def space(self) -> gym.spaces.Dict: """The action space given the current configuration. Returns: gym.spaces.Dict: A description of the action space that this formatter requires. """ if self.action_options is ActionOptions.unformatted: return None return gym.spaces.Dict( { agent_id: get_formatters()[agent_interface.action].space for agent_id, agent_interface in self._agent_interfaces.items() } )