Source code for smarts.core.utils.episodes

# Copyright (C) 2020. Huawei Technologies Co., Ltd. All rights reserved.
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
# in the Software without restriction, including without limitation the rights
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
# copies of the Software, and to permit persons to whom the Software is
# furnished to do so, subject to the following conditions:
#
# The above copyright notice and this permission notice shall be included in
# all copies or substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NON-INFRINGEMENT. IN NO EVENT SHALL THE
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
# THE SOFTWARE.

import os
import time
from collections import defaultdict
from dataclasses import dataclass, field
from typing import Optional, Union

import tableprint as tp


[docs]class EpisodeLogs: """An episode logging utility.""" def __init__(self, col_width, total_episodes: Union[str, int] = "?") -> None: self._col_width = col_width self._table = self.context(col_width) self._current_episode: Optional[EpisodeLog] = None self._total_episodes = total_episodes self._current_episode_num = 0
[docs] def reset(self) -> "EpisodeLog": """Record an episode reset.""" e = self._current_episode if e: self._write_row() self._current_episode_num += 1 self._current_episode = EpisodeLog(self._current_episode_num) return self._current_episode
def _write_row(self): assert isinstance(self._current_episode, EpisodeLog) e = self._current_episode row = ( f"{e.index + 1}/{self._total_episodes}", f"{e.sim2wall_ratio:.2f}", e.steps, f"{e.steps_per_second:.2f}", os.path.basename(e.scenario_map)[: self._col_width], e.scenario_traffic[: self._col_width], e.mission_hash[: self._col_width], ) score_summaries = [ f"{score:.2f} - {agent}" for agent, score in e.scores.items() ] if len(score_summaries) == 0: self._table(row + ("",)) else: self._table(row + (score_summaries[0],)) if len(score_summaries) > 1: for s in score_summaries[1:]: self._table(("", "", "", "", "", "", "", s)) def __enter__(self): self._table.__enter__() return self def __exit__(self, *exc): self._table.__exit__(*exc)
[docs] @staticmethod def context(col_width): """Generate a formatted table context object.""" return tp.TableContext( [ "Episode", "Sim T / Wall T", "Total Steps", "Steps / Sec", "Scenario Map", "Scenario Routes", "Mission (Hash)", "Scores", ], width=col_width, style="round", )
[docs]@dataclass class EpisodeLog: """An episode logging tool.""" index: int = 0 start_time: float = field(default_factory=lambda: time.time()) fixed_timestep_sec: float = 0 scores: dict = field(default_factory=lambda: defaultdict(lambda: 0)) steps: int = 0 scenario_map: str = "" scenario_traffic: str = "" mission_hash: str = "" @property def wall_time(self): """Time elapsed since instantiation.""" return time.time() - self.start_time @property def sim_time(self): """An estimation of the total fixed-time-step simulation performed.""" return self.fixed_timestep_sec * self.steps @property def sim2wall_ratio(self): """The ration of sim time to wall time. Above 1 is hyper-real-time.""" return self.sim_time / self.wall_time @property def steps_per_second(self): """The rate of steps performed since instantiation.""" return self.steps / self.wall_time
[docs] def record_scenario(self, scenario_log): """Record a scenario end.""" self.fixed_timestep_sec = scenario_log["fixed_timestep_sec"] self.scenario_map = scenario_log["scenario_map"] self.scenario_traffic = scenario_log.get( "scenario_traffic", scenario_log.get("scenario_routes", "") ) self.mission_hash = scenario_log["mission_hash"]
[docs] def record_step(self, observations, rewards, terminateds, truncateds, infos): """Record a step end.""" self.steps += 1 if not isinstance(terminateds, dict): ( observations, rewards, terminateds, truncateds, infos, ) = self._convert_to_dict( observations, rewards, terminateds, truncateds, infos ) if terminateds.get("__all__", False) and infos is not None: for agent, score in infos.items(): self.scores[agent] = score["score"] else: for id in (_id for _id, t in terminateds.items() if t): self.scores[id] = infos[id]["score"]
def _convert_to_dict(self, observations, rewards, terminateds, truncateds, infos): observations, rewards, infos = [ {"SingleAgent": obj} for obj in [observations, rewards, infos] ] terminateds = {"SingleAgent": terminateds, "__all__": terminateds} truncateds = {"SingleAgent": truncateds, "__all__": truncateds} return observations, rewards, terminateds, truncateds, infos
[docs]def episodes(n): """An iteration method that provides numbered episodes. Acts similar to python's `range(n)` but yielding episode loggers. """ col_width = 18 with EpisodeLogs(col_width, n) as episode_logs: for _ in range(n): yield episode_logs.reset() episode_logs.reset()
[docs]@dataclass class Episodes: """An episode counter utility.""" max_steps: int current_step: int = 0 def __enter__(self): return self def __exit__(self, *exception): pass
[docs]class Episode: """An episode recording object""" def __init__(self, episodes: Episodes): self._episodes = episodes
[docs] def continues(self, observation, reward, terminated, truncated, info) -> bool: """Determine if the current episode can continue.""" self._episodes.current_step += 1 if self._episodes.current_step >= self._episodes.max_steps: return False if isinstance(terminated, dict): return not terminated.get("__all__", all(terminated.values())) return not terminated
[docs]def episode_range(max_steps): """An iteration method that provides a range of episodes that meets the given max steps.""" with Episodes(max_steps=max_steps) as episodes: while episodes.current_step < episodes.max_steps: yield Episode(episodes=episodes)