Source code for barrage.engine

import tensorflow as tf

from barrage import api, config, dataset, logger, model, services, solver
from barrage.utils import io_utils, tf_utils


[docs]class BarrageModel(object): """Class for training the network and scoring records with best performing network. Args: artifact_dir: str, path to artifact directory. """ def __init__(self, artifact_dir): self._artifact_dir = artifact_dir
[docs] def train( self, cfg: dict, records_train: api.InputRecords, records_validation: api.InputRecords, ) -> tf.keras.Model: """Train the network. Args: cfg: dict, config. records_train: InputRecords, training records. records_validation: InputRecords, validation records. Returns: tf.keras.Model, trained network. """ logger.info("Starting training") tf_utils.reset() cfg = config.prepare_config(cfg) logger.info(f"Creating artifact directory: {self.artifact_dir}") services.make_artifact_dir(self.artifact_dir) io_utils.save_json(cfg, "config.json", self.artifact_dir) io_utils.save_pickle(cfg, "config.pkl", self.artifact_dir) logger.info("Creating datasets") ds_train = dataset.RecordDataset( artifact_dir=self.artifact_dir, cfg_dataset=cfg["dataset"], records=records_train, mode=api.RecordMode.TRAIN, batch_size=cfg["solver"]["batch_size"], ) ds_validation = dataset.RecordDataset( artifact_dir=self.artifact_dir, cfg_dataset=cfg["dataset"], records=records_validation, mode=api.RecordMode.VALIDATION, batch_size=cfg["solver"]["batch_size"], ) network_params = ds_train.transformer.network_params io_utils.save_json(network_params, "network_params.json", self.artifact_dir) io_utils.save_pickle(network_params, "network_params.pkl", self.artifact_dir) logger.info("Building network") net = model.build_network(cfg["model"], network_params) model.check_output_names(cfg["model"], net) logger.info("Compiling network") opt = solver.build_optimizer(cfg["solver"]) objective = model.build_objective(cfg["model"]) net.compile(optimizer=opt, **objective) logger.info("Creating services") callbacks = services.create_all_services(self.artifact_dir, cfg["services"]) if "learning_rate_reducer" in cfg["solver"]: logger.info("Creating learning rate reducer") callbacks.append(solver.create_learning_rate_reducer(cfg["solver"])) logger.info("Training network") net.summary() net.fit( ds_train, validation_data=ds_validation, epochs=cfg["solver"]["epochs"], steps_per_epoch=cfg["solver"].get("steps"), callbacks=callbacks, verbose=1, ) return net
[docs] def predict(self, records_score: api.InputRecords) -> api.BatchRecordScores: """Score records. Args: records_score: InputRecords, scoring records. Returns: BatchRecordScores, scored data records. """ if not hasattr(self, "net"): self.load() ds_score = dataset.RecordDataset( artifact_dir=self.artifact_dir, cfg_dataset=self.cfg["dataset"], records=records_score, mode=api.RecordMode.SCORE, batch_size=self.cfg["solver"]["batch_size"], ) network_output = self.net.predict(ds_score, verbose=1) scores = [ ds_score.transformer.postprocess(score) for score in dataset.batchify_network_output( network_output, self.net.output_names ) ] return scores
[docs] def load(self): """Load the best performing checkpoint.""" # Load artifacts needed to recreate the network self.cfg = io_utils.load_pickle("config.pkl", self.artifact_dir) network_params = io_utils.load_pickle("network_params.pkl", self.artifact_dir) # Build network self.net = model.build_network(self.cfg["model"], network_params) # Load best checkpoint path = services.get_best_checkpoint_filepath(self.artifact_dir) self.net.load_weights(path).expect_partial() # not loading optimizer return self
@property def artifact_dir(self): return self._artifact_dir