Source code for barrage.utils.import_utils

from functools import partial
import importlib
import inspect
from typing import Callable, List, Union

from tensorflow.python.keras import losses, metrics

from barrage import logger


[docs]def import_obj_with_search_modules( python_path: str, search_modules: List[str] = None, search_both_cases=False ) -> Callable: """Import a object (variable, class, function, etc...) from a python path. search_modules are used to potentially shorten the python path by eliminating the module part of the string. Optionally check both uncapitalized and capitalized version of the object name. Args: python_path: str, full python path or object name (in conjunction with search_modules). search_modules: list[str], python modules against short path. search_both_cases: bool, try capitalized and uncapitalized object name. Returns: callable, imported object. Raises: ImportError, object not found. """ python_path = _maybe_fix_tensorflow_python_path(python_path) try: module_path, obj_name = python_path.rsplit(".", 1) except ValueError: module_path = "" obj_name = python_path if module_path: module = importlib.import_module(module_path) if not hasattr(module, obj_name): raise ImportError( f"object: {obj_name} was not found in module {module_path}" ) return getattr(module, obj_name) else: if search_modules is None: raise ImportError(f"object: {obj_name} not found, no module provided.") obj_names = _make_both_cases(obj_name) if search_both_cases else [obj_name] for search_module_path in search_modules: module = importlib.import_module(search_module_path) for name in obj_names: if hasattr(module, name): return getattr(module, name) raise ImportError( f"object: {obj_name} was not found in the searched " f"modules {search_modules}" )
[docs]def import_or_alias(python_path: str) -> Union[Callable, str]: """Import an object from a full python path or assume it is an alias for some object in TensorFlow (e.g. "categorical_crossentropy"). Args: python_path: str, full python path or alias. Returns: callable/str, import object/alias. Raises: ImportError, object not found. """ python_path = _maybe_fix_tensorflow_python_path(python_path) try: module_path, obj_name = python_path.rsplit(".", 1) except ValueError: return python_path module = importlib.import_module(module_path) if not hasattr(module, obj_name): raise ImportError(f"object: {obj_name} was not found in module {module_path}") return getattr(module, obj_name)
[docs]def import_partial_wrap_func(import_block: dict) -> Callable: """Import a function from an import block and wrap with partial. Args: import_block: dict, {"import": str, "params": dict}. Returns: function. """ func = import_obj_with_search_modules(import_block["import"]) params = import_block.get("params", {}) return partial(func, **params)
[docs]def import_loss(import_block: dict) -> Union[losses.Loss, str]: """Import a loss from an import block. Args: import_block: dict, {"import": str, "params": dict}. Returns: Loss/str, non-alias loss or alias. Raises: TypeError, non-alias loss is not a Loss. """ loss = import_or_alias(import_block["import"]) if isinstance(loss, str): return loss else: params = import_block.get("params", {}) if not inspect.isclass(loss): raise TypeError( f"import loss: {loss} must be a class for v2 tensorflow.keras API" ) loss = loss(**params) if not isinstance(loss, losses.Loss): raise TypeError( f"import loss: {loss} is not a tensorflow.python.keras.losses.Loss" ) return loss
[docs]def import_metric(import_block: dict) -> Union[metrics.Metric, losses.Loss, str]: """Import a metric from an import block. Note: A loss can be a metric, but a metric cannot always be a loss. Args: import_block: dict, {"import": str, "params": dict}. Returns: Metric/Loss/str, non-alias metric or alias. Raises: TypeError, non-alias metric is not a Metric/loss. """ metric = import_or_alias(import_block["import"]) if isinstance(metric, str): return metric else: params = import_block.get("params", {}) if not inspect.isclass(metric): raise TypeError( f"import metric: {metric} must be a class " "for v2 tensorflow.keras API" ) metric = metric(**params) if not isinstance(metric, (metrics.Metric, losses.Loss)): raise TypeError( f"import metric: {metric} is not a " "tensorflow.python.keras.metrics.Metric or " "tensorflow.python.keras.losses.Loss" ) return metric
def _maybe_fix_tensorflow_python_path(python_path: str) -> str: """Fix a common mistake python path mistake tf.keras vs tensorflow.keras. Args: python_path: str, python path. Returns: str, python path. """ if python_path.startswith("tf.keras"): python_path = python_path.replace("tf.keras", "tensorflow.keras") logger.warning( f"fixing import {python_path} - replacing tf.keras " "with tensorflow.python.keras" ) return python_path def _make_both_cases(s: str) -> List[str]: """Capitalize and uncapitalize first letter of string. Original case first. Args: s, str. Return: list[str], capitalized and uncapitalized version of string. """ if s: upper = s[:1].upper() + s[1:] lower = s[:1].lower() + s[1:] if upper == s: return [upper, lower] else: return [lower, upper] else: return [""]