# Copyright (C) 2020. Huawei Technologies Co., Ltd. All rights reserved.
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
# in the Software without restriction, including without limitation the rights
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
# copies of the Software, and to permit persons to whom the Software is
# furnished to do so, subject to the following conditions:
#
# The above copyright notice and this permission notice shall be included in
# all copies or substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NON-INFRINGEMENT. IN NO EVENT SHALL THE
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
# THE SOFTWARE.
"""Importing this module "redirects" the import to the "real" sumolib. This is available
for convenience and to reduce code duplication as sumolib lives under SUMO_HOME.
"""
from __future__ import annotations
import abc
import functools
import inspect
import json
import logging
import multiprocessing
import os
import socket
import subprocess
import sys
import time
from typing import Any, List, Literal, Optional, Tuple
from smarts.core.utils import networking
from smarts.core.utils.core_logging import suppress_output
try:
import sumo
SUMO_PATH = sumo.SUMO_HOME
os.environ["SUMO_HOME"] = sumo.SUMO_HOME
except ImportError:
if "SUMO_HOME" not in os.environ:
raise ImportError("SUMO_HOME not set, can't import sumolib")
SUMO_PATH = os.environ["SUMO_HOME"]
tools_path = os.path.join(SUMO_PATH, "tools")
if tools_path not in sys.path:
sys.path.append(tools_path)
try:
import sumo.tools.sumolib as sumolib
import sumo.tools.traci as traci
except ModuleNotFoundError as e:
raise ImportError(
"Missing dependencies for SUMO. Install them using the command `pip install -e .[sumo]` at the source directory."
) from e
def _safe_close(conn, **kwargs):
try:
conn.close(**kwargs)
except (subprocess.SubprocessError, multiprocessing.ProcessError):
# Subprocess or process failed
pass
except traci.exceptions.FatalTraCIError:
# TraCI connection is already dead.
pass
except AttributeError:
# Socket was destroyed internally, likely due to an error.
pass
except Exception as err:
pass
[docs]class DomainWrapper:
"""Wraps `traci.Domain` type for the `TraciConn` utility"""
def __init__(self, traci_conn, domain: traci.domain.Domain, attribute_name) -> None:
self._domain = domain
self._traci_conn = traci_conn
self._attribute_name = attribute_name
def __getattr__(self, name: str) -> Any:
attribute = getattr(self._domain, name)
if inspect.isbuiltin(attribute) or inspect.ismethod(attribute):
attribute = functools.partial(
_wrap_traci_method,
method=attribute,
traci_conn=self._traci_conn,
attribute_name=self._attribute_name,
)
return attribute
[docs]class SumoProcess(metaclass=abc.ABCMeta):
"""A simplified utility representing a SUMO process."""
[docs] @abc.abstractmethod
def generate(
self, base_params: List[str], sumo_binary: Literal["sumo", "sumo-gui"] = "sumo"
):
"""Generate the process."""
raise NotImplementedError
[docs] @abc.abstractmethod
def terminate(self, kill: bool):
"""Terminate this process."""
raise NotImplementedError
[docs] @abc.abstractmethod
def poll(self) -> Optional[int]:
"""Poll the underlying process."""
raise NotImplementedError
[docs] @abc.abstractmethod
def wait(self, timeout: Optional[float] = None) -> int:
"""Wait on the underlying process."""
raise NotImplementedError
@property
@abc.abstractmethod
def port(self) -> int:
"""The port this process is associated with."""
raise NotImplementedError
@property
@abc.abstractmethod
def host(self) -> str:
"""The port this process is associated with."""
raise NotImplementedError
[docs]class RemoteSumoProcess(SumoProcess):
"""Connects to a sumo server."""
def __init__(self, remote_host, remote_port) -> None:
self._remote_host = remote_host
self._remote_port = remote_port
self._port = None
self._host = None
self._client_socket = None
[docs] def generate(
self, base_params: List[str], sumo_binary: Literal["sumo", "sumo-gui"] = "sumo"
):
client_socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
# Wait on server to start if it needs to.
error = None
for _ in range(5):
try:
client_socket.connect((self._remote_host, self._remote_port))
except OSError as err:
time.sleep(1)
error = err
continue
break
else:
raise OSError(
f"Unable to connect to server {self._remote_host}:{self._remote_port}. Try running again or running the server using `python -m smarts.core.utils.sumo_server`."
) from error
client_socket.send(f"{sumo_binary}:{json.dumps(base_params)}\n".encode("utf-8"))
self._client_socket = client_socket
response = client_socket.recv(1024)
self._host, _, port = response.decode("utf-8").partition(":")
self._port = int(port)
[docs] def terminate(self, kill: bool):
self._client_socket.send("e:".encode("utf-8"))
self._client_socket.close()
@property
def port(self) -> int:
return self._port or 0
@property
def host(self) -> str:
return self._host or "-1"
[docs] def poll(self) -> Optional[int]:
return None
[docs] def wait(self, timeout: Optional[float] = None) -> int:
return 0
[docs]class LocalSumoProcess(SumoProcess):
"""Connects to a local sumo process."""
def __init__(self, sumo_port) -> None:
self._sumo_proc = None
self._sumo_port = sumo_port
[docs] def generate(
self, base_params: List[str], sumo_binary: Literal["sumo", "sumo-gui"] = "sumo"
):
if self._sumo_port is None:
self._sumo_port = networking.find_free_port()
sumo_cmd = [
os.path.join(SUMO_PATH, "bin", sumo_binary),
f"--remote-port={self._sumo_port}",
*base_params,
]
self._sumo_proc = subprocess.Popen(
sumo_cmd,
stdin=subprocess.DEVNULL,
stdout=subprocess.DEVNULL,
stderr=subprocess.DEVNULL,
close_fds=True,
)
@property
def port(self) -> int:
assert self._sumo_port is not None
return self._sumo_port
@property
def host(self) -> str:
return "localhost"
[docs] def terminate(self, kill):
if self._sumo_proc:
_safe_close(self._sumo_proc.stdin)
_safe_close(self._sumo_proc.stdout)
_safe_close(self._sumo_proc.stderr)
if kill:
self._sumo_proc.kill()
self._sumo_proc = None
[docs] def poll(self) -> Optional[int]:
return self._sumo_proc.poll()
[docs] def wait(self, timeout=None):
return self._sumo_proc.wait(timeout=timeout)
[docs]class TraciConn:
"""A simplified utility for connecting to a SUMO process."""
def __init__(
self,
sumo_process: SumoProcess,
host: str = "localhost",
name: str = "",
):
self._traci_conn = None
self._sumo_port = None
self._sumo_version: Tuple[int, ...] = tuple()
self._host = host
self._name = name
self._log = logging.Logger(self.__class__.__name__)
self._log = logging
self._connected = False
self._sumo_process = sumo_process
def __del__(self) -> None:
# We should not raise in delete.
try:
self.close_traci_and_pipes()
except Exception:
pass
[docs] def connect(
self,
timeout: float,
minimum_traci_version: int,
minimum_sumo_version: Tuple[int, ...],
debug: bool = False,
):
"""Attempt a connection with the SUMO process."""
traci_conn = None
self._host = self._sumo_process.host
self._sumo_port = self._sumo_process.port
try:
# See if the process is still alive before attempting a connection.
with suppress_output(stderr=not debug, stdout=True):
traci_conn = traci.connect(
self._sumo_process.port,
host=self._sumo_process.host,
numRetries=max(0, int(20 * timeout)),
proc=self._sumo_process,
waitBetweenRetries=0.05,
) # SUMO must be ready within timeout seconds
# We will retry since this is our first sumo command
except traci.exceptions.FatalTraCIError as err:
self._log.error(
"[%s] TraCI could not connect in time to '%s:%s' [%s]",
self._name,
self._host,
self._sumo_port,
err,
)
# XXX: Actually not fatal...
raise
except traci.exceptions.TraCIException as err:
self._log.error(
"[%s] SUMO process died while trying to connect to '%s:%s' [%s]",
self._name,
self._host,
self._sumo_port,
err,
)
self.close_traci_and_pipes()
raise
except ConnectionRefusedError:
self._log.error(
"[%s] Intended TraCI server '%s:%s' refused connection.",
self._name,
self._host,
self._sumo_port,
)
self.close_traci_and_pipes()
raise
self._connected = True
self._traci_conn = traci_conn
try:
if not self.viable:
raise traci.exceptions.TraCIException("TraCI server already finished!?")
vers, vers_str = traci_conn.getVersion()
if vers < minimum_traci_version:
raise ValueError(
f"TraCI API version must be >= {minimum_traci_version}. Got version ({vers})"
)
self._sumo_version = tuple(
int(v) for v in vers_str.partition(" ")[2].split(".")
) # e.g. "SUMO 1.11.0" -> (1, 11, 0)
if self._sumo_version < minimum_sumo_version:
raise ValueError(f"SUMO version must be >= SUMO {minimum_sumo_version}")
except traci.exceptions.FatalTraCIError as err:
self._log.error(
"[%s] TraCI disconnected for connection attempt '%s:%s': [%s]",
self._name,
self._host,
self._sumo_port,
err,
)
# XXX: the error type is changed to TraCIException to make it consistent with the
# process died case of `traci.connect`. Since TraCIException is fatal just in this case...
self.close_traci_and_pipes()
raise traci.exceptions.TraCIException(err)
except OSError as err:
self._log.error(
"[%s] OS error occurred for TraCI connection attempt '%s:%s': [%s]",
self._name,
self._host,
self._sumo_port,
err,
)
self.close_traci_and_pipes()
raise traci.exceptions.TraCIException(err)
except ValueError:
self.close_traci_and_pipes()
raise
@property
def connected(self) -> bool:
"""Check if the connection is still valid."""
return self._sumo_process is not None and self._connected
@property
def viable(self) -> bool:
"""If making a connection to the sumo process is still viable."""
return self._sumo_process is not None and self._sumo_process.poll() is None
@property
def sumo_version(self) -> Tuple[int, ...]:
"""Get the current SUMO version as a tuple."""
return self._sumo_version
@property
def port(self) -> Optional[int]:
"""Get the used TraCI port."""
return self._sumo_port
@property
def hostname(self) -> str:
"""Get the used TraCI port."""
return self._host
def __getattr__(self, name: str) -> Any:
if not self.connected:
raise traci.exceptions.FatalTraCIError("TraCI died.")
attribute = getattr(self._traci_conn, name)
if inspect.isbuiltin(attribute) or inspect.ismethod(attribute):
attribute = functools.partial(
_wrap_traci_method,
method=attribute,
attribute_name=name,
traci_conn=self,
)
elif isinstance(attribute, traci.domain.Domain):
attribute = DomainWrapper(
traci_conn=self, domain=attribute, attribute_name=name
)
else:
raise NotImplementedError()
return attribute
[docs] def must_reset(self):
"""If the version of sumo will have errors if just reloading such that it must be reset."""
return self._sumo_version > (1, 12, 0)
[docs] def close_traci_and_pipes(self, wait: bool = True, kill: bool = True):
"""Safely closes all connections. We should expect this method to always work without throwing"""
if self._connected:
self._log.debug("Closing TraCI connection to %s", self._sumo_port)
_safe_close(self._traci_conn, wait=wait)
if self._sumo_process:
self._sumo_process.terminate(kill=kill)
self._log.info(
"Killed TraCI server process '%s:%s'", self._host, self._sumo_port
)
self._sumo_process = None
self._connected = False
[docs] def teardown(self):
"""Clean up all resources."""
self.close_traci_and_pipes(True)
self._traci_conn = None
def _wrap_traci_method(
*args, method, traci_conn: TraciConn, attribute_name: str, **kwargs
):
# Argument order must be `*args` first so `method` and `sumo_process` are keyword only arguments.
try:
return method(*args, **kwargs)
except traci.exceptions.FatalTraCIError as err:
logging.error(
"[%s] TraCI '%s:%s' disconnected for call '%s', process may have died: [%s]",
traci_conn._name,
traci_conn.hostname,
traci_conn.port,
attribute_name,
err,
)
# TraCI cannot continue
traci_conn.close_traci_and_pipes()
raise traci.exceptions.FatalTraCIError("TraCI died.") from err
except OSError as err:
logging.error(
"[%s] OS error occurred for TraCI '%s:%s' call '%s': [%s]",
traci_conn._name,
traci_conn.hostname,
traci_conn.port,
attribute_name,
err,
)
traci_conn.close_traci_and_pipes()
raise OSError("Connection dropped.") from err
except traci.exceptions.TraCIException as err:
# Case where TraCI/SUMO can theoretically continue
raise traci.exceptions.TraCIException("TraCI can continue.") from err
except KeyboardInterrupt:
traci_conn.close_traci_and_pipes(wait=False)
raise