# Copyright (C) 2020. Huawei Technologies Co., Ltd. All rights reserved.
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
# in the Software without restriction, including without limitation the rights
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
# copies of the Software, and to permit persons to whom the Software is
# furnished to do so, subject to the following conditions:
#
# The above copyright notice and this permission notice shall be included in
# all copies or substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NON-INFRINGEMENT. IN NO EVENT SHALL THE
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
# THE SOFTWARE.
import ctypes
import logging
import os
import sys
import warnings
from collections import defaultdict
from contextlib import contextmanager
from io import UnsupportedOperation
from time import time
from typing import Sequence
from .file import unpack
[docs]@contextmanager
def timeit(name: str, log):
"""Context manger that stopwatches the amount of time between context block start and end.
.. code-block:: python
import logging
with timeit(n,logging.log):
a = a * b
"""
start = time()
yield
elapsed_time = (time() - start) * 1000
log(f'"{name}" took: {elapsed_time:4f}ms')
[docs]def isnotebook():
"""Determines if executing in ipython (`Jupyter` Notebook)"""
try:
shell = get_ipython().__class__.__name__ # pytype: disable=name-error
if shell == "ZMQInteractiveShell" or "google.colab" in sys.modules:
return True # Jupyter notebook or qtconsole or Google Colab
except NameError:
pass
return False
libc = ctypes.CDLL(None)
try:
c_stderr = ctypes.c_void_p.in_dll(libc, "stderr")
c_stdout = ctypes.c_void_p.in_dll(libc, "stdout")
except:
# macOS
c_stderr = ctypes.c_void_p.in_dll(libc, "__stderrp")
c_stdout = ctypes.c_void_p.in_dll(libc, "__stdoutp")
[docs]def try_fsync(fd):
"""Attempts to see if `fsync` will work. Workaround for error on GitHub Actions."""
try:
os.fsync(fd)
except OSError:
# On GH actions was returning an OSError: [Errno 22] Invalid argument
pass
[docs]@contextmanager
def suppress_output(stderr=True, stdout=True):
"""Attempts to suppress console print statements.
.. spelling:word-list::
stderr
stdout
Args:
stderr: Suppress `stderr`.
stdout: Suppress `stdout`.
"""
cleanup_stderr = None
cleanup_stdout = None
try:
if stderr:
cleanup_stderr = _suppress_fileout("stderr")
if stdout:
cleanup_stdout = _suppress_fileout("stdout")
yield
finally:
if stderr and cleanup_stderr:
cleanup_stderr(c_stderr)
if stdout and cleanup_stdout:
cleanup_stdout(c_stdout)
def _suppress_fileout(stdname):
original = getattr(sys, stdname)
try:
original_std_fno = original.fileno()
except UnsupportedOperation as e:
if not isnotebook():
raise e
file = open(os.devnull, "w")
old_std = getattr(sys, stdname)
setattr(sys, stdname, file)
def cleanup_notebook(_):
nonlocal old_std, stdname
new_std = getattr(sys, stdname)
new_std.flush()
# Ensure attributes exist because of https://github.com/ipython/ipykernel/issues/867
if not hasattr(new_std, "watch_fd_thread"):
setattr(new_std, "watch_fd_thread", None)
if not hasattr(new_std, "_exc"):
setattr(new_std, "_exc", None)
new_std.close()
setattr(sys, stdname, old_std)
## This case is notebook
return cleanup_notebook
dup_std_fno = os.dup(original_std_fno)
devnull_fno = os.open(os.devnull, os.O_WRONLY)
os.dup2(devnull_fno, original_std_fno)
setattr(sys, stdname, os.fdopen(devnull_fno, "w"))
def cleanup_local(c_stdobj):
getattr(sys, stdname).flush()
libc.fflush(c_stdobj)
try_fsync(devnull_fno)
os.close(devnull_fno)
os.dup2(dup_std_fno, original_std_fno)
os.close(dup_std_fno)
try:
getattr(sys, stdname).close()
except OSError as e:
# This happens in some environments and is fine so we should ignore just it
if e.errno != 9: # [Errno 9] Bad file descriptor
raise e
finally:
setattr(sys, stdname, original)
return cleanup_local
[docs]@contextmanager
def suppress_websocket():
"""Attempts to filter out irritating `websocket` library messages."""
websocket_filter = lambda record: "goodbye" not in record.msg
with warnings.catch_warnings():
# XXX: websocket-client library seems to have leaks on connection
# retry that cause annoying warnings within Python 3.8+
warnings.filterwarnings("ignore", category=ResourceWarning)
# Filter out the websocket "goodbye" messages.
_logger = logging.getLogger("websocket")
_logger.addFilter(websocket_filter)
yield
_logger.removeFilter(websocket_filter)
[docs]def diff_unpackable(obj, other_obj):
"""Do an asserted comparison of an object that is able to be unpacked. This works with nested collections:
dictionaries, named-tuples, tuples, lists, numpy arrays, and dataclasses.
Raises:
AssertionError: if objects do not match.
"""
obj_unpacked = unpack(obj)
other_obj_unpacked = unpack(other_obj)
def sort(orig_value):
value = orig_value
if isinstance(value, (dict, defaultdict)):
return dict(sorted(value.items(), key=lambda item: item[0]))
try:
s = sorted(value, key=lambda item: item[0])
except IndexError:
s = sorted(value)
except (KeyError, TypeError):
s = value
return s
def process(obj, other_obj, current_comparison):
if isinstance(obj, (dict, defaultdict)):
t_o = sort(obj)
assert isinstance(t_o, dict)
t_oo = sort(other_obj)
assert isinstance(t_oo, dict)
comps.append((t_o.keys(), t_oo.keys()))
comps.append((t_o.values(), t_oo.values()))
elif isinstance(obj, Sequence) and not isinstance(obj, (str)):
comps.append((sort(obj), sort(other_obj)))
elif obj != other_obj:
return f"{obj}!={other_obj} in {current_comparison}"
return ""
comps = []
result = process(obj_unpacked, other_obj_unpacked, None)
while len(comps) > 0:
o_oo = comps.pop()
for o, oo in zip(*o_oo):
if result != "":
return result
result = process(o, oo, o_oo)
return ""