# Licensed under a 3-clause BSD style license - see LICENSE.rst"""Multiprocessing and multithreading setup."""importimportlibimportloggingfromenumimportEnumfromgammapy.utils.pbarimportprogress_barlog=logging.getLogger(__name__)__all__=["multiprocessing_manager","run_multiprocessing","BACKEND_DEFAULT","N_JOBS_DEFAULT","POOL_KWARGS_DEFAULT","METHOD_DEFAULT","METHOD_KWARGS_DEFAULT",]classParallelBackendEnum(Enum):"""Enum for parallel backend."""multiprocessing="multiprocessing"ray="ray"@classmethoddeffrom_str(cls,value):"""Get enum from string."""ifvalue=="ray"andnotis_ray_available():log.warning("Ray is not installed, falling back to multiprocessing backend")value="multiprocessing"returncls(value)classPoolMethodEnum(Enum):"""Enum for pool method."""starmap="starmap"apply_async="apply_async"BACKEND_DEFAULT=ParallelBackendEnum.multiprocessingN_JOBS_DEFAULT=1ALLOW_CHILD_JOBS=FalsePOOL_KWARGS_DEFAULT=dict(processes=N_JOBS_DEFAULT)METHOD_DEFAULT=PoolMethodEnum.starmapMETHOD_KWARGS_DEFAULT={}defget_multiprocessing():"""Get multiprocessing module."""importmultiprocessingreturnmultiprocessingdefget_multiprocessing_ray():"""Get multiprocessing module for ray backend."""importray.util.multiprocessingasmultiprocessinglog.warning("Gammapy support for parallelisation with ray is still a prototype and is not fully functional.")returnmultiprocessingdefis_ray_initialized():"""Check if ray is initialized."""try:fromrayimportis_initializedreturnis_initialized()exceptModuleNotFoundError:returnFalsedefis_ray_available():"""Check if ray is available."""try:importlib.import_module("ray")returnTrueexceptModuleNotFoundError:returnFalse
[docs]classmultiprocessing_manager:"""Context manager to update the default configuration for multiprocessing. Only the default configuration will be modified, if class arguments like `n_jobs` and `parallel_backend` are set they will overwrite the default configuration. Parameters ---------- backend : {'multiprocessing', 'ray'} Backend to use. pool_kwargs : dict Keyword arguments passed to the pool. The number of processes is limited to the number of physical CPUs. method : {'starmap', 'apply_async'} Pool method to use. method_kwargs : dict Keyword arguments passed to the method Examples -------- :: import gammapy.utils.parallel as parallel from gammapy.estimators import FluxPointsEstimator fpe = FluxPointsEstimator(energy_edges=[1, 3, 10] * u.TeV) with parallel.multiprocessing_manager( backend="multiprocessing", pool_kwargs=dict(processes=2), ): fpe.run(datasets) """def__init__(self,backend=None,pool_kwargs=None,method=None,method_kwargs=None):globalBACKEND_DEFAULT,POOL_KWARGS_DEFAULT,METHOD_DEFAULT,METHOD_KWARGS_DEFAULT,N_JOBS_DEFAULTself._backend=BACKEND_DEFAULTself._pool_kwargs=POOL_KWARGS_DEFAULTself._method=METHOD_DEFAULTself._method_kwargs=METHOD_KWARGS_DEFAULTself._n_jobs=N_JOBS_DEFAULTifbackendisnotNone:BACKEND_DEFAULT=ParallelBackendEnum.from_str(backend).valueifpool_kwargsisnotNone:POOL_KWARGS_DEFAULT=pool_kwargsN_JOBS_DEFAULT=pool_kwargs.get("processes",N_JOBS_DEFAULT)ifmethodisnotNone:METHOD_DEFAULT=PoolMethodEnum(method).valueifmethod_kwargsisnotNone:METHOD_KWARGS_DEFAULT=method_kwargsdef__enter__(self):passdef__exit__(self,type,value,traceback):globalBACKEND_DEFAULT,POOL_KWARGS_DEFAULT,METHOD_DEFAULT,METHOD_KWARGS_DEFAULT,N_JOBS_DEFAULTBACKEND_DEFAULT=self._backendPOOL_KWARGS_DEFAULT=self._pool_kwargsMETHOD_DEFAULT=self._methodMETHOD_KWARGS_DEFAULT=self._method_kwargsN_JOBS_DEFAULT=self._n_jobs
classParallelMixin:"""Mixin class to handle parallel processing."""_n_child_jobs=1@propertydefn_jobs(self):"""Number of jobs as an integer."""# TODO: this is somewhat unusual behaviour. It deviates from a normal default value handlingifself._n_jobsisNone:returnN_JOBS_DEFAULTreturnself._n_jobs@n_jobs.setterdefn_jobs(self,value):"""Number of jobs setter as an integer."""ifnotisinstance(value,(int,type(None))):raiseValueError(f"Invalid type: {value!r}, and integer or None is expected.")self._n_jobs=valueifALLOW_CHILD_JOBS:self._n_child_jobs=valuedef_update_child_jobs(self):"""needed because we can update only in the main process otherwise global ALLOW_CHILD_JOBS has default value"""ifALLOW_CHILD_JOBS:self._n_child_jobs=self.n_jobselse:self._n_child_jobs=1@propertydef_get_n_child_jobs(self):"""Number of allowed child jobs as an integer."""returnself._n_child_jobs@propertydefparallel_backend(self):"""Parallel backend as a string."""ifself._parallel_backendisNone:returnBACKEND_DEFAULTreturnself._parallel_backend@parallel_backend.setterdefparallel_backend(self,value):"""Parallel backend setter (str)"""ifvalueisNone:self._parallel_backend=Noneelse:self._parallel_backend=ParallelBackendEnum.from_str(value).value
[docs]defrun_multiprocessing(func,inputs,backend=None,pool_kwargs=None,method=None,method_kwargs=None,task_name="",):"""Run function in a loop or in Parallel. Notes ----- The progress bar can be displayed for this function. Parameters ---------- func : function Function to run. inputs : list List of arguments to pass to the function. backend : {'multiprocessing', 'ray'}, optional Backend to use. Default is None. pool_kwargs : dict, optional Keyword arguments passed to the pool. The number of processes is limited to the number of physical CPUs. Default is None. method : {'starmap', 'apply_async'} Pool method to use. Default is "starmap". method_kwargs : dict, optional Keyword arguments passed to the method. Default is None. task_name : str, optional Name of the task to display in the progress bar. Default is "". """ifbackendisNone:backend=BACKEND_DEFAULTifmethodisNone:method=METHOD_DEFAULTifmethod_kwargsisNone:method_kwargs=METHOD_KWARGS_DEFAULTifpool_kwargsisNone:pool_kwargs=POOL_KWARGS_DEFAULTprocesses=pool_kwargs.get("processes",N_JOBS_DEFAULT)backend=ParallelBackendEnum.from_str(backend)multiprocessing=PARALLEL_BACKEND_MODULES[backend]()ifbackend==ParallelBackendEnum.multiprocessing:cpu_count=multiprocessing.cpu_count()ifprocesses>cpu_count:log.info(f"Limiting number of processes from {processes} to {cpu_count}")processes=cpu_countifmultiprocessing.current_process().name!="MainProcess":# with multiprocessing subprocesses cannot have childs (but possible with ray)processes=1ifprocesses==1:returnrun_loop(func=func,inputs=inputs,method_kwargs=method_kwargs,task_name=task_name)ifbackend==ParallelBackendEnum.ray:address="auto"ifis_ray_initialized()elseNonepool_kwargs.setdefault("ray_address",address)log.info(f"Using {processes} processes to compute {task_name}")withmultiprocessing.Pool(**pool_kwargs)aspool:pool_func=POOL_METHODS[PoolMethodEnum(method)]results=pool_func(pool=pool,func=func,inputs=inputs,method_kwargs=method_kwargs,task_name=task_name,)returnresults
defrun_loop(func,inputs,method_kwargs=None,task_name=""):"""Loop over inputs and run function."""results=[]callback=method_kwargs.get("callback",None)forargumentsinprogress_bar(inputs,desc=task_name):result=func(*arguments)ifcallbackisnotNone:result=callback(result)results.append(result)returnresultsdefrun_pool_star_map(pool,func,inputs,method_kwargs=None,task_name=""):"""Run function in parallel."""returnpool.starmap(func,progress_bar(inputs,desc=task_name),**method_kwargs)defrun_pool_async(pool,func,inputs,method_kwargs=None,task_name=""):"""Run function in parallel async."""results=[]forargumentsinprogress_bar(inputs,desc=task_name):result=pool.apply_async(func,arguments,**method_kwargs)results.append(result)# wait async run is done[result.wait()forresultinresults]returnresultsPOOL_METHODS={PoolMethodEnum.starmap:run_pool_star_map,PoolMethodEnum.apply_async:run_pool_async,}PARALLEL_BACKEND_MODULES={ParallelBackendEnum.multiprocessing:get_multiprocessing,ParallelBackendEnum.ray:get_multiprocessing_ray,}