Source code for smarts.core.traffic_history

# Copyright (C) 2020. Huawei Technologies Co., Ltd. All rights reserved.
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
# in the Software without restriction, including without limitation the rights
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
# copies of the Software, and to permit persons to whom the Software is
# furnished to do so, subject to the following conditions:
#
# The above copyright notice and this permission notice shall be included in
# all copies or substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NON-INFRINGEMENT. IN NO EVENT SHALL THE
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
# THE SOFTWARE.

# to allow for typing to refer to class being defined (TrafficHistory)
from __future__ import annotations

import logging
import os
import random
import sqlite3
from contextlib import closing, nullcontext
from functools import cached_property, lru_cache
from typing import (
    Dict,
    Generator,
    NamedTuple,
    Optional,
    Set,
    Tuple,
    Type,
    TypeVar,
    Union,
)

from smarts.core.coordinates import Dimensions
from smarts.core.utils.core_math import radians_to_vec
from smarts.core.vehicle import VEHICLE_CONFIGS

T = TypeVar("T")


[docs]class TrafficHistory: """Traffic history for use with converted datasets.""" def __init__(self, db): self._log = logging.getLogger(self.__class__.__name__) self._db = db self._db_cnxn = None @property def name(self) -> str: """The name of the traffic history.""" return os.path.splitext(self._db.name)[0]
[docs] def connect_for_multiple_queries(self): """Optional optimization to avoid the overhead of parsing the `sqlite` file header multiple times for clients that will be performing multiple queries. If used, then disconnect() should be called when finished.""" if not self._db_cnxn: self._db_cnxn = sqlite3.connect(self._db.path)
[docs] def disconnect(self): """End connection with the history database.""" if self._db_cnxn: self._db_cnxn.close() self._db_cnxn = None
def _query_val( self, result_type: Type[T], query: str, params: Tuple = () ) -> Optional[T]: with nullcontext(self._db_cnxn) if self._db_cnxn else closing( sqlite3.connect(self._db) ) as dbcnxn: cur = dbcnxn.cursor() cur.execute(query, params) row = cur.fetchone() cur.close() if not row: return None return row if result_type is tuple else result_type(row[0]) def _query_list( self, query: str, params: Tuple = () ) -> Generator[Tuple, None, None]: with nullcontext(self._db_cnxn) if self._db_cnxn else closing( sqlite3.connect(self._db) ) as dbcnxn: cur = dbcnxn.cursor() for row in cur.execute(query, params): yield row cur.close() @cached_property def dataset_source(self) -> Optional[str]: """The known source of the history data""" query = "SELECT value FROM Spec where key='source_type'" return self._query_val(str, query) @cached_property def lane_width(self) -> Optional[float]: """The general lane width in the history data""" query = "SELECT value FROM Spec where key='map_net.lane_width'" return self._query_val(float, query) @cached_property def target_speed(self) -> Optional[float]: """The general speed limit in the history data.""" query = "SELECT value FROM Spec where key='speed_limit_mps'" return self._query_val(float, query)
[docs] def all_vehicle_ids(self) -> Generator[int, None, None]: """Get the ids of all vehicles in the history data""" query = "SELECT id FROM Vehicle" return (row[0] for row in self._query_list(query))
@cached_property def ego_vehicle_id(self) -> Optional[int]: """The id of the ego's actor in the history data.""" query = "SELECT id FROM Vehicle WHERE is_ego_vehicle = 1" ego_id = self._query_val(int, query) return ego_id
[docs] @lru_cache(maxsize=32) def vehicle_initial_time(self, vehicle_id: str) -> float: """Returns the initial time the specified vehicle is seen in the history data.""" query = "SELECT MIN(sim_time) FROM Trajectory WHERE vehicle_id = ?" return self._query_val(float, query, params=(vehicle_id,))
[docs] @lru_cache(maxsize=32) def vehicle_final_exit_time(self, vehicle_id: str) -> float: """Returns the final time the specified vehicle is seen in the history data.""" query = "SELECT MAX(sim_time) FROM Trajectory WHERE vehicle_id = ?" return self._query_val(float, query, params=(vehicle_id,))
[docs] @lru_cache(maxsize=32) def vehicle_final_position(self, vehicle_id: str) -> Tuple[float, float]: """Returns the final (x,y) position for the specified vehicle in the history data.""" query = "SELECT position_x, position_y FROM Trajectory WHERE vehicle_id=? AND sim_time=(SELECT MAX(sim_time) FROM Trajectory WHERE vehicle_id=?)" return self._query_val( tuple, query, params=( vehicle_id, vehicle_id, ), )
[docs] def decode_vehicle_type(self, vehicle_type: int) -> str: """Convert from the dataset type id to their config type. Options from NGSIM and INTERACTION currently include: 1=motorcycle, 2=auto, 3=truck, 4=pedestrian/bicycle This actually returns a ``vehicle_config_type``. """ if vehicle_type == 1: return "motorcycle" elif vehicle_type == 2: return "passenger" elif vehicle_type == 3: return "truck" elif vehicle_type == 4: return "pedestrian" else: self._log.warning( f"unsupported vehicle_type ({vehicle_type}) in history data." ) return "passenger"
[docs] @lru_cache(maxsize=32) def vehicle_config_type(self, vehicle_id: str) -> str: """Find the configuration type of the specified vehicle.""" query = "SELECT type FROM Vehicle WHERE id = ?" veh_type = self._query_val(int, query, params=(vehicle_id,)) return self.decode_vehicle_type(veh_type)
def _resolve_vehicle_dims( self, vehicle_type: Union[str, int], length: float, width: float, height: float ): v_type = vehicle_type if isinstance(v_type, int): v_type = self.decode_vehicle_type(v_type) default_dims = VEHICLE_CONFIGS[v_type].dimensions if not length: length = default_dims.length if not width: width = default_dims.width if not height: height = default_dims.height return Dimensions(length, width, height)
[docs] @lru_cache(maxsize=32) def vehicle_dims(self, vehicle_id: str) -> Dimensions: """Get the vehicle dimensions of the specified vehicle.""" # do import here to break circular dependency chain from smarts.core.vehicle import VEHICLE_CONFIGS query = "SELECT length, width, height, type FROM Vehicle WHERE id = ?" length, width, height, veh_type = self._query_val( tuple, query, params=(vehicle_id,) ) return self._resolve_vehicle_dims(veh_type, length, width, height)
[docs] def first_seen_times(self) -> Generator[Tuple[int, float], None, None]: """Find the times each vehicle is first seen in the traffic history. XXX: For now, limit agent missions to just passenger cars (V.type = 2) """ query = """SELECT T.vehicle_id, MIN(T.sim_time) FROM Trajectory AS T INNER JOIN Vehicle AS V ON T.vehicle_id=V.id WHERE V.type = 2 GROUP BY vehicle_id""" return self._query_list(query)
[docs] def last_seen_vehicle_time(self) -> Optional[float]: """Find the time the last vehicle exits the history.""" query = """SELECT MAX(T.sim_time) FROM Trajectory AS T INNER JOIN Vehicle AS V ON T.vehicle_id=V.id WHERE V.type = 2 ORDER BY T.sim_time DESC LIMIT 1""" return self._query_val(float, query)
[docs] def vehicle_pose_at_time( self, vehicle_id: str, sim_time: float ) -> Optional[Tuple[float, float, float, float]]: """Get the pose of the specified vehicle at the specified history time.""" query = """SELECT position_x, position_y, heading_rad, speed FROM Trajectory WHERE vehicle_id = ? and sim_time = ?""" return self._query_val(tuple, query, params=(int(vehicle_id), float(sim_time)))
[docs] def vehicle_ids_active_between( self, start_time: float, end_time: float ) -> Generator[Tuple, None, None]: """Find the ids of all active vehicles between the given history times. XXX: For now, limited to just passenger cars (V.type = 2) XXX: This looks like the wrong level to filter out vehicles """ query = """SELECT DISTINCT T.vehicle_id FROM Trajectory AS T INNER JOIN Vehicle AS V ON T.vehicle_id=V.id WHERE ? <= T.sim_time AND T.sim_time <= ? AND V.type = 2""" return self._query_list(query, (start_time, end_time))
[docs] class VehicleRow(NamedTuple): """Vehicle state information""" vehicle_id: int vehicle_type: int vehicle_length: float vehicle_width: float vehicle_height: float position_x: float position_y: float heading_rad: float speed: float
[docs] class TrafficHistoryVehicleWindow(NamedTuple): """General information about a vehicle between a time window.""" vehicle_id: int vehicle_type: str vehicle_length: float vehicle_width: float vehicle_height: float start_position_x: float start_position_y: float start_heading: float start_speed: float average_speed: float start_time: float end_time: float end_position_x: float end_position_y: float end_heading: float @property def axle_start_position(self): """The start position of the vehicle from the axle.""" hhx, hhy = radians_to_vec(self.start_heading) * (0.5 * self.vehicle_length) return [self.start_position_x + hhx, self.start_position_y + hhy] @property def axle_end_position(self): """The start position of the vehicle from the axle.""" hhx, hhy = radians_to_vec(self.end_heading) * (0.5 * self.vehicle_length) return [self.end_position_x + hhx, self.end_position_y + hhy] @property def dimensions(self) -> Dimensions: """The known or estimated dimensions of this vehicle.""" return Dimensions( self.vehicle_length, self.vehicle_width, self.vehicle_height )
[docs] def vehicles_active_between( self, start_time: float, end_time: float ) -> Generator[TrafficHistory.VehicleRow, None, None]: """Find all vehicles active between the given history times.""" query = """SELECT V.id, V.type, V.length, V.width, V.height, T.position_x, T.position_y, T.heading_rad, T.speed FROM Vehicle AS V INNER JOIN Trajectory AS T ON V.id = T.vehicle_id WHERE T.sim_time > ? AND T.sim_time <= ? ORDER BY T.sim_time DESC""" rows = self._query_list(query, (start_time, end_time)) return (TrafficHistory.VehicleRow(*row) for row in rows)
[docs] class TrafficLightRow(NamedTuple): """Fields in a row from the TrafficLightState table.""" sim_time: float state: int stop_point_x: float stop_point_y: float lane_id: int
[docs] def traffic_light_states_between( self, start_time: float, end_time: float ) -> Generator[TrafficHistory.TrafficLightRow, None, None]: """Find all traffic light states between the given history times.""" query = """SELECT sim_time, state, stop_point_x, stop_point_y, lane FROM TrafficLightState WHERE sim_time > ? AND sim_time <= ? ORDER BY sim_time ASC""" rows = self._query_list(query, (start_time, end_time)) return (TrafficHistory.TrafficLightRow(*row) for row in rows)
@staticmethod def _greatest_n_per_group_format( select, table, group_by, greatest_of_group, where=None, operation="MAX" ): """This solves the issue where you want to get the highest value of `greatest_of_group` in the group `groupby` for versions of sqlite3 that are lower than `3.7.11`. See: https://stackoverflow.com/questions/12608025/how-to-construct-a-sqlite-query-to-group-by-order e.g. Get a table of the highest speed(`greatest_of_group`) each vehicle(`group_by`) was operating at. > _greatest_n_per_group_format( > select="vehicle_speed, > vehicle_id", > table="Trajectory", > group_by="vehicle_id", > greatest_of_group="vehicle_speed", > ) """ where = f"{where} AND" if where else "" return f""" SELECT {select}, (SELECT COUNT({group_by}) AS count FROM {table} m2 WHERE m1.{group_by} = m2.{group_by}) FROM {table} m1 WHERE {greatest_of_group} = (SELECT {operation}({greatest_of_group}) FROM {table} m3 WHERE {where} m1.{group_by} = m3.{group_by}) GROUP BY {group_by} """ def _window_from_row(self, row): return TrafficHistory.TrafficHistoryVehicleWindow( row[0], self.decode_vehicle_type(row[1]), *self._resolve_vehicle_dims(row[1], *row[2:5]).as_lwh, *row[5:], )
[docs] def vehicle_window_by_id( self, vehicle_id: str, ) -> Optional[TrafficHistory.TrafficHistoryVehicleWindow]: """Find the given vehicle by its id.""" query = """SELECT V.id, V.type, V.length, V.width, V.height, S.position_x, S.position_y, S.heading_rad, S.speed, D.avg_speed, S.sim_time, E.sim_time, E.position_x, E.position_y, E.heading_rad FROM Vehicle AS V INNER JOIN ( SELECT vehicle_id, AVG(speed) as "avg_speed" FROM Trajectory WHERE vehicle_id = ? ) AS D ON V.id = D.vehicle_id INNER JOIN ( SELECT vehicle_id, MIN(sim_time) as sim_time, speed, position_x, position_y, heading_rad FROM Trajectory WHERE vehicle_id = ? ) AS S ON V.id = S.vehicle_id INNER JOIN ( SELECT vehicle_id, MAX(sim_time) as sim_time, speed, position_x, position_y, heading_rad FROM Trajectory WHERE vehicle_id = ? ) AS E ON V.id = E.vehicle_id """ rows = self._query_list( query, tuple([vehicle_id] * 3), ) row = next(rows, None) if row is None: return None return self._window_from_row(row)
[docs] def vehicle_windows_in_range( self, exists_at_or_after: float, ends_before: float, minimum_vehicle_window: float, ) -> Generator[TrafficHistory.TrafficHistoryVehicleWindow, None, None]: """Find all vehicles active between the given history times.""" query = f"""SELECT V.id, V.type, V.length, V.width, V.height, S.position_x, S.position_y, S.heading_rad, S.speed, D.avg_speed, S.sim_time, E.sim_time, E.position_x, E.position_y, E.heading_rad FROM Vehicle AS V INNER JOIN ( SELECT vehicle_id, AVG(speed) as "avg_speed" FROM (SELECT vehicle_id, sim_time, speed from Trajectory WHERE sim_time >= ? AND sim_time < ?) GROUP BY vehicle_id ) AS D ON V.id = D.vehicle_id INNER JOIN ( {self._greatest_n_per_group_format( select='''vehicle_id, sim_time, speed, position_x, position_y, heading_rad''', table='Trajectory', group_by='vehicle_id', greatest_of_group='sim_time', where='sim_time >= ?', operation="MIN" )} ) AS S ON V.id = S.vehicle_id INNER JOIN ( {self._greatest_n_per_group_format( select='''vehicle_id, sim_time, speed, position_x, position_y, heading_rad''', table='Trajectory', group_by='vehicle_id', greatest_of_group='sim_time', where='sim_time < ?' )} ) AS E ON V.id = E.vehicle_id WHERE E.sim_time - S.sim_time >= ? GROUP BY V.id ORDER BY S.sim_time """ rows = self._query_list( query, ( exists_at_or_after, ends_before, exists_at_or_after, ends_before, minimum_vehicle_window, ), ) seen = set() seen_pos_x = set() rs = list() def _window_from_row_debug(row): nonlocal seen, seen_pos_x, rs r = self._window_from_row(row) assert r.vehicle_id not in seen assert r.end_time - r.start_time >= minimum_vehicle_window assert r.start_time >= exists_at_or_after assert r.end_time <= ends_before assert r.end_position_x not in seen_pos_x seen_pos_x.add(r.end_position_x) seen.add(r.vehicle_id) rs.append(r) return r return (_window_from_row_debug(row) for row in rows)
[docs] class TrajectoryRow(NamedTuple): """An instant in a trajectory""" position_x: float position_y: float heading_rad: float speed: float
[docs] def vehicle_trajectory( self, vehicle_id: str ) -> Generator[TrafficHistory.TrajectoryRow, None, None]: """Get the trajectory of the specified vehicle""" query = """SELECT T.position_x, T.position_y, T.heading_rad, T.speed FROM Trajectory AS T WHERE T.vehicle_id = ?""" rows = self._query_list(query, (vehicle_id,)) return (TrafficHistory.TrajectoryRow(*row) for row in rows)
[docs] def random_overlapping_sample( self, vehicle_start_times: Dict[str, float], k: int ) -> Set[str]: """Grab a sample containing a subset of specified vehicles and ensure overlapping time intervals across sample. Note: this may return a sample with less than k if we're unable to find k overlapping. """ # This is inefficient, but it's not that important choice = random.choice(list(vehicle_start_times.keys())) sample = {choice} sample_start_time = vehicle_start_times[choice] sample_end_time = self.vehicle_final_exit_time(choice) while len(sample) < k: choices = list(vehicle_start_times.keys()) if len(choices) <= len(sample): break choice = str(random.choice(choices)) sample_start_time = min(vehicle_start_times[choice], sample_start_time) sample_end_time = max(self.vehicle_final_exit_time(choice), sample_end_time) sample.add(choice) return sample