import os
from cobaya.model import get_model, Model
def _get_dill():
try:
import dill
except ImportError as excpt:
raise ImportError("Could not find the 'dill' package. This is not a strict "
"requirement for gpry, but without it the checkpoint "
"functionality does not work.") from excpt
return dill
_checkpoint_filenames = {
"model": "mod.pkl", "gpr": "gpr.pkl", "acquisition": "acq.pkl",
"convergence": "con.pkl", "options": "opt.pkl", "progress": "pro.pkl"}
[docs]
def create_path(path, verbose=True):
"""
Creates a path if it doesn't exist already and prints a message if creating a new
directory. If the directory already exits it does nothing.
Parameters
----------
path : string or path
The path which shall be created.
"""
if not os.path.exists(path):
os.makedirs(path)
if verbose:
print("Successfully created the directory %s" % path)
[docs]
def check_checkpoint(path=None):
"""
Checks if there are checkpoint files in a specific location and if so if they
are complete. Returns a list of bools.
Parameters
----------
path : string, optional
The path where the files are located. If ``None``, reports files as non-found.
Returns
-------
A boolean array containing whether the files exist in the specified
location in the following order:
[model, gp, acquisition, convergence, options]
"""
if path is None:
return [False] * len(_checkpoint_filenames)
return [os.path.exists(os.path.join(path, f)) for f in _checkpoint_filenames.values()]
[docs]
def read_checkpoint(path, model=None):
"""
Loads checkpoint files to be able to resume a run or save the results for
further processing.
Parameters
----------
path : string
The path where the files are located.
model : cobaya.model.Model, optional
If passed, it will be used instead of the loaded one.
Returns
-------
(model, gpr, acquisition, convergence, options, progress)
If any of the files does not exist or cannot be read the function will
return None instead.
"""
pickle = _get_dill()
# Check if a file exists in the checkpoint and if so resume from there.
checkpoint_files = check_checkpoint(path)
# Read in checkpoint
if model is not None and not isinstance(model, Model):
raise ValueError(
"If 'model' is not None, it must be a cobaya.model.Model instance."
)
if model is None:
with open(os.path.join(path, _checkpoint_filenames["model"]), 'rb') as i:
model = pickle.load(i) if checkpoint_files[0] else None
# Convert model from dict to model object
model = get_model(model)
with open(os.path.join(path, _checkpoint_filenames["gpr"]), 'rb') as i:
gpr = pickle.load(i) if checkpoint_files[1] else None
with open(os.path.join(path, _checkpoint_filenames["acquisition"]), 'rb') as i:
acquisition = pickle.load(i) if checkpoint_files[2] else None
with open(os.path.join(path, _checkpoint_filenames["convergence"]), 'rb') as i:
if checkpoint_files[3]:
convergence = pickle.load(i)
convergence.prior = model.prior
else:
convergence = None
with open(os.path.join(path, _checkpoint_filenames["options"]), 'rb') as i:
options = pickle.load(i) if checkpoint_files[5] else None
with open(os.path.join(path, _checkpoint_filenames["progress"]), 'rb') as i:
progress = pickle.load(i) if checkpoint_files[4] else None
return model, gpr, acquisition, convergence, options, progress
[docs]
def save_checkpoint(path, model, gpr, acquisition, convergence, options, progress):
"""
This function is used to save all relevant parts of the GP loop for reuse
as checkpoint in case the procedure crashes.
This function creates ``.pkl`` files which contain the instances
of the different modules.
The files can be loaded with the _read_checkpoint function.
Parameters
----------
path : The path where the files shall be saved
The files will be saved as *path* +(mod, gpr, acq, con, opt).pkl
model : Cobaya `model object <https://cobaya.readthedocs.io/en/latest/cosmo_model.html>`_
gpr : GaussianProcessRegressor
acquisition : GPAcquisition
convergence : Convergence_criterion
options : dict
progress : Progress instance
"""
if path is None:
return
pickle = _get_dill()
create_path(path, verbose=False)
try:
with open(os.path.join(path, _checkpoint_filenames["model"]), 'wb') as f:
# Save model as dict
model_dict = model.info()
pickle.dump(model_dict, f, pickle.HIGHEST_PROTOCOL)
with open(os.path.join(path, _checkpoint_filenames["gpr"]), 'wb') as f:
pickle.dump(gpr, f, pickle.HIGHEST_PROTOCOL)
with open(os.path.join(path, _checkpoint_filenames["acquisition"]), 'wb') as f:
pickle.dump(acquisition, f, pickle.HIGHEST_PROTOCOL)
with open(os.path.join(path, _checkpoint_filenames["convergence"]), 'wb') as f:
# TODO: maybe convergence should just not keep the prior!
# Need to delete the prior object in convergence so it doesn't
# do weird stuff while pickling
from copy import deepcopy
convergence = deepcopy(convergence)
convergence.prior = None
pickle.dump(convergence, f, pickle.HIGHEST_PROTOCOL)
with open(os.path.join(path, _checkpoint_filenames["options"]), 'wb') as f:
pickle.dump(options, f, pickle.HIGHEST_PROTOCOL)
with open(os.path.join(path, _checkpoint_filenames["progress"]), 'wb') as f:
pickle.dump(progress, f, pickle.HIGHEST_PROTOCOL)
except Exception as excpt:
raise RuntimeError("Could not save the checkpoint. Check if the path "
"is correct and exists. Error message: " + str(excpt))