Source code for barrage.model

from typing import List

import tensorflow as tf

from barrage import config
from barrage.utils import import_utils


[docs]def build_network(cfg_model: dict, transform_params: dict) -> tf.keras.Model: """Build the network. Args: cfg_model: dict, model subsection of config. transform_params: dict, params from transformer. Returns: tf.keras.Model, network. Raises: TypeError, network not a tf.keras.Model. """ path = cfg_model["network"]["import"] network_params = cfg_model["network"].get("params", {}) net_func = import_utils.import_obj_with_search_modules(path) net = net_func(**network_params, **transform_params) if not isinstance(net, tf.keras.Model): raise TypeError(f"import network: {net} is not a tf.keras.Model") return net
[docs]def build_objective(cfg_model: dict) -> dict: """Build objective (loss, loss_weights, metrics, and sample_weight_mode) for each model output. Args: cfg_model: dict, model subsection of config. Returns: dict, objective """ loss = {} loss_weights = {} sample_weight_mode = {} metrics = {} for output in cfg_model["outputs"]: name = output["name"] loss[name] = import_utils.import_loss(output["loss"]) loss_weights[name] = output.get("loss_weight", 1.0) sample_weight_mode[name] = output.get("sample_weight_mode") metrics[name] = [ import_utils.import_metric(m) for m in output.get("metrics", []) ] return { "loss": loss, "loss_weights": loss_weights, "sample_weight_mode": sample_weight_mode, "metrics": metrics, }
[docs]def check_output_names(cfg_model: dict, net: tf.keras.Model): """Check the net outputs in the config match the actual net. Args: cfg_model: dict, model subsection of config. net: tf.keras.Model, net. Raises: ValueError, mismatch between config and net. """ config_net_outputs = {o["name"] for o in cfg_model["outputs"]} actual_net_outputs = set(net.output_names) if config_net_outputs != actual_net_outputs: raise ValueError( f"'config.model.outputs.names': {config_net_outputs} " f"mismatch actual model outputs: {actual_net_outputs} - " "order and names must exactly match" )
[docs]def sequential_from_config(layers: List[dict], **kwargs) -> tf.keras.Model: """Build a sequential model from a list of layer specifications. Supports references to network_params computed inside Transformers by specifying {{variable name}}. Args: layers: list[dict], layer imports. Returns: tf.keras.Model, network. """ layers = config._render_params(layers, kwargs) network = tf.keras.models.Sequential() for layer in layers: if "import" not in layer: raise KeyError(f"layer {layer} missing 'import' key") if not layer.keys() <= {"import", "params"}: unexpected_keys = set(layer.keys()).difference({"import", "params"}) raise KeyError(f"layer {layer} unexpected key(s): {unexpected_keys}") layer_cls = import_utils.import_obj_with_search_modules( layer["import"], search_modules=["tensorflow.keras.layers"] ) layer_params = layer.get("params", {}) network.add(layer_cls(**layer_params)) return network