Source code for zoo.policies.replay_agent

"""
This file contains an agent used for replaying an agent.
"""
import logging
import os
import pickle
from pathlib import Path

from smarts.core.agent import Agent
from smarts.zoo.agent_spec import AgentSpec

agent_index = 0


[docs]class ReplayAgent(Agent): """A helper agent that wraps another agent to allow replay of the agent inputs and actions Look at `examples/replay/README.md` on how to use this agent.""" def __init__(self, save_directory, id, read: bool, internal_spec: AgentSpec): import smarts.core if smarts.core.current_seed() is None: smarts.core.seed(42) self.save_directory = save_directory self._base_agent = internal_spec.build_agent() self._logger = logging.getLogger(self.__class__.__name__) global agent_index self.id = f"{id}_{agent_index}" agent_index += 1 abs_path = os.path.abspath(save_directory) self._read = read file_mode = "wb" if not read else "rb" path = Path(f"{abs_path}/{self.id}") os.makedirs(abs_path, exist_ok=True) try: self._file = path.open(mode=file_mode) except FileNotFoundError as e: assert self._read self._logger.error( f"The file which you are trying to be read does not exist. " f"Make sure the {save_directory} directory passed is correct and has the agent file which is being read" ) raise e def __del__(self): if self._file: self._file.close()
[docs] def act(self, obs): if self._read: base_action = self._base_agent.act(obs) try: action = pickle.load(self._file) assert action == base_action except AssertionError as e: self._logger.debug("The Base Agent's action and new action don't match") raise e except Exception as e: self._logger.error( "Comparing the new action with the base agent action raise an unknown error" ) print(e) action = base_action else: action = self._base_agent.act(obs) pickle.dump(action, self._file, protocol=1) return action