Source code for smarts.zoo.agent_spec

# Copyright (C) 2021. Huawei Technologies Co., Ltd. All rights reserved.
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
# in the Software without restriction, including without limitation the rights
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
# copies of the Software, and to permit persons to whom the Software is
# furnished to do so, subject to the following conditions:
#
# The above copyright notice and this permission notice shall be included in
# all copies or substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NON-INFRINGEMENT. IN NO EVENT SHALL THE
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
# THE SOFTWARE.
import inspect
import warnings
from dataclasses import dataclass
from typing import Any, Callable, Optional

import cloudpickle

from smarts.core.agent import Agent
from smarts.core.agent_interface import AgentInterface


[docs]@dataclass class AgentSpec(object): """A configuration that is used by SMARTS environments. .. code-block:: python agent_spec = AgentSpec( interface=AgentInterface.from_type(AgentType.Laner), agent_params={"agent_function": lambda _: "keep_lane"}, agent_builder=Agent.from_function, ) env = gym.make( "smarts.env:hiway-v1", scenarios=["scenarios/sumo/loop"], agent_specs={agent_id: agent_spec}, ) agent = agent_spec.build_agent() Refer to the Agent documentation. """ # This is optional because sometimes when building re-useable specs, # you don't know the agent interface ahead of time. interface: Optional[AgentInterface] = None """the adaptor used to wrap agent observation and action flow (default None)""" agent_builder: Optional[Callable[..., Agent]] = None """A callable to build an `smarts.core.agent.Agent` given `AgentSpec.agent_params` (default None)""" agent_params: Optional[Any] = None """Parameters to be given to `AgentSpec.agent_builder` (default None)""" observation_adapter: Callable = lambda obs: obs """Deprecated. Do not use. An adaptor that allows shaping of the observations. Defaults to `lambda obs: obs`.""" action_adapter: Callable = lambda act: act """Deprecated. Do not use. An adaptor that allows shaping of the action. Defaults to `lambda act: act`.""" reward_adapter: Callable = lambda obs, reward: reward """Deprecated. Do not use. An adaptor that allows shaping of the reward. Defaults to `lambda obs, reward: reward`.""" info_adapter: Callable = lambda obs, reward, info: info """Deprecated. Do not use. An adaptor that allows shaping of info. Defaults to `lambda obs, reward, info: info`.""" def __getattribute__(self, item): if ( item == "observation_adapter" or item == "action_adapter" or item == "reward_adapter" or item == "info_adapter" ): warnings.warn( "Observation_adapter, action_adapter, reward_adapter, " "and info_adapter, are deprecated. Do not use them.", DeprecationWarning, stacklevel=2, ) return object.__getattribute__(self, item) def __post_init__(self): # make sure we can pickle ourselves cloudpickle.dumps(self)
[docs] def build_agent(self) -> Agent: """Construct an Agent from the AgentSpec configuration.""" if self.agent_builder is None: raise ValueError("Can't build agent, no agent builder was supplied") if not callable(self.agent_builder): raise ValueError( f"""agent_builder: {self.agent_builder} is not callable Use a combination of agent_params and agent_builder to define how to build your agent, ie. AgentSpec( agent_params={{"input_dimensions": 12}}, agent_builder=MyAgent # we are not instantiating the agent, just passing the class reference ) """ ) if self.agent_params is None: # no args to agent builder return self.agent_builder() elif isinstance(self.agent_params, (list, tuple)): # a list or tuple is treated as positional arguments return self.agent_builder(*self.agent_params) elif isinstance(self.agent_params, dict): # dictionaries, as keyword arguments fas = inspect.getfullargspec(self.agent_builder) if fas[2] is not None: return self.agent_builder(**self.agent_params) else: return self.agent_builder( **{ k: self.agent_params[k] for k in self.agent_params.keys() & set(fas[0]) } ) else: # otherwise, the agent params are sent as is to the builder return self.agent_builder(self.agent_params)