Source code for barrage.dataset.core

import copy
import os
from typing import List, Union

import cytoolz
import numpy as np
import pandas as pd
import tensorflow as tf

from barrage import api, logger
from barrage.utils import import_utils

SEARCH_MODULES = ["barrage.dataset"]


class RecordDataset(tf.keras.utils.Sequence):
    """A sequence designed to wrap a DataFrame-like object: apply load
    operations, fit & apply transforms, apply data augmentation, and support sampling
    of records.

    At train time: fit transform -> batch -> load -> transform -> augment.

    At validation time: batch -> load -> transform.

    At score time: batch -> load -> transform.

    Args:
        artifact_dir: str, path to artifact directory.
        cfg_dataset: dict, dataset subsection of config.
        records: api.InputRecords, records.
        mode: RecordMode, transform mode.
        batch_size: int, batch size.
    """

    def __init__(
        self,
        artifact_dir: str,
        cfg_dataset: dict,
        records: api.InputRecords,
        mode: api.RecordMode,
        batch_size: int,
    ):

        if not isinstance(mode, api.RecordMode):
            raise TypeError("mode must be type RecordMode")

        # Standardize InputRecords to Records
        if isinstance(records, pd.DataFrame):
            records.reset_index(drop=True, inplace=True)
            self.records = records.to_dict(orient="records")
        elif all(isinstance(record, dict) for record in records):
            self.records = records
        else:
            raise TypeError("record must be a list of dicts or pandas DataFrame")

        self.num_records = len(records)
        logger.info(f"Building {mode} dataset with {self.num_records} records")
        self.mode = mode
        self.batch_size = batch_size

        self.seed = cfg_dataset.get("seed")
        np.random.seed(self.seed)

        sample_count = cfg_dataset.get("sample_count")
        if self.mode == api.RecordMode.TRAIN and sample_count is not None:
            self._sample_inds = self.convert_sample_count_to_inds(
                [record[sample_count] for record in self.records]
            )
        else:
            self._sample_inds = list(range(self.num_records))
        self.shuffle()

        loader_cls = import_utils.import_obj_with_search_modules(
            cfg_dataset["loader"]["import"], search_modules=SEARCH_MODULES
        )
        self.loader = loader_cls(mode=mode, **cfg_dataset["loader"].get("params", {}))
        if not isinstance(self.loader, api.RecordLoader):
            raise TypeError(f"loader {self.loader} is not of type RecordLoader")

        transformer_cls = import_utils.import_obj_with_search_modules(
            cfg_dataset["transformer"]["import"], search_modules=SEARCH_MODULES
        )
        self.transformer = transformer_cls(
            mode=self.mode,
            loader=self.loader,
            **cfg_dataset["transformer"].get("params", {}),
        )
        if not isinstance(self.transformer, api.RecordTransformer):
            raise TypeError(
                f"transformer {self.transformer} is not of type RecordTransformer"
            )

        dataset_dir = os.path.join(artifact_dir, "dataset")
        if self.mode == api.RecordMode.TRAIN:
            self.augmentor = RecordAugmentor(cfg_dataset["augmentor"])
            logger.info(f"Fitting transform: {self.transformer.__class__.__name__}")
            self.transformer.fit(copy.deepcopy(self.records))
            logger.info(
                f"Transformer network params: {self.transformer.network_params}"
            )
            self.transformer.save(dataset_dir)
        else:
            self.transformer.load(dataset_dir)

    def __len__(self):
        """Number of batches in a sequence."""
        return int(np.ceil(len(self.sample_inds) / float(self.batch_size)))

    def __getitem__(self, ind) -> api.BatchDataRecords:
        """Get a batch by index.

        Args:
            ind: int, batch index.

        Returns:
            BatchDataRecords, batch data records.
        """
        batch_inds = self.sample_order[
            ind * self.batch_size : (ind + 1) * self.batch_size
        ]
        batch_records = copy.deepcopy([self.records[bi] for bi in batch_inds])

        if self.mode == api.RecordMode.TRAIN:
            lst_data_records = [
                self.augmentor(self.transformer.transform(self.loader(record)))
                for record in batch_records
            ]
        else:
            lst_data_records = [
                self.transformer.transform(self.loader(record))
                for record in batch_records
            ]

        return batchify_data_records(lst_data_records)

    def on_epoch_end(self):
        """Shuffle sample_order on epoch end."""
        self.shuffle()

    @property
    def sample_inds(self):
        """Index of records used for batches. Contains repeats if
        self.mode = RecordMode.TRAIN and self.sample_count.
        """
        return self._sample_inds.copy()

    @sample_inds.setter
    def sample_inds(self, x):
        raise ValueError("illegal set on 'sample_inds'")

    def shuffle(self):
        """Shuffle sample_inds to compute sample_order."""
        if self.mode == api.RecordMode.TRAIN or self.mode == api.RecordMode.VALIDATION:
            self.sample_order = self.sample_inds
            np.random.shuffle(self.sample_order)
        else:
            self.sample_order = self.sample_inds

    @staticmethod
    def convert_sample_count_to_inds(sample_count: List[int]) -> List[int]:
        """Convert a list of sample counts to a list of sample inds.

        Args:
            sample_count: list[int], integer list of sample counts.

        Returns:
            list[int], sample inds.
        """
        sample_count = [max(1, int(sc)) for sc in sample_count]
        sample_lsts = [[ind] * val for ind, val in enumerate(sample_count)]
        sample_inds = [item for li in sample_lsts for item in li]
        return sample_inds


