# Copyright (C) 2020. Huawei Technologies Co., Ltd. All rights reserved.
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
# in the Software without restriction, including without limitation the rights
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
# copies of the Software, and to permit persons to whom the Software is
# furnished to do so, subject to the following conditions:
#
# The above copyright notice and this permission notice shall be included in
# all copies or substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NON-INFRINGEMENT. IN NO EVENT SHALL THE
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
# THE SOFTWARE.
from __future__ import annotations
import weakref
from typing import TYPE_CHECKING, Iterable, List, Sequence, Set
from .actor import ActorRole
from .controllers import ActionSpaceType
from .provider import Provider, ProviderManager, ProviderRecoveryFlags, ProviderState
from .scenario import Scenario
from .utils.file import replace
if TYPE_CHECKING:
from .vehicle_state import VehicleState
[docs]class ExternalProvider(Provider):
"""A provider that is intended to be used for external intervention in the simulation.
Vehicles managed by this provider cannot be hijacked by social agents
and may have privileged VehicleStates."""
def __init__(self, sim):
# start with the default recovery flags...
self._recovery_flags = super().recovery_flags
self.set_manager(sim)
self.reset()
[docs] def set_manager(self, manager: ProviderManager):
self._sim = weakref.ref(manager)
@property
def _sim_time(self) -> float:
sim = self._sim()
assert sim
# pytype: disable=attribute-error
# TAI: consider adding to ProviderManager interface
return sim.elapsed_sim_time
# pytype: enable=attribute-error
@property
def _vehicle_index(self):
sim = self._sim()
assert sim
# pytype: disable=attribute-error
# TAI: consider adding to ProviderManager interface
return sim.vehicle_index
# pytype: enable=attribute-error
@property
def recovery_flags(self) -> ProviderRecoveryFlags:
return self._recovery_flags
@recovery_flags.setter
def recovery_flags(self, flags: ProviderRecoveryFlags):
self._recovery_flags = flags
[docs] def reset(self):
self._ext_vehicle_states = []
self._sent_states = None
self._last_step_delta = None
self._last_fresh_step = self._sim_time
[docs] def state_update(
self,
vehicle_states: Sequence[VehicleState],
step_delta: float,
):
"""Update vehicle states. Use `all_vehicle_states()` to look at previous states."""
self._ext_vehicle_states = [
replace(vs, source=self.source_str, role=ActorRole.External)
for vs in vehicle_states
]
self._last_step_delta = step_delta
@property
def actions(self) -> Set[ActionSpaceType]:
return set()
@property
def _provider_state(self):
dt = self._sim_time - self._last_fresh_step
if id(self._ext_vehicle_states) != id(self._sent_states):
self._last_fresh_step = self._sim_time
self._sent_states = self._ext_vehicle_states
return ProviderState(actors=self._ext_vehicle_states, dt=dt)
[docs] def setup(self, scenario: Scenario) -> ProviderState:
return self._provider_state
[docs] def step(self, actions, dt: float, elapsed_sim_time: float) -> ProviderState:
return self._provider_state
[docs] def sync(self, provider_state: ProviderState):
pass
[docs] def teardown(self):
self.reset()
@property
def all_vehicle_states(self) -> List[VehicleState]:
"""Get all current vehicle states."""
result = []
for vehicle in self._vehicle_index.vehicles:
if vehicle.subscribed_to_accelerometer_sensor:
linear_acc, angular_acc, _, _ = vehicle.accelerometer_sensor(
vehicle.state.linear_velocity,
vehicle.state.angular_velocity,
self._last_step_delta,
)
result.append(vehicle.state)
return result
[docs] def manages_actor(self, actor_id: str) -> bool:
for vs in self._ext_vehicle_states:
if vs.actor_id == actor_id:
return True
return False
@property
def actor_ids(self) -> Iterable[str]:
"""A set of actors that this provider manages.
Returns:
Iterable[str]: The actors this provider manages.
"""
return set(vs.actor_id for vs in self._ext_vehicle_states)