Source code for smarts.env.gymnasium.wrappers.parallel_env

# MIT License
#
# Copyright (C) 2021. Huawei Technologies Co., Ltd. All rights reserved.
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
# in the Software without restriction, including without limitation the rights
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
# copies of the Software, and to permit persons to whom the Software is
# furnished to do so, subject to the following conditions:
#
# The above copyright notice and this permission notice shall be included in
# all copies or substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NON-INFRINGEMENT. IN NO EVENT SHALL THE
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
# THE SOFTWARE.

import multiprocessing as mp
import sys
import traceback
import warnings
from enum import Enum
from typing import Any, Callable, Dict, Sequence, Tuple

import cloudpickle
import gymnasium as gym

__all__ = ["ParallelEnv"]


EnvConstructor = Callable[[int], gym.Env]


class _Message(Enum):
    SEED = 1
    ACCESS = 2
    RESET = 3
    STEP = 4
    RESULT = 5
    CLOSE = 6
    EXCEPTION = 7


[docs]class ParallelEnv(object): """Batch together multiple environments and step them in parallel. Each environment is simulated in an external process for lock-free parallelism using `multiprocessing` processes, and pipes for communication. Note: Simulation might slow down when number of parallel environments requested exceed number of available CPUs. """ def __init__( self, env_constructors: Sequence[EnvConstructor], auto_reset: bool, seed: int = 42, ): """The environments can be different but must use the same action and observation spaces. Args: env_constructors (Sequence[EnvConstructor]): List of callables that create environments. auto_reset (bool): Automatically resets an environment when episode ends. seed (int, optional): Seed for the first environment. Defaults to 42. Raises: TypeError: If any environment constructor is not callable. ValueError: If the action or observation spaces do not match. """ if len(env_constructors) > mp.cpu_count(): warnings.warn( f"Simulation might slow down, as the requested number of parallel " f"environments ({len(env_constructors)}) exceed the number of available " f"CPUs ({mp.cpu_count()}).", ResourceWarning, ) if any([not callable(ctor) for ctor in env_constructors]): raise TypeError( f"Found non-callable `env_constructors`. Expected `env_constructors` of type " f"`Sequence[Callable[[int], gym.Env]]`, but got {env_constructors})." ) self._num_envs = len(env_constructors) self._polling_period = 0.1 self._closed = False # Fork is not a thread safe method. forkserver_available = "forkserver" in mp.get_all_start_methods() start_method = "forkserver" if forkserver_available else "spawn" mp_ctx = mp.get_context(start_method) self._parent_pipes = [] self._processes = [] for idx, env_constructor in enumerate(env_constructors): cur_seed = seed + idx parent_pipe, child_pipe = mp_ctx.Pipe() process = mp_ctx.Process( target=_worker, name=f"Worker-<{type(self).__name__}>-<{idx}>", args=( cloudpickle.dumps(env_constructor), cur_seed, auto_reset, child_pipe, self._polling_period, ), ) self._parent_pipes.append(parent_pipe) self._processes.append(process) # Daemonic subprocesses quit when parent process quits. However, daemonic # processes cannot spawn children. Hence, `process.daemon` is set to False. process.daemon = False process.start() child_pipe.close() self._wait_start() self._single_observation_space, self._single_action_space = self._get_spaces() @property def batch_size(self) -> int: """The number of environments.""" return self._num_envs @property def observation_space(self) -> gym.Space: """The environment's observation space in gym representation.""" return self._single_observation_space @property def action_space(self) -> gym.Space: """The environment's action space in gym representation.""" return self._single_action_space def _call(self, msg: _Message, payloads: Sequence[Any]) -> Sequence[Any]: assert len(payloads) == self._num_envs for pipe, payload in zip(self._parent_pipes, payloads): pipe.send((msg, payload)) return self._recv() def _recv(self) -> Sequence[Any]: messages = [] payloads = [] for pipe in self._parent_pipes: message, payload = pipe.recv() messages.append(message) payloads.append(payload) for messages, payload in zip(messages, payloads): if message == _Message.EXCEPTION: worker_name, stacktrace = payload self.close() raise Exception(f"\n{worker_name}\n{stacktrace}") return payloads def _wait_start(self): self._recv() def _get_spaces(self) -> Tuple[gym.Space, gym.Space]: observation_spaces = self._call( _Message.ACCESS, ["observation_space"] * self._num_envs ) observation_space = observation_spaces[0] if any([space != observation_space for space in observation_spaces]): raise ValueError( f"Expected all environments to have the same observation space, " f"but got {observation_spaces}." ) action_spaces = self._call(_Message.ACCESS, ["action_space"] * self._num_envs) action_space = action_spaces[0] if any([space != action_space for space in action_spaces]): raise ValueError( f"Expected all environments to have the same action space, " f"but got {action_spaces}." ) return observation_space, action_space
[docs] def seed(self) -> Sequence[int]: """Retrieves the seed used in each environment. Returns: Sequence[int]: Seed of each environment. """ seeds = self._call(_Message.SEED, [None] * self._num_envs) return seeds
[docs] def reset(self) -> Tuple[Sequence[Dict[str, Any]], Sequence[Dict[str, Any]]]: """Reset all environments. Returns: Tuple[Sequence[Dict[str, Any]], Sequence[Dict[str, Any]]]: A batch of observations and infos from the vectorized environment. """ # since the return is [(obs0, infos0), ...] they need to be zipped to form. # [(obs0, ...), (infos0, ...)] observations, infos = zip(*self._call(_Message.RESET, [None] * self._num_envs)) return observations, infos
[docs] def step( self, actions: Sequence[Dict[str, Any]] ) -> Tuple[ Sequence[Dict[str, Any]], Sequence[Dict[str, float]], Sequence[Dict[str, bool]], Sequence[Dict[str, bool]], Sequence[Dict[str, Any]], ]: """Steps all environments. Args: actions (Sequence[Dict[str,Any]]): Actions for each environment. Returns: Tuple[ Sequence[Dict[str, Any]], Sequence[Dict[str, float]], Sequence[Dict[str, bool]], Sequence[Dict[str, bool]], Sequence[Dict[str, Any]] ]: A batch of (observations, rewards, terminateds, truncateds, infos) from the vectorized environment. """ result = self._call(_Message.STEP, actions) observations, rewards, terminateds, truncateds, infos = zip(*result) return (observations, rewards, terminateds, truncateds, infos)
[docs] def close(self, terminate=False): """Sends a close message to all external processes. Args: terminate (bool, optional): If `True`, then the `close` operation is forced and all processes are terminated. Defaults to False. """ if terminate: for process in self._processes: if process.is_alive(): process.terminate() else: for pipe in self._parent_pipes: try: pipe.send((_Message.CLOSE, None)) pipe.close() except IOError: # The connection was already closed. pass for process in self._processes: if process.is_alive(): process.join() self._closed = True
def __del__(self): if not self._closed: self.close(terminate=True)
def _worker( env_constructor: bytes, seed: int, auto_reset: bool, pipe: mp.connection.Connection, polling_period: float = 0.1, ): """Process to build and run an environment. Using a pipe to communicate with parent, the process receives action, steps the environment, and returns the observations. Args: env_constructor (bytes): Cloudpickled callable which constructs the environment. seed (int): Seed for the environment. auto_reset (bool): If True, auto resets environment when episode ends. pipe (mp.connection.Connection): Child's end of the pipe. polling_period (float, optional): Time to wait for keyboard interrupts. Defaults to 0.1. Raises: KeyError: If unknown message type is received. """ env = cloudpickle.loads(env_constructor)(seed=seed) pipe.send((_Message.RESULT, None)) try: while True: if not pipe.poll(polling_period): continue message, payload = pipe.recv() if message == _Message.SEED: env_seed = env.seed pipe.send((_Message.RESULT, env_seed)) elif message == _Message.ACCESS: result = getattr(env, payload, None) pipe.send((_Message.RESULT, result)) elif message == _Message.RESET: observation, info = env.reset() pipe.send((_Message.RESULT, (observation, info))) elif message == _Message.STEP: observation, reward, terminated, truncated, info = env.step(payload) if terminated["__all__"] and auto_reset: # Final observation can be obtained from `info` as follows: # `final_obs = info[agent_id]["env_obs"]` observation, _ = env.reset() pipe.send( ( _Message.RESULT, (observation, reward, terminated, truncated, info), ) ) elif message == _Message.CLOSE: break else: raise KeyError( f"Expected message from {_Message.__members__}, but got unknown message `{message}`." ) except (Exception, KeyboardInterrupt): etype, evalue, tb = sys.exc_info() if etype == KeyboardInterrupt: stacktrace = "".join(traceback.format_exception(etype, evalue, None)) else: stacktrace = "".join(traceback.format_exception(etype, evalue, tb)) payload = (mp.current_process().name, stacktrace) pipe.send((_Message.EXCEPTION, payload)) finally: env.close() pipe.close()