Source code for gammapy.utils.parallel

# Licensed under a 3-clause BSD style license - see LICENSE.rst
"""Multiprocessing and multithreading setup."""
import importlib
import logging
from enum import Enum
from gammapy.utils.pbar import progress_bar

log = logging.getLogger(__name__)

__all__ = [
    "multiprocessing_manager",
    "run_multiprocessing",
    "BACKEND_DEFAULT",
    "N_JOBS_DEFAULT",
    "POOL_KWARGS_DEFAULT",
    "METHOD_DEFAULT",
    "METHOD_KWARGS_DEFAULT",
]


class ParallelBackendEnum(Enum):
    """Enum for parallel backend."""

    multiprocessing = "multiprocessing"
    ray = "ray"

    @classmethod
    def from_str(cls, value):
        """Get enum from string."""

        if value == "ray" and not is_ray_available():
            log.warning("Ray is not installed, falling back to multiprocessing backend")
            value = "multiprocessing"

        return cls(value)


class PoolMethodEnum(Enum):
    """Enum for pool method."""

    starmap = "starmap"
    apply_async = "apply_async"


BACKEND_DEFAULT = ParallelBackendEnum.multiprocessing
N_JOBS_DEFAULT = 1
ALLOW_CHILD_JOBS = False
POOL_KWARGS_DEFAULT = dict(processes=N_JOBS_DEFAULT)
METHOD_DEFAULT = PoolMethodEnum.starmap
METHOD_KWARGS_DEFAULT = {}


def get_multiprocessing():
    """Get multiprocessing module."""
    import multiprocessing

    return multiprocessing


def get_multiprocessing_ray():
    """Get multiprocessing module for ray backend."""
    import ray.util.multiprocessing as multiprocessing

    log.warning(
        "Gammapy support for parallelisation with ray is still a prototype and is not fully functional."
    )
    return multiprocessing


def is_ray_initialized():
    """Check if ray is initialized."""
    try:
        from ray import is_initialized

        return is_initialized()
    except ModuleNotFoundError:
        return False


def is_ray_available():
    """Check if ray is available."""
    try:
        importlib.import_module("ray")
        return True
    except ModuleNotFoundError:
        return False


