# Copyright (C) 2022. Huawei Technologies Co., Ltd. All rights reserved.
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
# in the Software without restriction, including without limitation the rights
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
# copies of the Software, and to permit persons to whom the Software is
# furnished to do so, subject to the following conditions:
#
# The above copyright notice and this permission notice shall be included in
# all copies or substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NON-INFRINGEMENT. IN NO EVENT SHALL THE
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
# THE SOFTWARE.
from collections import defaultdict
from typing import Any, Dict, List, Optional, Tuple
import numpy as np
try:
from matplotlib import pyplot as plt
from matplotlib.animation import FuncAnimation
from matplotlib.lines import Line2D
except:
raise ImportError(
"Missing dependencies for Waymo. Install them using the command `pip install -e .[waymo]` at the source directory."
)
from smarts.core.utils.file import read_tfrecord_file
from smarts.waymo.waymo_open_dataset.protos import scenario_pb2
MAP_HANDLES = [
Line2D([0], [0], linestyle=":", color="gray", label="Lane Polyline"),
Line2D([0], [0], linestyle="-", color="yellow", label="Single Road Line"),
Line2D([0], [0], linestyle="--", color="yellow", label="Double Road Line"),
Line2D([0], [0], linestyle="-", color="black", label="Road Edge"),
Line2D([0], [0], linestyle="--", color="black", label="Crosswalk"),
Line2D([0], [0], linestyle=":", color="black", label="Speed Bump"),
Line2D(
[],
[],
color="red",
marker="o",
linestyle="None",
markersize=5,
label="Stop Sign",
),
]
TRAJECTORY_HANDLES = [
Line2D(
[],
[],
color="cyan",
marker="^",
linestyle="None",
markersize=5,
label="Ego Vehicle",
),
Line2D(
[],
[],
color="black",
marker="^",
linestyle="None",
markersize=5,
label="Car",
),
Line2D(
[],
[],
color="magenta",
marker="d",
linestyle="None",
markersize=5,
label="Pedestrian",
),
Line2D(
[],
[],
color="yellow",
marker="*",
linestyle="None",
markersize=5,
label="Cyclist",
),
Line2D(
[],
[],
color="black",
marker="8",
linestyle="None",
markersize=5,
label="Other",
),
]
def _create_interactive_handle(object_type: int, v_id: int) -> Line2D:
if object_type == 1:
return Line2D(
[],
[],
color="green",
marker="^",
linestyle="None",
markersize=5,
label=f"Interactive Car {v_id}",
)
elif object_type == 2:
return Line2D(
[],
[],
color="green",
marker="d",
linestyle="None",
markersize=5,
label=f"Interactive Pedestrian {v_id}",
)
elif object_type == 3:
return Line2D(
[],
[],
color="green",
marker="*",
linestyle="None",
markersize=5,
label=f"Interactive Cyclist {v_id}",
)
else:
Line2D(
[],
[],
color="green",
marker="8",
linestyle="None",
markersize=5,
label=f"Interactive Other {v_id}",
)
def _plot_map_features(map_features: Dict):
for lane in map_features["lane"]:
pts = np.array([[p.x, p.y] for p in lane[0].polyline])
plt.plot(pts[:, 0], pts[:, 1], linestyle=":", color="gray")
for road_line in map_features["road_line"]:
pts = np.array([[p.x, p.y] for p in road_line[0].polyline])
xs, ys = pts[:, 0], pts[:, 1]
if road_line[0].type in [1, 4, 5]:
plt.plot(pts[:, 0], pts[:, 1], "y--")
else:
plt.plot(pts[:, 0], pts[:, 1], "y-")
for road_edge in map_features["road_edge"]:
pts = np.array([[p.x, p.y] for p in road_edge[0].polyline])
plt.plot(pts[:, 0], pts[:, 1], "k-")
for crosswalk in map_features["crosswalk"]:
poly_points = [[p.x, p.y] for p in crosswalk[0].polygon]
poly_points.append(poly_points[0])
pts = np.array(poly_points)
plt.plot(pts[:, 0], pts[:, 1], "k--")
for speed_bump in map_features["speed_bump"]:
poly_points = [[p.x, p.y] for p in speed_bump[0].polygon]
poly_points.append(poly_points[0])
pts = np.array(poly_points)
plt.plot(pts[:, 0], pts[:, 1], "k:")
for stop_sign in map_features["stop_sign"]:
plt.scatter(
stop_sign[0].position.x,
stop_sign[0].position.y,
marker="o",
c="red",
alpha=1,
)
def _get_map_features(scenario) -> Dict[str, List]:
map_features = defaultdict(lambda: [])
for i in range(len(scenario.map_features)):
map_feature = scenario.map_features[i]
key = map_feature.WhichOneof("feature_data")
if key is not None:
map_features[key].append((getattr(map_feature, key), str(map_feature.id)))
return map_features
def _get_trajectories(scenario) -> Dict[int, Dict[str, Any]]:
num_steps = len(scenario.timestamps_seconds)
trajectories = defaultdict(lambda: {"positions": [None] * num_steps})
for i in range(len(scenario.tracks)):
vehicle_id = scenario.tracks[i].id
trajectories[vehicle_id]["is_ego"] = i == scenario.sdc_track_index
trajectories[vehicle_id]["object_type"] = scenario.tracks[i].object_type
for j in range(num_steps):
obj_state = scenario.tracks[i].states[j]
trajectories[vehicle_id]["positions"][j] = (
obj_state.center_x if obj_state.valid else None,
obj_state.center_y if obj_state.valid else None,
)
return trajectories
def _plot_trajectories(
trajectories: Dict[int, Dict[str, Any]],
interactive_ids: List[int],
) -> Tuple[List[Line2D], List[Optional[Tuple[list, list]]], List[Line2D]]:
points, data, handles = [], [], []
# Need to plot something initially to get handles to the point objects,
# so just use a valid point from the first trajectory
first_traj = list(trajectories.values())[0]["positions"]
ind = None
for i in range(len(first_traj)):
if first_traj[i][0] is not None:
ind = i
break
assert ind is not None, "No valid point in first trajectory"
x0 = first_traj[ind][0]
y0 = first_traj[ind][1]
for v_id, props in trajectories.items():
xs = [p[0] for p in props["positions"]]
ys = [p[1] for p in props["positions"]]
if props["is_ego"]:
(point,) = plt.plot(x0, y0, "c^")
continue
is_interactive = int(v_id) in interactive_ids
object_type = props["object_type"]
if is_interactive:
handles.append(_create_interactive_handle(object_type, v_id))
if object_type == 1:
if is_interactive:
(point,) = plt.plot(x0, y0, "g^")
else:
(point,) = plt.plot(x0, y0, "k^")
elif object_type == 2:
if is_interactive:
(point,) = plt.plot(x0, y0, "gd")
else:
(point,) = plt.plot(x0, y0, "md")
elif object_type == 3:
if is_interactive:
(point,) = plt.plot(x0, y0, "g*")
else:
(point,) = plt.plot(x0, y0, "y*")
else:
if is_interactive:
(point,) = plt.plot(x0, y0, "g8")
else:
(point,) = plt.plot(x0, y0, "k8")
data.append((xs, ys))
points.append(point)
return points, data, handles
[docs]def get_tfrecord_info(tfrecord_file: str) -> Dict[str, Dict[str, Any]]:
"""Extract info about each scenario in the TFRecord file."""
scenarios = dict()
records = read_tfrecord_file(tfrecord_file)
for record in records:
scenario = scenario_pb2.Scenario()
scenario.ParseFromString(bytes(record))
scenario_id = scenario.scenario_id
num_vehicles = 0
num_pedestrians = 0
for track in scenario.tracks:
if track.object_type == 1:
num_vehicles += 1
elif track.object_type == 2:
num_pedestrians += 1
scenarios[scenario_id] = {
"timestamps": len(scenario.timestamps_seconds),
"vehicles": num_vehicles,
"pedestrians": num_pedestrians,
}
return scenarios
[docs]def plot_scenario(
tfrecord_file: str,
scenario_id: str,
animate: bool,
label_vehicles: bool,
):
"""Plot the map features of a Waymo scenario,
and optionally plot/animate the vehicle trajectories."""
from smarts.core.waymo_map import WaymoMap
source = f"{tfrecord_file}#{scenario_id}"
scenario = WaymoMap.parse_source_to_scenario(source)
fig = plt.figure()
mng = plt.get_current_fig_manager()
mng.resize(1000, 1000)
map_features = _get_map_features(scenario)
handles = MAP_HANDLES
if label_vehicles:
handles.extend(TRAJECTORY_HANDLES)
trajectories = _get_trajectories(scenario)
for v_id, props in trajectories.items():
valid_pts = [p for p in props["positions"] if p[0] is not None]
if len(valid_pts) > 0:
x = valid_pts[0][0]
y = valid_pts[0][1]
plt.scatter(x, y, marker="o", c="blue")
bbox_props = dict(boxstyle="square,pad=0.1", fc="white", ec=None)
plt.text(x + 1, y + 1, f"{v_id}", bbox=bbox_props)
elif animate:
trajectories = _get_trajectories(scenario)
interactive_ids = [i for i in scenario.objects_of_interest]
points, data, interactive_handles = _plot_trajectories(
trajectories, interactive_ids
)
handles.extend(TRAJECTORY_HANDLES + interactive_handles)
def update(i):
drawn_pts = []
for (xs, ys), point in zip(data, points):
if i < len(xs) and xs[i] is not None and ys[i] is not None:
point.set_data(xs[i], ys[i])
drawn_pts.append(point)
return drawn_pts
num_steps = len(scenario.timestamps_seconds)
anim = FuncAnimation(
fig, update, frames=range(1, num_steps), blit=True, interval=100
)
_plot_map_features(map_features)
plt.title(f"Scenario {scenario_id}")
plt.legend(handles=handles)
plt.axis("equal")
plt.show()
[docs]def gen_smarts_scenario_code(dataset_path: str, scenario_id: str) -> str:
"""Generate source code for the ``scenario.py`` of a SMARTS scenario for a Waymo scenario."""
return f"""from pathlib import Path
from smarts.sstudio import gen_scenario
from smarts.sstudio import types as t
dataset_path = "{dataset_path}"
scenario_id = "{scenario_id}"
traffic_histories = [
t.TrafficHistoryDataset(
name=f"waymo",
source_type="Waymo",
input_path=dataset_path,
scenario_id=scenario_id,
)
]
gen_scenario(
t.Scenario(
map_spec=t.MapSpec(
source=f"{{dataset_path}}#{{scenario_id}}", lanepoint_spacing=1.0
),
traffic_histories=traffic_histories,
),
output_dir=Path(__file__).parent,
)
"""