[docs]class RecordAugmentor(object): """Class for applying a list of data augmentation functions to a data record. Args: funcs: list[dict], list of augmentation functions {"import": "python_path", "params": {...}}. """ def __init__(self, funcs: List[dict]): self.augment_func = self.reduce_compose( *[import_utils.import_partial_wrap_func(f) for f in funcs] ) def __call__(self, data_record: api.DataRecord) -> api.DataRecord: return self.augment(data_record)
[docs] def augment(self, data_record: api.DataRecord) -> api.DataRecord: """Apply augmentation to a train data record. Args: data_record: DataRecord, data record. Returns: DataRecord, augmented data record. """ return self.augment_func(data_record)
[docs] @staticmethod def reduce_compose(*funcs): """Compose a list of functions into a single function.""" if len(funcs) == 0: return lambda x: x from functools import reduce def _compose2(func1, func2): return lambda *args, **kwargs: func2(func1(*args, **kwargs)) return reduce(_compose2, funcs)
[docs]def batchify_data_records(data_records: List[api.DataRecord]) -> api.BatchDataRecords: """Stack a list of DataRecord into BatchRecord. This process converts a list of tuples comprising of dicts {str: float/array} into tuples of dict {str: array}. Float/array is concatenated along the first dimension. See Example. Args: data_records: list[DataRecord], list of individual data records. Returns: BatchDataRecords, batch data records. Example: :: data_record_1 = ({"input_1": 1, "input_2": 2}, {"output_1": 3}) data_record_2 = ({"input_1": 2, "input_2": 4}, {"output_1": 6}) batch_data_records = ( {"input_1": arr([1, 2], "input_2": arr([2, 4])}, {"output_1": arr([3, 6])} ) """ batch_data_records = tuple( cytoolz.merge_with(np.array, ii) for ii in zip(*data_records) ) return batch_data_records # type: ignore
[docs]def batchify_network_output( network_output: Union[np.ndarray, List[np.ndarray]], output_names: List[str] ) -> api.BatchRecordScores: """Convert network output scores to BatchRecordScores. This process converts a single numpy array or list of numpy arrays into a list of dictionaries. See example. Args: network_output: union[np.ndarray, list[np.ndarray], network output. Returns: BatchRecordScores, batch scores. Example: :: network_output == np.array([[1], [2]]) output_names = ["y"] batch_scores = [{"y": np.array([1])}, {"y": np.array([2])}] """ # Handle type inconsistency between outputs of single output/multi networks if isinstance(network_output, np.ndarray): dict_output = {output_names[0]: network_output} num_scores = len(network_output) else: dict_output = { output_names[ii]: network_output[ii] for ii in range(len(output_names)) } num_scores = len(network_output[0]) scores = [ {k: v[ii, ...] for k, v in dict_output.items()} # type: ignore for ii in range(num_scores) ] return scores