[docs] class multiprocessing_manager: """Context manager to update the default configuration for multiprocessing. Only the default configuration will be modified, if class arguments like `n_jobs` and `parallel_backend` are set they will overwrite the default configuration. Parameters ---------- backend : {'multiprocessing', 'ray'} Backend to use. pool_kwargs : dict Keyword arguments passed to the pool. The number of processes is limited to the number of physical CPUs. method : {'starmap', 'apply_async'} Pool method to use. method_kwargs : dict Keyword arguments passed to the method Examples -------- :: import gammapy.utils.parallel as parallel from gammapy.estimators import FluxPointsEstimator fpe = FluxPointsEstimator(energy_edges=[1, 3, 10] * u.TeV) with parallel.multiprocessing_manager( backend="multiprocessing", pool_kwargs=dict(processes=2), ): fpe.run(datasets) """ def __init__(self, backend=None, pool_kwargs=None, method=None, method_kwargs=None): global BACKEND_DEFAULT, POOL_KWARGS_DEFAULT, METHOD_DEFAULT, METHOD_KWARGS_DEFAULT, N_JOBS_DEFAULT self._backend = BACKEND_DEFAULT self._pool_kwargs = POOL_KWARGS_DEFAULT self._method = METHOD_DEFAULT self._method_kwargs = METHOD_KWARGS_DEFAULT self._n_jobs = N_JOBS_DEFAULT if backend is not None: BACKEND_DEFAULT = ParallelBackendEnum.from_str(backend).value if pool_kwargs is not None: POOL_KWARGS_DEFAULT = pool_kwargs N_JOBS_DEFAULT = pool_kwargs.get("processes", N_JOBS_DEFAULT) if method is not None: METHOD_DEFAULT = PoolMethodEnum(method).value if method_kwargs is not None: METHOD_KWARGS_DEFAULT = method_kwargs def __enter__(self): pass def __exit__(self, type, value, traceback): global BACKEND_DEFAULT, POOL_KWARGS_DEFAULT, METHOD_DEFAULT, METHOD_KWARGS_DEFAULT, N_JOBS_DEFAULT BACKEND_DEFAULT = self._backend POOL_KWARGS_DEFAULT = self._pool_kwargs METHOD_DEFAULT = self._method METHOD_KWARGS_DEFAULT = self._method_kwargs N_JOBS_DEFAULT = self._n_jobs
class ParallelMixin: """Mixin class to handle parallel processing.""" _n_child_jobs = 1 @property def n_jobs(self): """Number of jobs as an integer.""" # TODO: this is somewhat unusual behaviour. It deviates from a normal default value handling if self._n_jobs is None: return N_JOBS_DEFAULT return self._n_jobs @n_jobs.setter def n_jobs(self, value): """Number of jobs setter as an integer.""" if not isinstance(value, (int, type(None))): raise ValueError( f"Invalid type: {value!r}, and integer or None is expected." ) self._n_jobs = value if ALLOW_CHILD_JOBS: self._n_child_jobs = value def _update_child_jobs(self): """needed because we can update only in the main process otherwise global ALLOW_CHILD_JOBS has default value""" if ALLOW_CHILD_JOBS: self._n_child_jobs = self.n_jobs else: self._n_child_jobs = 1 @property def _get_n_child_jobs(self): """Number of allowed child jobs as an integer.""" return self._n_child_jobs @property def parallel_backend(self): """Parallel backend as a string.""" if self._parallel_backend is None: return BACKEND_DEFAULT return self._parallel_backend @parallel_backend.setter def parallel_backend(self, value): """Parallel backend setter (str)""" if value is None: self._parallel_backend = None else: self._parallel_backend = ParallelBackendEnum.from_str(value).value
[docs] def run_multiprocessing( func, inputs, backend=None, pool_kwargs=None, method=None, method_kwargs=None, task_name="", ): """Run function in a loop or in Parallel. Notes ----- The progress bar can be displayed for this function. Parameters ---------- func : function Function to run. inputs : list List of arguments to pass to the function. backend : {'multiprocessing', 'ray'}, optional Backend to use. Default is None. pool_kwargs : dict, optional Keyword arguments passed to the pool. The number of processes is limited to the number of physical CPUs. Default is None. method : {'starmap', 'apply_async'} Pool method to use. Default is "starmap". method_kwargs : dict, optional Keyword arguments passed to the method. Default is None. task_name : str, optional Name of the task to display in the progress bar. Default is "". """ if backend is None: backend = BACKEND_DEFAULT if method is None: method = METHOD_DEFAULT if method_kwargs is None: method_kwargs = METHOD_KWARGS_DEFAULT if pool_kwargs is None: pool_kwargs = POOL_KWARGS_DEFAULT try: method_enum = PoolMethodEnum(method) except ValueError as e: raise ValueError(f"Invalid method: {method}") from e processes = pool_kwargs.get("processes", N_JOBS_DEFAULT) backend = ParallelBackendEnum.from_str(backend) multiprocessing = PARALLEL_BACKEND_MODULES[backend]() if backend == ParallelBackendEnum.multiprocessing: cpu_count = multiprocessing.cpu_count() if processes > cpu_count: log.info(f"Limiting number of processes from {processes} to {cpu_count}") processes = cpu_count if multiprocessing.current_process().name != "MainProcess": # with multiprocessing subprocesses cannot have childs (but possible with ray) processes = 1 if processes == 1: return run_loop( func=func, inputs=inputs, method_kwargs=method_kwargs, task_name=task_name ) if backend == ParallelBackendEnum.ray: address = "auto" if is_ray_initialized() else None pool_kwargs.setdefault("ray_address", address) log.info(f"Using {processes} processes to compute {task_name}") with multiprocessing.Pool(**pool_kwargs) as pool: pool_func = POOL_METHODS[method_enum] results = pool_func( pool=pool, func=func, inputs=inputs, method_kwargs=method_kwargs, task_name=task_name, ) return results
def run_loop(func, inputs, method_kwargs=None, task_name=""): """Loop over inputs and run function.""" results = [] callback = method_kwargs.get("callback", None) for arguments in progress_bar(inputs, desc=task_name): result = func(*arguments) if callback is not None: result = callback(result) results.append(result) return results def run_pool_star_map(pool, func, inputs, method_kwargs=None, task_name=""): """Run function in parallel.""" return pool.starmap(func, progress_bar(inputs, desc=task_name), **method_kwargs) def run_pool_async(pool, func, inputs, method_kwargs=None, task_name=""): """Run function in parallel async.""" results = [] for arguments in progress_bar(inputs, desc=task_name): result = pool.apply_async(func, arguments, **method_kwargs) results.append(result) # wait async run is done [result.wait() for result in results] return results POOL_METHODS = { PoolMethodEnum.starmap: run_pool_star_map, PoolMethodEnum.apply_async: run_pool_async, } PARALLEL_BACKEND_MODULES = { ParallelBackendEnum.multiprocessing: get_multiprocessing, ParallelBackendEnum.ray: get_multiprocessing_ray, }