Source code for gpry.io

"""
Module containing i/o utilities.
"""

import os

import dill as pickle

from gpry.gpr import GaussianProcessRegressor
from gpry.truth import get_truth, Truth

_checkpoint_filenames = {
    "truth": "tru.pkl",
    "gpr": "gpr.pkl",
    "acquisition": "acq.pkl",
    "convergence": "con.pkl",
    "options": "opt.pkl",
    "progress": "pro.pkl",
}

# For backwards compatibility (TODO: deprecate)
_checkpoint_filename_model = "mod.pkl"


[docs] def create_path(path, verbose=True): """ Creates a path if it doesn't exist already and prints a message if creating a new directory. If the directory already exits it does nothing. Parameters ---------- path : string or path The path which shall be created. """ if not os.path.exists(path): os.makedirs(path) if verbose: print("Successfully created the directory %s" % path)
[docs] def check_checkpoint(path=None): """ Checks if there are checkpoint files in a specific location and if so if they are complete. Returns a list of bools. Parameters ---------- path : string, optional The path where the files are located. If ``None``, reports files as non-found. Returns ------- A boolean array containing whether the files exist in the specified location in the following order: [truth, gp, acquisition, convergence, options] """ if path is None: return [False] * len(_checkpoint_filenames) return [os.path.exists(os.path.join(path, f)) for f in _checkpoint_filenames.values()]
[docs] def read_checkpoint(path, truth=None): """ Loads checkpoint files to be able to resume a run or save the results for further processing. Parameters ---------- path : string The path where the files are located. truth : gpry.truth.Truth, optional If passed, it will be used instead of the loaded one. Returns ------- (truth, gpr, acquisition, convergence, options, progress) If any of the files does not exist or cannot be read the function will return None instead. """ # Check if a file exists in the checkpoint and if so resume from there. checkpoint_files = check_checkpoint(path) # Read in checkpoint if truth is not None and not isinstance(truth, Truth): raise ValueError( "If 'truth' is not None, it must be a gpry.truth.Truth instance." ) if truth is None: with open(os.path.join(path, _checkpoint_filenames["truth"]), "rb") as i: truth = pickle.load(i) if checkpoint_files[0] else None # Backwards compatibility: load Cobaya model (TODO: deprecate soon) filename_model = os.path.exists(os.path.join(path, _checkpoint_filename_model)) if truth is None and os.path.exists(filename_model): with open(filename_model, "rb") as i: truth = {"loglike": pickle.load(i)} truth = get_truth(**truth) with open(os.path.join(path, _checkpoint_filenames["gpr"]), "rb") as i: gpr = pickle.load(i) if checkpoint_files[1] else None with open(os.path.join(path, _checkpoint_filenames["acquisition"]), "rb") as i: acquisition = pickle.load(i) if checkpoint_files[2] else None with open(os.path.join(path, _checkpoint_filenames["convergence"]), "rb") as i: convergence = pickle.load(i) if checkpoint_files[3] else None with open(os.path.join(path, _checkpoint_filenames["options"]), "rb") as i: options = pickle.load(i) if checkpoint_files[5] else None with open(os.path.join(path, _checkpoint_filenames["progress"]), "rb") as i: progress = pickle.load(i) if checkpoint_files[4] else None return truth, gpr, acquisition, convergence, options, progress
[docs] def save_checkpoint(path, truth, gpr, acquisition, convergence, options, progress): """ This function is used to save all relevant parts of the GP loop for reuse as checkpoint in case the procedure crashes. This function creates ``.pkl`` files which contain the instances of the different modules. The files can be loaded with the read_checkpoint function. Parameters ---------- path : The path where the files shall be saved The files will be saved as *path* +(mod, gpr, acq, con, opt).pkl truth : Truth gpr : GaussianProcessRegressor acquisition : GPAcquisition convergence : Convergence_criterion options : dict progress : Progress instance """ if path is None: return create_path(path, verbose=False) try: if truth is not None: with open(os.path.join(path, _checkpoint_filenames["truth"]), "wb") as f: pickle.dump(truth.as_dict(), f, pickle.HIGHEST_PROTOCOL) with open(os.path.join(path, _checkpoint_filenames["gpr"]), "wb") as f: pickle.dump(gpr, f, pickle.HIGHEST_PROTOCOL) with open(os.path.join(path, _checkpoint_filenames["acquisition"]), "wb") as f: pickle.dump(acquisition, f, pickle.HIGHEST_PROTOCOL) with open(os.path.join(path, _checkpoint_filenames["convergence"]), "wb") as f: pickle.dump(convergence, f, pickle.HIGHEST_PROTOCOL) with open(os.path.join(path, _checkpoint_filenames["options"]), "wb") as f: pickle.dump(options, f, pickle.HIGHEST_PROTOCOL) with open(os.path.join(path, _checkpoint_filenames["progress"]), "wb") as f: pickle.dump(progress, f, pickle.HIGHEST_PROTOCOL) except Exception as excpt: raise RuntimeError( "Could not save the checkpoint. Check if the path " "is correct and exists. Error message: " + str(excpt) ) from excpt
[docs] def ensure_gpr( gpr, truth=None, acquisition=None, convergence=None, options=None, progress=None ): """ Returns (if instance passed) or loads (if string) the given gpr and associated objects. If loading, any object passed as a keyword will be preferred to the loaded one. Parameters ---------- gpr : GaussianProcessRegressor truth : Truth acquisition : GPAcquisition, optional convergence : Convergence_criterion, optional options : dict, optional progress : Progress instance, optional Returns ------- (truth, gpr, acquisition, convergence, options, progress) If any of the files does not exist or cannot be read the function will return None instead. """ if not isinstance(gpr, (str, GaussianProcessRegressor)): raise TypeError( "`gpr` needs to be a gpry GP Regressor or a string " "with a path to a checkpoint file." ) if isinstance(gpr, str): truth_, gpr, acq_, conv_, opt_, prog_ = read_checkpoint(gpr, truth=truth) else: truth_, acq_, conv_, opt_, prog_ = None, None, None, None, None truth = truth or truth_ acquisition = acquisition or acq_ convergence = convergence or conv_ options = options or opt_ progress = progress or prog_ return (truth, gpr, acquisition, convergence, options, progress)