Source code for mltk.utils.python

"""Common Python utilities

See the source code on Github: `mltk/utils/python.py <https://github.com/siliconlabs/mltk/blob/master/mltk/utils/python.py>`_
"""

import collections
import sys
import os
import re
import json
import logging
import importlib
import inspect
import time
import copy
from enum import Enum
import warnings
from typing import Iterable, Any, Union

from mltk import MLTK_ROOT_DIR
from .logger import DummyLogger, make_filelike
from .shell_cmd import run_shell_cmd
from .path import fullpath


SHORT_VERSION = '.'.join([str(x) for x in sys.version_info[:2]])
"""The Python version as <major>.<minor>
e.g.: 3.9
"""

def _defaultdict_not_found():
    return None


[docs]def DefaultDict(d: dict = None, **kwargs) -> collections.defaultdict: """Creates a directory that returns None if a key does not exist NOTE: Nested dictionaries are also updated to a defaultdict """ def _convert_to_default_dict(obj): if isinstance(obj, dict): obj = DefaultDict(obj) elif isinstance(obj, list): for i, x in enumerate(obj): obj[i] = _convert_to_default_dict(x) return obj if d is not None: kwargs.update(d) for key, value in kwargs.items(): kwargs[key] = _convert_to_default_dict(value) return collections.defaultdict(_defaultdict_not_found, kwargs)
[docs]class DictObject(dict): """Standard Python dictionary that allows for accessing entries as object properties, e.g.: my_dict_obj = DictObject({'foo': 1, 'bar': False}) # Both lines do the same thing foo = my_dict_obj.foo foo = my_dict_obj['foo'] my_dict_obj.bar = True my_dict_obj['bar'] = True """
[docs] def __getattr__(self, name): try: return self[name] except KeyError as e: raise AttributeError(e)
[docs] def __setattr__(self, name, value): self[name] = value
[docs]def merge_dict(destination: dict, source: dict, copy_destination=False) -> dict: """Recursively merge the source dictionary into the destination and return the destination """ if copy_destination: destination = copy.deepcopy(destination) for key, value in source.items(): if isinstance(value, dict): # get node or create one node = destination.setdefault(key, {}) merge_dict(node, value) else: destination[key] = value return destination
[docs]def as_list(obj: Any, split: str=None) -> list: """Convert the given object to a list - If obj is None, then return empty list - If obj is a string, If the `split` argument is given then return obj.split(split) else just wrap the string in a list """ if obj is None: return [] elif isinstance(obj, list): return obj elif isinstance(obj, str): return [obj] if not split else [x.strip() for x in obj.split(split)] elif isinstance(obj, collections.abc.Iterable): return [x for x in obj] else: return [obj]
[docs]def flatten_list(l : Iterable) -> list: """Flatten the given iterable object to a list""" retval = [] for x in l: try: iter(x) except TypeError: is_iterable = False else: is_iterable = True if is_iterable: retval.extend(flatten_list(x)) else: retval.append(x) return retval
[docs]def list_rindex(lst: Iterable, value: Any) -> int: """Reverse find element index that is given value""" for i, v in enumerate(reversed(lst)): if v == value: return len(lst) - i - 1 # return the index in the original list return -1
[docs]def contains_class_type(l: Iterable, cls: Any) -> bool: """Return if the given list contains a class with the given type""" for e in l: if isinstance(e, cls): return True return False
[docs]def get_case_insensitive(value: str, l: Iterable) -> str: """Get the given string with case-insensitive comparsion""" if value is None: return None value = value.lower() for v in l: if v.lower() == value: return v return None
[docs]def is_true(arg) -> bool: """Return if the given argument is a True value""" if isinstance(arg, str): return arg.lower() in ('yes', 'true', 'on', '1') if isinstance(arg, bool): return arg if isinstance(arg, int): return arg != 0 raise ValueError(f'Invalid boolean arg: {arg}')
[docs]def is_false(arg) -> bool: """Return if the given argument is a False value""" if isinstance(arg, str): return arg.lower() in ('no', 'false', 'off', '0') if isinstance(arg, bool): return arg if isinstance(arg, int): return arg == 0 raise ValueError(f'Invalid boolean arg: {arg}')
[docs]def forward_method_kwargs(**kwargs) -> dict: """Return all the keyword-arguments of a method, excluding the 'self' argument""" retval = {} for key, value in kwargs.items(): if key == 'self' or key.startswith('_'): continue elif key == 'kwargs': retval.update(value) else: retval[key] = value return retval
[docs]def prepend_exception_msg(e:Exception, msg:str) -> Exception: """Prepend a message to the given exception""" e.args = (msg, *e.args) all_str = True for x in e.args: try: str(x) except: all_str = False break # If every entry in the exception msg is a string # then make it look pretty by combining into a coma-separated string if all_str: s = ', '.join(str(x) for x in e.args) e.args = (s, ) return e
[docs]def append_exception_msg(e:Exception, msg:str) -> Exception: """Append a message to the given exception""" e.args = (*e.args, msg) all_str = True for x in e.args: try: str(x) except: all_str = False break # If every entry in the exception msg is a string # then make it look pretty by combining into a coma-separated string if all_str: s = ', '.join(str(x) for x in e.args) e.args = (s, ) return e
[docs]def debugger_is_active() -> bool: """Return if the debugger is currently active""" gettrace = getattr(sys, 'gettrace', lambda : None) return gettrace() is not None
[docs]def notebook_is_active() -> bool: """Return if script is executing in a IPython notebook (e.g. Jupyter notebook)""" for x in sys.modules: if x.lower() == 'ipykernel': return True return False
[docs]def install_pip_package( package:str, module_name:str=None, logger: logging.Logger=None, install_dir:str=None, upgrade=False, no_deps=False ): """Install the given pip package is necessary""" logger = logger or DummyLogger() if install_dir: install_dir = fullpath(install_dir) if install_dir not in sys.path: os.makedirs(install_dir, exist_ok=True) logger.info(f'Adding {install_dir} to sys.path') sys.path.append(install_dir) version_match = re.match(r'([\w\_\-]+)([=<>]).*', package) if not module_name: if version_match: module_name = package[:version_match.start(2)] else: module_name = package # Only try to import the module without running pip if no version is specified and upgrade=False if not upgrade and not version_match: try: importlib.import_module(module_name) return except: pass make_filelike(logger) cmd = [sys.executable, "-m", "pip", "install"] if upgrade: cmd.append('-U') if no_deps: cmd.append('--no-deps') cmd.append(package) if install_dir: cmd.extend(['--target', install_dir]) logger.warning(f'Running cmd: {" ".join(cmd)}\n(This may take awhile, please be patient ...)') retcode, retval = run_shell_cmd(cmd, outfile=logger) if retcode != 0: raise RuntimeError(f'Failed to install pip package: {package}, err:\n{retval}')
[docs]def import_module_at_path(path:str, reload=False): """Import the Python module at the given path and return the imported module """ module_package = None path = fullpath(path) mltk_root_path = fullpath(MLTK_ROOT_DIR) # If the path is within the mltk directory # Then generate the module path relative to the mltk package if path.startswith(f'{mltk_root_path}/'): mltk_rel_path = os.path.relpath(path, mltk_root_path).replace('\\', '/') module_package = None module_name = mltk_rel_path.replace('.py', '').replace('/', '.') # Else if the path to an external .py script was provided elif path.endswith('.py'): path_dir = os.path.dirname(path).replace('\\', '/') module_name = os.path.basename(path).replace('.py', '') if os.path.exists(f'{path_dir}/__init__.py'): # Do a relative import if the module is in a parent package module_name = '.' + module_name parent_dir = os.path.dirname(path_dir) module_package = os.path.basename(path_dir) if parent_dir not in sys.path: sys.path.insert(0, parent_dir) else: # Otherwise, do an absolute import if path_dir not in sys.path: sys.path.insert(0, path_dir) # Else the path to external Python directory was provided else: if not os.path.exists(f'{path}/__init__.py'): raise RuntimeError(f'Given path to directory: {path} does not contain a __init__.py file') parent_dir = os.path.dirname(path).replace('\\', '/') module_name = os.path.basename(path) if os.path.exists(f'{parent_dir}/__init__.py'): # Do a relative import if the module is in a parent package module_name = '.' + module_name module_package = os.path.basename(parent_dir) parent_dir = os.path.dirname(parent_dir).replace('\\', '/') if parent_dir not in sys.path: sys.path.insert(0, parent_dir) else: # Otherwise, do an absolute import if parent_dir not in sys.path: sys.path.insert(0, parent_dir) # If the module has already been imported # then reload it if necessary if module_name in sys.modules: if reload: return importlib.reload(sys.modules[module_name]) else: return sys.modules[module_name] # Otherwise import the module return importlib.import_module(module_name, package=module_package)
[docs]def load_json_safe(path:str, *args, **kwargs) -> object: """Load a JSON file and ignoring any single-line, multi-line comments and trailing commas Args: path: Path to JSON file args, kwargs: Arguments to pass into json.loads Return: Loaded JSON object """ RE_SINGLE_LINE_COMMENT = re.compile(r'("(?:(?=(\\?))\2.)*?")|(?:\/{2,}.*)') RE_MULTI_LINE_COMMENT = re.compile(r'("(?:(?=(\\?))\2.)*?")|(?:\/\*(?:(?!\*\/).)+\*\/)', flags=re.M|re.DOTALL) RE_TRAILING_COMMA = re.compile(r',(?=\s*?[\}\]])') with open(path, 'r') as fp: unfiltered_json_string = fp.read() filtered_json_string = RE_SINGLE_LINE_COMMENT.sub(r'\1', unfiltered_json_string) filtered_json_string = RE_MULTI_LINE_COMMENT.sub(r'\1', filtered_json_string) filtered_json_string = RE_TRAILING_COMMA.sub('', filtered_json_string) return json.loads(filtered_json_string, *args, **kwargs)
[docs]def find_object_key_with_value( obj:object, needle:object, throw_exception=False ) -> str: """Given an class or class instance, search the attribute values of the object for the given "needle" and return its corresponding key. Note: If a class if given then it must be instantiable using a default constructor. Args: obj: Class or class instance needle: Class attribute value to find in class instance throw_exception: If true, throw an exception if the needle is not found, return 'none' otherwise Return: Lowercase key of found attribute value or "none" if value is not found """ if inspect.isclass(obj) and not issubclass(obj, Enum): obj = obj() for key in dir(obj): if getattr(obj, key) == needle: return key.lower() if throw_exception: raise ValueError(f'{needle} not found in {obj}') return 'none'
[docs]def find_object_value_with_key( obj:object, needle:str, ignore_case=False, throw_exception=False ): """Given a class or class instance, search the attribute keys of the object for the given "needle" and return its corresponding value. NOTE: If a class if given then it must be instantiable using a default constructor (except of Enum classes). Args: obj: Class or class instance needle: Class attribute key to find in class instance ignore_case: Ignore the key's case if True throw_exception: If true, throw an exception if the needle is not found, return None otherwise Return: Value of found attribute key or None if key is not found """ if needle is None: return None if inspect.isclass(obj) and not issubclass(obj, Enum): obj = obj() if ignore_case: needle = needle.lower() for key in dir(obj): if ignore_case: if key.lower() == needle: return getattr(obj, key) else: if key == needle: return getattr(obj, key) if throw_exception: raise ValueError(f'{needle} not found in {obj}') return None
[docs]def find_object_value_with_key_or_value( obj:object, needle:Union[str,object], ignore_case=False, throw_exception=False ): """Given a class or class instance, search the attribute keys and values of the object for the given "needle" and return its corresponding value. NOTE: If a class if given then it must be instantiable using a default constructor (except of Enum classes). Args: obj: Class or class instance needle: Class attribute key or value to find in class instance ignore_case: Ignore the key's case if True (needle must be a string) throw_exception: If true, throw an exception if the needle is not found, return None otherwise Return: Value of found attribute key or None if key/value is not found """ if needle is None: return None if inspect.isclass(obj) and not issubclass(obj, Enum): obj = obj() needle_lower = None if ignore_case and isinstance(needle, str): needle_lower = needle.lower() for key in dir(obj): value = getattr(obj, key) if (needle_lower is None and key == needle) or \ (needle_lower is not None and key.lower() == needle_lower) or \ (value == needle): return value if throw_exception: raise ValueError(f'{needle} not found in {obj}') return None
[docs]def timeit(method): """Decorator to measure time it takes for method or function to execute""" def timed(*args, **kw): ts = time.time() try: return method(*args, **kw) finally: te = time.time() diff = (te - ts) * 1000 print(f'{method.__name__} {diff:4f}ms') return timed
[docs]def set_absl_log_level(level): """Set the absl.logging library log level and return previous level The absl.logging library, a google Python package, can be very verbose. This sets the log level and returns the previous log level Args: level: The new log level Returns: Previous log level """ prev_level = None try: import absl.logging prev_level = absl.logging.get_verbosity() absl.logging.set_verbosity(getattr(absl.logging, level)) except: pass return prev_level
[docs]def disable_warnings(func, level='ERROR'): """Function decorator hide warnings from the Python warnings library and absl.logging library while the decorated function executes""" def _wrapper(*args, **kwargs): prev_level = set_absl_log_level(level) try: with warnings.catch_warnings(): warnings.simplefilter('ignore') return func(*args, **kwargs) finally: set_absl_log_level(prev_level) return _wrapper