Source code for mpi

# Defining some helpers for parallelisation.
import dill
import numpy as np
from warnings import warn
from numpy.random import SeedSequence, default_rng, Generator

try:
    from mpi4py import MPI

    # Use dill pickler (can seriealize more stuff, e.g. lambdas)
    MPI.pickle.__init__(dill.dumps, dill.loads)
    # Define some interfaces
    comm = MPI.COMM_WORLD
    SIZE = comm.Get_size()
    RANK = comm.Get_rank()
    is_main_process = not bool(RANK)
    multiple_processes = SIZE > 1
except ImportError:
    warn(
        "mpi4py could not be imported. "
        "It is optional but recommended for faster running in parallel."
    )
    # Define dummy interfaces
    comm = None
    SIZE = 1
    RANK = 0
    is_main_process = True
    multiple_processes = False


[docs] def get_random_generator(seed=None): """ Generates seed sequences for processes running in parallel. Parameters ---------- seed : int or numpy seed, or numpy.random.Generator, optional (default=None) A random seed to initialise a Generator, or a Generator to be used directly. If none is provided a random one will be drawn. """ if isinstance(seed, Generator): return seed if is_main_process: ss = SeedSequence(seed) child_seeds = ss.spawn(SIZE) if not multiple_processes: return default_rng(child_seeds[0]) ss = comm.scatter(child_seeds if is_main_process else None) return default_rng(ss)
[docs] def bcast(args, root=0): """ Wrapper for MPI.comm.bcast, that works if MPI not present. """ if multiple_processes: return comm.bcast(args, root=root) return args
[docs] def gather(args, root=0): """ Wrapper for MPI.comm.gather, that works if MPI not present. """ if multiple_processes: return comm.gather(args, root=root) return [args]
[docs] def allgather(args): """ Wrapper for MPI.allgather, that works if MPI not present. """ if multiple_processes: return comm.allgather(args) return [args]
[docs] def split_number_for_parallel_processes(n, n_proc=SIZE): """ Splits a number of atomic tasks `n` between the parallel processes. If `n` is not divisible by the number of processes, processes with lower rank are preferred, e.g. 5 tasks for 3 processes are assigned as [2, 2, 1]. Parameters ---------- n : int The number of atomic tasks n_proc : int, optional (default=number of MPI comm's) The number of processes to divide the tasks between Returns ------- An array with the number of tasks corresponding each process. """ n_rounded_to_nproc = int(np.ceil(n / n_proc)) * n_proc slots = np.zeros(n_rounded_to_nproc, dtype=int) slots[:n] = 1 slots = slots.reshape((int(len(slots) / n_proc), n_proc)) return np.sum(slots, axis=0)
[docs] def step_split(values): """ Broadcasts from rank=0 and splits array between MPI processes, using mpi.size as step. If starting from sorted arrays, it preserves "computational scaling" among processes, but producing similar-in-content partial arrays. """ if not multiple_processes: return values values = comm.bcast(values) return values[RANK::SIZE]
[docs] def merge_step_split(values): """ Gather step-split (with ``::mpi.SIZE``) arrays and returns the merged set for the rank=0 process (``None`` for the rest). """ if not multiple_processes: return values values_step = comm.gather(values) if is_main_process: values_merged = np.zeros(sum(len(v) for v in values_step)) for i, v in enumerate(values_step): values_merged[i::SIZE] = v return values_merged return None
[docs] def multi_gather_array(arrs): """ Gathers (possibly a list of) arrays from all processes into the main process. NB: mpi-gather guarantees rank order is preserved. Parameters ---------- arrs : array-like The arrays to gather Returns ------- The gathered array(s) from all processes """ if not isinstance(arrs, (list, tuple)): arrs = [arrs] Nobj = len(arrs) if not multiple_processes: return arrs all_arrs = comm.gather(arrs) if is_main_process: arrs = [ np.concatenate([all_arrs[r][i] for r in range(SIZE)]) for i in range(Nobj) ] return arrs else: return [None for i in range(Nobj)]
[docs] def sync_processes(): """ Makes all processes halt here until all have reached this point. """ if not multiple_processes: return comm.barrier()
[docs] def share_attr(instance, attr_name, root=0): """Broadcasts ``attr`` of ``instance`` from process of rank ``root``.""" if not multiple_processes: return setattr( instance, attr_name, comm.bcast(getattr(instance, attr_name, None), root=root) )
[docs] def compute_y_parallel(gpr, X, y, sigma_y, ensure_sigma_y=False): """ Computes the GPR mean (and std if `do_sigma_y=True`) in parallel. Returns the resulting `(y, sigma_y)` arrays (computed or given) for rank 0, and ``None`` otherwise. """ if multiple_processes: y = comm.bcast(y) if y is None: # assume sigma_y is also None this_X = step_split(X) if len(this_X) > 0: if ensure_sigma_y: this_y, this_sigma_y = gpr.predict( this_X, return_std=True, validate=False ) else: this_y = gpr.predict(this_X, return_std=False, validate=False) else: this_y = np.array([], dtype=float) this_sigma_y = np.array([], dtype=float) if ensure_sigma_y else None return ( merge_step_split(this_y), merge_step_split(this_sigma_y) if ensure_sigma_y else None, ) sigma_y = comm.bcast(sigma_y) if sigma_y is None and ensure_sigma_y: this_X = step_split(X) if len(this_X) > 0: this_sigma_y = gpr.predict_std(this_X, validate=False) else: this_sigma_y = np.array([], dtype=float) return ( y if is_main_process else None, merge_step_split(this_sigma_y), ) return (y, sigma_y) if is_main_process else (None, None)