Source code for smarts.waymo.waymo_utils

# Copyright (C) 2022. Huawei Technologies Co., Ltd. All rights reserved.
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
# in the Software without restriction, including without limitation the rights
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
# copies of the Software, and to permit persons to whom the Software is
# furnished to do so, subject to the following conditions:
#
# The above copyright notice and this permission notice shall be included in
# all copies or substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NON-INFRINGEMENT. IN NO EVENT SHALL THE
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
# THE SOFTWARE.
from collections import defaultdict
from typing import Any, Dict, List, Optional, Tuple

import numpy as np

try:
    from matplotlib import pyplot as plt
    from matplotlib.animation import FuncAnimation
    from matplotlib.lines import Line2D
except:
    raise ImportError(
        "Missing dependencies for Waymo. Install them using the command `pip install -e .[waymo]` at the source directory."
    )

from smarts.core.utils.file import read_tfrecord_file
from smarts.waymo.waymo_open_dataset.protos import scenario_pb2

MAP_HANDLES = [
    Line2D([0], [0], linestyle=":", color="gray", label="Lane Polyline"),
    Line2D([0], [0], linestyle="-", color="yellow", label="Single Road Line"),
    Line2D([0], [0], linestyle="--", color="yellow", label="Double Road Line"),
    Line2D([0], [0], linestyle="-", color="black", label="Road Edge"),
    Line2D([0], [0], linestyle="--", color="black", label="Crosswalk"),
    Line2D([0], [0], linestyle=":", color="black", label="Speed Bump"),
    Line2D(
        [],
        [],
        color="red",
        marker="o",
        linestyle="None",
        markersize=5,
        label="Stop Sign",
    ),
]

TRAJECTORY_HANDLES = [
    Line2D(
        [],
        [],
        color="cyan",
        marker="^",
        linestyle="None",
        markersize=5,
        label="Ego Vehicle",
    ),
    Line2D(
        [],
        [],
        color="black",
        marker="^",
        linestyle="None",
        markersize=5,
        label="Car",
    ),
    Line2D(
        [],
        [],
        color="magenta",
        marker="d",
        linestyle="None",
        markersize=5,
        label="Pedestrian",
    ),
    Line2D(
        [],
        [],
        color="yellow",
        marker="*",
        linestyle="None",
        markersize=5,
        label="Cyclist",
    ),
    Line2D(
        [],
        [],
        color="black",
        marker="8",
        linestyle="None",
        markersize=5,
        label="Other",
    ),
]


def _create_interactive_handle(object_type: int, v_id: int) -> Line2D:
    if object_type == 1:
        return Line2D(
            [],
            [],
            color="green",
            marker="^",
            linestyle="None",
            markersize=5,
            label=f"Interactive Car {v_id}",
        )
    elif object_type == 2:
        return Line2D(
            [],
            [],
            color="green",
            marker="d",
            linestyle="None",
            markersize=5,
            label=f"Interactive Pedestrian {v_id}",
        )
    elif object_type == 3:
        return Line2D(
            [],
            [],
            color="green",
            marker="*",
            linestyle="None",
            markersize=5,
            label=f"Interactive Cyclist {v_id}",
        )
    else:
        Line2D(
            [],
            [],
            color="green",
            marker="8",
            linestyle="None",
            markersize=5,
            label=f"Interactive Other {v_id}",
        )


def _plot_map_features(map_features: Dict):
    for lane in map_features["lane"]:
        pts = np.array([[p.x, p.y] for p in lane[0].polyline])
        plt.plot(pts[:, 0], pts[:, 1], linestyle=":", color="gray")

    for road_line in map_features["road_line"]:
        pts = np.array([[p.x, p.y] for p in road_line[0].polyline])
        xs, ys = pts[:, 0], pts[:, 1]
        if road_line[0].type in [1, 4, 5]:
            plt.plot(pts[:, 0], pts[:, 1], "y--")
        else:
            plt.plot(pts[:, 0], pts[:, 1], "y-")

    for road_edge in map_features["road_edge"]:
        pts = np.array([[p.x, p.y] for p in road_edge[0].polyline])
        plt.plot(pts[:, 0], pts[:, 1], "k-")

    for crosswalk in map_features["crosswalk"]:
        poly_points = [[p.x, p.y] for p in crosswalk[0].polygon]
        poly_points.append(poly_points[0])
        pts = np.array(poly_points)
        plt.plot(pts[:, 0], pts[:, 1], "k--")

    for speed_bump in map_features["speed_bump"]:
        poly_points = [[p.x, p.y] for p in speed_bump[0].polygon]
        poly_points.append(poly_points[0])
        pts = np.array(poly_points)
        plt.plot(pts[:, 0], pts[:, 1], "k:")

    for stop_sign in map_features["stop_sign"]:
        plt.scatter(
            stop_sign[0].position.x,
            stop_sign[0].position.y,
            marker="o",
            c="red",
            alpha=1,
        )


def _get_map_features(scenario) -> Dict[str, List]:
    map_features = defaultdict(lambda: [])
    for i in range(len(scenario.map_features)):
        map_feature = scenario.map_features[i]
        key = map_feature.WhichOneof("feature_data")
        if key is not None:
            map_features[key].append((getattr(map_feature, key), str(map_feature.id)))
    return map_features


