Source code for gpry.cobaya

"""
Interface for Cobaya: wrapper using the Cobaya.sampler.Sampler class.

To use with Cobaya, once GPry has been installed with pip, add it to the sampler block
as ``gpry.CobayaSampler`.

For input arguments and options, see ``CobayaSampler.yaml`` in this folder, or run
``cobaya-doc gpry.CobayaSampler`` in a shell.
"""

import os
import re
import logging
from copy import deepcopy
from tempfile import gettempdir
from inspect import cleandoc
from typing import Union

from cobaya.sampler import Sampler
from cobaya.component import get_component_class
from cobaya.log import LoggedError
from cobaya.output import get_output, split_prefix, OutputReadOnly
from cobaya.tools import get_external_function
from cobaya.collection import SampleCollection

from gpry import mpi
from gpry.run import Runner

# TODO: resuming may not work in cases where internal gp_acq or gpr options aren't changed


# pylint: disable=no-member,access-member-before-definition
# pylint: disable=attribute-defined-outside-init
[docs] class CobayaWrapper(Sampler): """GPry: a package for Bayesian inference of expensive likelihoods using GPs.""" # Resume: _at_resume_prefer_new = ["plots", "callback", "callback_is_MPI_aware", "verbose"] # Other options _gpry_output_dir = "gpry_output" _surrogate_suffix = "gpr"
[docs] def initialize(self): """ Initializes GPry. """ # Set some args for the Runner that are derived from Cobaya if self.verbose is None: if self.log.getEffectiveLevel() == logging.NOTSET: self.verbose = 3 elif self.log.getEffectiveLevel() <= logging.DEBUG: self.verbose = 4 elif self.log.getEffectiveLevel() <= logging.INFO: self.verbose = 3 else: self.verbose = 2 # Prepare output self.path_checkpoint, self.surrogate_prefix = \ self.get_checkpoint_dir_and_surr_prefix(self.output) self.mc_sampler_upd_info = None self.mc_sampler_instance = None self.output_strategy = "resume" if self.output.is_resuming() else "overwrite" # Default to Runner & Acq defaults: (recusively) remove keys with None value: if self.gpr is None: self.gpr = {} for k, v in list(self.gpr.items()): if v is None: self.gpr.pop(k) if self.gp_acquisition is None: self.gp_acquisition = {} for k, v in list(self.gp_acquisition.items()): if v is None: self.gp_acquisition.pop(k) gp_acq_input = deepcopy(self.gp_acquisition) gp_acq_engine = gp_acq_input.pop("engine", "BatchOptimizer") # Grab the relevant acq options, merge them, and kick out the unused ones gp_acq_engine_options = None for k in list(gp_acq_input): if k.startswith("options_"): gp_acq_engine_options = gp_acq_input.pop(k) if k.lower().endswith(gp_acq_engine.lower()): gp_acq_input.update(gp_acq_engine_options or {}) gp_acq_input = {gp_acq_engine: gp_acq_input} # Initialize the runner try: self.gpry_runner = Runner( model=self.model, gpr=self.gpr, gp_acquisition=gp_acq_input, initial_proposer=self.initial_proposer, convergence_criterion=self.convergence_criterion, options=self.options, callback=get_external_function(self.callback) if self.callback else None, callback_is_MPI_aware=self.callback_is_MPI_aware, checkpoint=self.path_checkpoint, load_checkpoint=self.output_strategy, seed=self._rng, plots=self.plots, verbose=self.verbose, ) except (ValueError, TypeError) as excpt: raise LoggedError( self.log, f"Error when initializing GPry: {str(excpt)}" ) from excpt
[docs] def run(self): """ Gets the initial training points and starts the acquistion loop. """ if mpi.is_main_process: self.log.info("Starting learning stage...") try: self.gpry_runner.run() except Exception as excpt: raise LoggedError( self.log, "GPry failed during learning: %s", str(excpt) ) from excpt if mpi.is_main_process: if self.gpry_runner.has_converged: self.log.info("Learning stage finished successfully!") else: self.log.info("Learning stage failed to converge! Will MC sample anyway.") self.log.info("Starting MC-sampling stage...") try: self.do_surrogate_sample(resume=self.output.is_resuming()) except Exception as excpt: raise LoggedError( self.log, "GPry failed during MC sampling of the surrogate model: %s", str(excpt) ) from excpt if mpi.is_main_process: if self.gpry_runner.has_converged: self.log.info("MC-sampling finished successfully!") else: self.log.info("MC-sampling finished, but model *DID NOT CONVERGE*!") if self.plots: self.log.info("Doing some plots...") self.do_plots() return self.mc_sampler_upd_info, self.mc_sampler_instance
[docs] def do_surrogate_sample( self, sampler=None, add_options=None, resume=False, prefix=None, ): """ Perform an MC sample of the surrogate model. This function is called automatically at the end of a run, but it can be called by hand too e.g. if the initial one did not converge. Parameters ---------- sampler: str An anternative sampler, if different from the one specified at initialisation add_options: dict Configuration to be passed to the sampler. resume: bool (default: False) Whether to try to resume a previous run prefix: str, optional An alternative path where to save the sample. If not given, the sample will use the default one with suffix ``(_)gpr``. Returns ------- surr_info : dict The dictionary that was used to run (or initialized) the sampler, corresponding to the surrogate model, and populated with the sampler input specification. sampler : Sampler instance The sampler instance that has been run (or just initialised). The sampler products can be retrieved with the `Sampler.products()` method. """ if prefix is None: prefix = self.surrogate_prefix self.gpry_runner.generate_mc_sample( sampler=self.mc_sampler if sampler is None else sampler, add_options=add_options, output=prefix, resume=resume, ) self.mc_sampler_upd_info = self.gpry_runner.last_mc_surr_info self.mc_sampler_instance = self.gpry_runner.last_mc_sampler
@property def is_mc_sampled(self): """ Returns True if the MC sampling of the surrogate process has run and converged. """ return bool(getattr(self.gpry_runner, "_last_mc_samples", False))
[docs] def do_plots(self, ext="png"): """ Produces some results and diagnosis plots. """ self.gpry_runner.plot_distance_distribution(format=format) self.gpry_runner.plot_progress(ext=ext) if self.is_mc_sampled: self.gpry_runner.plot_mc(ext=ext)
[docs] def samples( self, combined: bool = False, skip_samples: float = 0, to_getdist: bool = False, ) -> Union[SampleCollection, "MCSamples"]: """ Returns the last sample from the surrogate model. """ return self.mc_sampler_instance.samples( combined=combined, skip_samples=skip_samples, to_getdist=to_getdist, )
[docs] def products( self, combined: bool = False, skip_samples: float = 0, to_getdist: bool = False, ) -> dict: """ Returns the products of the run: an MC sample of the surrogate posterior under ``sample``, and the GPRy ``Runner`` object under ``runner``. """ products = {"runner": self.gpry_runner} products.update( self.mc_sampler_instance.products( combined=combined, skip_samples=skip_samples, to_getdist=to_getdist ) ) return products
[docs] @classmethod def get_checkpoint_dir_and_surr_prefix(cls, output=None): """ Folder where the checkpoint output of GPry is going to be saved, and prefix for the output object of the MC sample of the surrogate model, given a Cobaya ``OutputReadOnly`` instance. These two are wrapped into a single classmethod in order to use the same temp folder if called with dummy output. Parameters ---------- output: cobaya.output.Output, cobaya.output.DummyOutput, optional Cobaya output instance. Can be a dummy one or None, in which case a temporary folder will be created. Returns ------- checkpoint_dir: str Relative folder where the GPry checkpoint will be saved. surrogate_prefix: str Prefix for surrogate MC chains for a Cobaya ``Output`` with relative path. Examples -------- Assuming that ``cls._gpry_output_dir = "gpry_output"`` and ``cls._surrogate_suffix = "gpr"``: >>> from cobaya.output import get_output >>> cls.get_checkpoint_dir_and_surr_prefix(get_output("folder/")) 'folder/gpry_output', 'folder/gpr' >>> cls.get_checkpoint_dir_and_surr_prefix(get_output("folder/prefix")) 'folder/prefix_gpry_output', 'folder/prefix_gpr' >>> cls.get_checkpoint_dir_and_surr_prefix(get_output()) # dummy output '[tmp_folder]/gpry_output', '[tmp_folder]/gpr' """ if output: return (output.add_suffix(cls._gpry_output_dir, separator="_"), output.add_suffix(cls._surrogate_suffix, separator="_")) return None, None
[docs] @classmethod def output_files_regexps(cls, output, info=None, minimal=False): """ Returns a list of tuples `(regexp, root)` of output files potentially produced. If `root` in the tuple is `None`, `output.folder` is used. If `minimal=True`, returns regexp's for the files that should really not be there when we are not resuming: GPry checkpoint products and the MC sample from the surrogate. """ path_checkpoint, surrogate_prefix = \ cls.get_checkpoint_dir_and_surr_prefix(output) # GPry checkpoint files regexps_tuples = [ (re.compile(re.escape(name + ".pkl")), path_checkpoint) for name in ["acq", "con", "gpr", "mod", "opt", "pro"] ] + [(None, os.path.join(path_checkpoint, "images"))] # MC sample from surrogate -- more precise if we know the sampler # Using OutputReadOnly and not Output here bc it can be called just from rank 0, # and it's never used except as an aux object to get correct surrogate MC prefixes surr_mc_output = OutputReadOnly(prefix=surrogate_prefix) surr_mc_folder, _ = split_prefix(surrogate_prefix) surr_mc_sampler = (info or {}).get("mc_sampler") if surr_mc_sampler: sampler = get_component_class(surr_mc_sampler, kind="sampler") regexps_tuples += [ (regexp[0], os.path.join(surr_mc_folder, regexp[1] or "")) for regexp in sampler.output_files_regexps( output=surr_mc_output, minimal=minimal) ] else: regexps_tuples += \ [(surr_mc_output.collection_regexp(name=None), surr_mc_folder)] return regexps_tuples
[docs] @staticmethod def is_nora(info): """Returns True if NORA is being used.""" acq_method = list((info or {}).get("gp_acquisition", {}) or {}) return ( len(acq_method) > 0 and isinstance(acq_method[0], str) and acq_method[0].lower() == "nora" )
@classmethod def get_desc(cls, info=None): nora_string = ( r"using the NORA parallelised acquisition approach \cite{Torrado:2023cbj}" ) if info is None: # Unknown case (no info passed) nora_string = f" [(if gp_acquisition: NORA) {nora_string}]" else: nora_string = " " + nora_string if cls.is_nora(info) else "" return ( "GPry: a package for Bayesian inference of expensive likelihoods " r"with Gaussian Processes \cite{Gammal:2022eob}" + nora_string + "." ) @classmethod def get_bibtex(cls): return cleandoc(r""" @article{Gammal:2022eob, author = {{El Gammal}, Jonas and Sch\"oneberg, Nils and Torrado, Jes\'us and Fidler, Christian}, title = "{Fast and robust Bayesian Inference using Gaussian Processes with GPry}", eprint = "2211.02045", archivePrefix = "arXiv", primaryClass = "astro-ph.CO", month = "11", year = "2022" # Cite only if using NORA as gp_acquisition (see desc.) @article{Torrado:2023cbj, author = {Torrado, Jes\'us and Sch\"oneberg, Nils and Gammal, Jonas El}, title = "{Parallelized Acquisition for Active Learning using Monte Carlo Sampling}", eprint = "2305.19267", archivePrefix = "arXiv", primaryClass = "stat.ML", month = "5", year = "2023" }""")