Source code for smarts.core.simulation_frame

# MIT License
#
# Copyright (C) 2022. Huawei Technologies Co., Ltd. All rights reserved.
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
# in the Software without restriction, including without limitation the rights
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
# copies of the Software, and to permit persons to whom the Software is
# furnished to do so, subject to the following conditions:
#
# The above copyright notice and this permission notice shall be included in
# all copies or substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NON-INFRINGEMENT. IN NO EVENT SHALL THE
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
# THE SOFTWARE.
from __future__ import annotations

import logging
import re
from dataclasses import dataclass
from functools import cached_property, lru_cache
from typing import TYPE_CHECKING, Dict, List, Optional, Set

if TYPE_CHECKING:
    from smarts.core.actor import ActorState
    from smarts.core.agent_interface import AgentInterface
    from smarts.core.provider import ProviderState
    from smarts.core.sensor import Sensor
    from smarts.core.sensors import SensorState
    from smarts.core.vehicle_state import Collision, VehicleState
    from smarts.sstudio.sstypes import MapSpec

logger = logging.getLogger(__name__)


[docs]@dataclass(frozen=True) class SimulationFrame: """This is state that should change per step of the simulation.""" actor_states: List[ActorState] agent_vehicle_controls: Dict[str, str] agent_interfaces: Dict[str, AgentInterface] ego_ids: Set[str] pending_agent_ids: List[str] elapsed_sim_time: float fixed_timestep: float resetting: bool map_spec: MapSpec last_dt: float last_provider_state: ProviderState step_count: int vehicle_collisions: Dict[str, List[Collision]] vehicles_for_agents: Dict[str, List[str]] vehicle_ids: Set[str] vehicle_states: Dict[str, VehicleState] vehicle_sensors: Dict[str, Dict[str, Sensor]] sensor_states: Dict[str, SensorState] interest_filter: re.Pattern # TODO MTA: renderer can be allowed here as long as it is only type information # renderer_type: Any = None _ground_bullet_id: Optional[str] = None @cached_property def agent_ids(self) -> Set[str]: """Get the ids of all agents that currently have vehicles.""" return set(self.vehicles_for_agents.keys()) @cached_property def potential_agent_ids(self) -> Set[str]: """This includes current agent ids and future pending ego agent ids.""" return set(self.vehicles_for_agents.keys()) | set(self.pending_agent_ids) @cached_property def actor_states_by_id(self) -> Dict[str, ActorState]: """Get actor states paired by their ids.""" return {a_s.actor_id: a_s for a_s in self.actor_states}
[docs] @lru_cache(28) def interest_actors( self, extension: Optional[re.Pattern] = None ) -> Dict[str, ActorState]: """Get the actor states of actors that are marked as of interest. Args: extension (re.Pattern): A matching for interest actors not defined in scenario. """ _matchers: List[re.Pattern] = [] if self.interest_filter.pattern: _matchers.append(self.interest_filter) if extension is not None and extension.pattern: _matchers.append(extension) if len(_matchers) > 0: return { a_s.actor_id: a_s for a_s in self.actor_states if any(bool(m.match(a_s.actor_id)) for m in _matchers) } return {}
[docs] def actor_is_interest( self, actor_id: str, extension: Optional[re.Pattern] = None ) -> bool: """Determine if the actor is of interest. Args: actor_id (str): The id of the actor to test. Returns: bool: If the actor is of interest. """ return actor_id in self.interest_actors(extension)
[docs] def vehicle_did_collide(self, vehicle_id: str) -> bool: """Test if the given vehicle had any collisions in the last physics update.""" vehicle_collisions = self.vehicle_collisions.get(vehicle_id, []) for c in vehicle_collisions: if c.collidee_id != self._ground_bullet_id: return True return False
[docs] def filtered_vehicle_collisions(self, vehicle_id: str) -> List[Collision]: """Get a list of all collisions the given vehicle was involved in during the last physics update. """ vehicle_collisions = self.vehicle_collisions.get(vehicle_id, []) return [ c for c in vehicle_collisions if c.collidee_id != self._ground_bullet_id ]
@cached_property def _hash(self): return self.step_count ^ hash(self.fixed_timestep) ^ hash(self.map_spec) def __hash__(self): return self._hash def __post_init__(self): if logger.isEnabledFor(logging.DEBUG): assert isinstance(self.actor_states, list) assert self.agent_ids.union(self.vehicles_for_agents) == self.agent_ids assert len(self.agent_ids - set(self.agent_interfaces)) == 0 assert not len(self.vehicle_ids.symmetric_difference(self.vehicle_states))