Source code for zoo.policies.chase_via_points_agent

import numpy as np

from smarts.core.agent import Agent
from smarts.core.observations import Observation
from smarts.core.sensors import LANE_ID_CONSTANT


[docs]class ChaseViaPointsAgent(Agent):
[docs] def act(self, obs: Observation): assert obs.waypoint_paths, ( f"Waypoint paths = {obs.waypoint_paths}; " "cannot be empty or None. Enable waypoint paths in agent interface." ) lane_change_dist = 80 # Truncate all paths to be of the same length min_len = min(lane_change_dist, min(map(len, obs.waypoint_paths))) trunc_waypoints = list(map(lambda x: x[:min_len], obs.waypoint_paths)) waypoints = [list(map(lambda x: x.pos, path)) for path in trunc_waypoints] waypoints = np.array(waypoints, dtype=np.float64) # Ego status ego_lane_id = obs.ego_vehicle_state.lane_id assert ego_lane_id is not LANE_ID_CONSTANT, f"Ego lane cannot be {ego_lane_id}." ego_pos = obs.ego_vehicle_state.position[:2] dist = np.linalg.norm(waypoints[:, 0, :] - ego_pos, axis=-1) ego_wp_inds = np.where(dist == dist.min())[0] # Get target via point. via_points = np.array( [via_point.position for via_point in obs.via_data.near_via_points] ) via_point_wp_ind, via_point_ind = _nearest_waypoint(waypoints, via_points) # No nearby via points. Hence, remain in same lane. if via_point_ind is None: return (obs.waypoint_paths[ego_wp_inds[0]][0].speed_limit, 0) # Target via point is in the same path. Hence, remain in same lane. if via_point_wp_ind[0] in ego_wp_inds: return (obs.via_data.near_via_points[via_point_ind].required_speed, 0) # Turn leftwards if (via_point_wp_ind[0] - ego_wp_inds[0]) > 0 , as target via point is on the left. # Turn rightwards if (via_point_wp_ind[0] - ego_wp_inds[0]) < 0 , as target via point is on the right. return ( obs.via_data.near_via_points[via_point_ind].required_speed, via_point_wp_ind[0] - ego_wp_inds[0], )
def _nearest_waypoint(matrix: np.ndarray, points: np.ndarray, radius: float = 2): cur_point_index = ((np.intp(1e10), np.intp(1e10)), None) if points.shape == (0,): return cur_point_index assert len(matrix.shape) == 3 assert matrix.shape[2] == 2 assert len(points.shape) == 2 assert points.shape[1] == 2 points_expanded = np.expand_dims(points, (1, 2)) diff = matrix - points_expanded dist = np.linalg.norm(diff, axis=-1) for ii in range(points.shape[0]): index = np.argmin(dist[ii]) index_unravel = np.unravel_index(index, dist[ii].shape) min_dist = dist[ii][index_unravel] if min_dist <= radius and index_unravel[1] < cur_point_index[0][1]: cur_point_index = (index_unravel, ii) return cur_point_index