def _get_trajectories(scenario) -> Dict[int, Dict[str, Any]]:
    num_steps = len(scenario.timestamps_seconds)
    trajectories = defaultdict(lambda: {"positions": [None] * num_steps})
    for i in range(len(scenario.tracks)):
        vehicle_id = scenario.tracks[i].id
        trajectories[vehicle_id]["is_ego"] = i == scenario.sdc_track_index
        trajectories[vehicle_id]["object_type"] = scenario.tracks[i].object_type
        for j in range(num_steps):
            obj_state = scenario.tracks[i].states[j]
            trajectories[vehicle_id]["positions"][j] = (
                obj_state.center_x if obj_state.valid else None,
                obj_state.center_y if obj_state.valid else None,
            )
    return trajectories


def _plot_trajectories(
    trajectories: Dict[int, Dict[str, Any]],
    interactive_ids: List[int],
) -> Tuple[List[Line2D], List[Optional[Tuple[list, list]]], List[Line2D]]:
    points, data, handles = [], [], []

    # Need to plot something initially to get handles to the point objects,
    # so just use a valid point from the first trajectory
    first_traj = list(trajectories.values())[0]["positions"]
    ind = None
    for i in range(len(first_traj)):
        if first_traj[i][0] is not None:
            ind = i
            break
    assert ind is not None, "No valid point in first trajectory"
    x0 = first_traj[ind][0]
    y0 = first_traj[ind][1]
    for v_id, props in trajectories.items():
        xs = [p[0] for p in props["positions"]]
        ys = [p[1] for p in props["positions"]]

        if props["is_ego"]:
            (point,) = plt.plot(x0, y0, "c^")
            continue

        is_interactive = int(v_id) in interactive_ids
        object_type = props["object_type"]
        if is_interactive:
            handles.append(_create_interactive_handle(object_type, v_id))

        if object_type == 1:
            if is_interactive:
                (point,) = plt.plot(x0, y0, "g^")
            else:
                (point,) = plt.plot(x0, y0, "k^")
        elif object_type == 2:
            if is_interactive:
                (point,) = plt.plot(x0, y0, "gd")
            else:
                (point,) = plt.plot(x0, y0, "md")
        elif object_type == 3:
            if is_interactive:
                (point,) = plt.plot(x0, y0, "g*")
            else:
                (point,) = plt.plot(x0, y0, "y*")
        else:
            if is_interactive:
                (point,) = plt.plot(x0, y0, "g8")
            else:
                (point,) = plt.plot(x0, y0, "k8")
        data.append((xs, ys))
        points.append(point)
    return points, data, handles


[docs]def get_tfrecord_info(tfrecord_file: str) -> Dict[str, Dict[str, Any]]: """Extract info about each scenario in the TFRecord file.""" scenarios = dict() records = read_tfrecord_file(tfrecord_file) for record in records: scenario = scenario_pb2.Scenario() scenario.ParseFromString(bytes(record)) scenario_id = scenario.scenario_id num_vehicles = 0 num_pedestrians = 0 for track in scenario.tracks: if track.object_type == 1: num_vehicles += 1 elif track.object_type == 2: num_pedestrians += 1 scenarios[scenario_id] = { "timestamps": len(scenario.timestamps_seconds), "vehicles": num_vehicles, "pedestrians": num_pedestrians, } return scenarios
[docs]def plot_scenario( tfrecord_file: str, scenario_id: str, animate: bool, label_vehicles: bool, ): """Plot the map features of a Waymo scenario, and optionally plot/animate the vehicle trajectories.""" from smarts.core.waymo_map import WaymoMap source = f"{tfrecord_file}#{scenario_id}" scenario = WaymoMap.parse_source_to_scenario(source) fig = plt.figure() mng = plt.get_current_fig_manager() mng.resize(1000, 1000) map_features = _get_map_features(scenario) handles = MAP_HANDLES if label_vehicles: handles.extend(TRAJECTORY_HANDLES) trajectories = _get_trajectories(scenario) for v_id, props in trajectories.items(): valid_pts = [p for p in props["positions"] if p[0] is not None] if len(valid_pts) > 0: x = valid_pts[0][0] y = valid_pts[0][1] plt.scatter(x, y, marker="o", c="blue") bbox_props = dict(boxstyle="square,pad=0.1", fc="white", ec=None) plt.text(x + 1, y + 1, f"{v_id}", bbox=bbox_props) elif animate: trajectories = _get_trajectories(scenario) interactive_ids = [i for i in scenario.objects_of_interest] points, data, interactive_handles = _plot_trajectories( trajectories, interactive_ids ) handles.extend(TRAJECTORY_HANDLES + interactive_handles) def update(i): drawn_pts = [] for (xs, ys), point in zip(data, points): if i < len(xs) and xs[i] is not None and ys[i] is not None: point.set_data(xs[i], ys[i]) drawn_pts.append(point) return drawn_pts num_steps = len(scenario.timestamps_seconds) anim = FuncAnimation( fig, update, frames=range(1, num_steps), blit=True, interval=100 ) _plot_map_features(map_features) plt.title(f"Scenario {scenario_id}") plt.legend(handles=handles) plt.axis("equal") plt.show()
[docs]def gen_smarts_scenario_code(dataset_path: str, scenario_id: str) -> str: """Generate source code for the ``scenario.py`` of a SMARTS scenario for a Waymo scenario.""" return f"""from pathlib import Path from smarts.sstudio import gen_scenario from smarts.sstudio import types as t dataset_path = "{dataset_path}" scenario_id = "{scenario_id}" traffic_histories = [ t.TrafficHistoryDataset( name=f"waymo", source_type="Waymo", input_path=dataset_path, scenario_id=scenario_id, ) ] gen_scenario( t.Scenario( map_spec=t.MapSpec( source=f"{{dataset_path}}#{{scenario_id}}", lanepoint_spacing=1.0 ), traffic_histories=traffic_histories, ), output_dir=Path(__file__).parent, ) """