Source code for smarts.core.utils.file

# Copyright (C) 2020. Huawei Technologies Co., Ltd. All rights reserved.
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
# in the Software without restriction, including without limitation the rights
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
# copies of the Software, and to permit persons to whom the Software is
# furnished to do so, subject to the following conditions:
#
# The above copyright notice and this permission notice shall be included in
# all copies or substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NON-INFRINGEMENT. IN NO EVENT SHALL THE
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
# THE SOFTWARE.
import dataclasses
import hashlib
import os
import pickle
import shutil
import struct
from collections import defaultdict
from contextlib import contextmanager
from ctypes import c_int64
from typing import Any, Generator, Sequence

import numpy as np

import smarts


[docs]def file_in_folder(filename: str, path: str) -> bool: """Checks to see if a file exists Args: filename: The name of the file. path: The path to the directory of the file. Returns: If the file exists. """ return os.path.exists(os.path.join(path, filename))
# https://stackoverflow.com/a/2166841
[docs]def isnamedtupleinstance(x): """Check to see if an object is a named tuple.""" t = type(x) b = t.__bases__ if len(b) != 1 or b[0] != tuple: return False f = getattr(t, "_fields", None) if not isinstance(f, tuple): return False return all(type(n) == str for n in f)
[docs]def replace(obj: Any, **kwargs): """Replace dataclasses and named tuples with the same interface.""" if isnamedtupleinstance(obj): return obj._replace(**kwargs) elif dataclasses.is_dataclass(obj): return dataclasses.replace(obj, **kwargs) raise ValueError("Must be a namedtuple or dataclass.")
[docs]def isdataclass(x): """Check if an object is a `dataclass`.""" return dataclasses.is_dataclass(x)
# TAI MTA: This is probably the wrong place for this utility: `logging.py`?
[docs]def unpack(obj): """A helper that can be used to print nested data objects (`tuple`, `dataclass`, `namedtuple`, ...). For example, ```python pprint(unpack(obs), indent=1, width=80, compact=True) ``` """ if isinstance(obj, dict): return {key: unpack(value) for key, value in obj.items()} elif isinstance(obj, (list, np.ndarray)): return [unpack(value) for value in obj] elif isnamedtupleinstance(obj): return {key: unpack(value) for key, value in obj._asdict().items()} elif isdataclass(obj): return {key: unpack(value) for key, value in dataclasses.asdict(obj).items()} elif isinstance(obj, tuple): return tuple(unpack(value) for value in obj) else: return obj
[docs]def copy_tree(from_path, to_path, overwrite=False): """Copy a directory tree (including files) to another location. Args: from_path: The directory to copy. to_path: The output directory. overwrite: If to overwrite the output directory. """ if os.path.exists(to_path): if overwrite: shutil.rmtree(to_path) else: raise FileExistsError( "The destination path={} already exists.".format(to_path) ) shutil.copytree(from_path, to_path)
[docs]def path2hash(file_path: str): """Converts a file path to a hash value.""" m = hashlib.md5() m.update(bytes(file_path, "utf-8")) return m.hexdigest()
[docs]def file_md5_hash(file_path: str) -> str: """Converts file contents to a hash value. Useful for doing a file diff.""" hasher = hashlib.md5() with open(file_path) as f: hasher.update(f.read().encode()) return str(hasher.hexdigest())
[docs]def pickle_hash(obj, include_version=False) -> str: """Converts a Python object to a hash value. NOTE: NOT stable across different Python versions.""" pickle_bytes = pickle.dumps(obj, protocol=4) hasher = hashlib.md5() hasher.update(pickle_bytes) if include_version: hasher.update(smarts.VERSION.encode()) return hasher.hexdigest()
[docs]def pickle_hash_int(obj) -> int: """Converts a Python object to a hash value. NOTE: NOT stable across different Python versions.""" hash_str = pickle_hash(obj) val = int(hash_str, 16) return c_int64(val).value
[docs]def smarts_local_user_dir() -> str: """Retrieves the smarts logging directory. Returns: str: The smarts local user directory path. """ ## Following should work for linux and macos smarts_dir = os.path.join(os.path.expanduser("~"), ".smarts") os.makedirs(smarts_dir, exist_ok=True) return smarts_dir
[docs]def smarts_global_user_dir() -> str: """Retrieves the smarts global user directory. Returns: str: The smarts global user directory path. """ smarts_dir = os.path.join("/etc", "smarts") return smarts_dir
[docs]def make_dir_in_smarts_log_dir(dir): """Return a new directory location in the smarts logging directory.""" return os.path.join(smarts_local_user_dir(), dir)
[docs]@contextmanager def suppress_pkg_resources(): """A context manager that injects an `ImportError` into the `pkg_resources` module to force package fallbacks in imports that can use alternatives to `pkg_resources`. """ import sys import pkg_resources from smarts.core.utils.invalid import raise_import_error pkg_res = sys.modules["pkg_resources"] sys.modules["pkg_resources"] = property(raise_import_error) yield sys.modules["pkg_resources"] = pkg_res
[docs]def read_tfrecord_file(path: str) -> Generator[bytes, None, None]: """Iterate over the records in a TFRecord file and return the bytes of each record. path: The path to the TFRecord file """ with open(path, "rb") as f: while True: length_bytes = f.read(8) if len(length_bytes) != 8: return record_len = int(struct.unpack("Q", length_bytes)[0]) _ = f.read(4) # masked_crc32_of_length (ignore) record_data = f.read(record_len) _ = f.read(4) # masked_crc32_of_data (ignore) yield record_data