# Licensed under a 3-clause BSD style license - see LICENSE.rst
import logging
from astropy.coordinates import Angle
from astropy.nddata import NoOverlapError
import gammapy.utils.parallel as parallel
from gammapy.datasets import Datasets, MapDataset, MapDatasetOnOff, SpectrumDataset
from .core import Maker
from .safe import SafeMaskMaker
log = logging.getLogger(__name__)
__all__ = [
"DatasetsMaker",
]
[docs]
class DatasetsMaker(Maker, parallel.ParallelMixin):
"""Run makers in a chain.
Parameters
----------
makers : list of `Maker` objects
Makers.
stack_datasets : bool, optional
If True, stack into the reference dataset (see `run` method arguments).
Default is True.
n_jobs : int, optional
Number of processes to run in parallel.
Default is one, unless `~gammapy.utils.parallel.N_JOBS_DEFAULT` was modified.
cutout_mode : {'trim', 'partial', 'strict'}
Used only to cutout the reference `MapDataset` around each processed observation.
Mode is an option for Cutout2D, for details see `~astropy.nddata.utils.Cutout2D`.
Default is "trim".
cutout_width : tuple of `~astropy.coordinates.Angle`, optional
Angular sizes of the region in (lon, lat) in that specific order.
If only one value is passed, a square region is extracted.
If None it returns an error, except if the list of makers includes a `SafeMaskMaker`
with the offset-max method defined. In that case it is set to two times `offset_max`.
Default is None.
parallel_backend : {'multiprocessing', 'ray'}, optional
Which backend to use for multiprocessing.
Default is None.
"""
tag = "DatasetsMaker"
[docs]
def __init__(
self,
makers,
stack_datasets=True,
n_jobs=None,
cutout_mode="trim",
cutout_width=None,
parallel_backend=None,
):
self.log = logging.getLogger(__name__)
self.makers = makers
self.cutout_mode = cutout_mode
if cutout_width is not None:
cutout_width = Angle(cutout_width)
self.cutout_width = cutout_width
self._apply_cutout = True
if self.cutout_width is None:
if self.offset_max is None:
self._apply_cutout = False
else:
self.cutout_width = 2 * self.offset_max
self.n_jobs = n_jobs
self.parallel_backend = parallel_backend
self.stack_datasets = stack_datasets
self._datasets = []
self._error = False
@property
def offset_max(self):
maker = self.safe_mask_maker
if maker is not None and hasattr(maker, "offset_max"):
return maker.offset_max
@property
def safe_mask_maker(self):
for m in self.makers:
if isinstance(m, SafeMaskMaker):
return m
[docs]
def make_dataset(self, dataset, observation):
"""Make single dataset.
Parameters
----------
dataset : `~gammapy.datasets.MapDataset`
Reference dataset.
observation : `Observation`
Observation.
"""
obs_point = observation.get_pointing_icrs(observation.tmid)
if self._apply_cutout:
cutouts_kwargs = {
"position": obs_point.galactic,
"width": self.cutout_width,
"mode": self.cutout_mode,
}
try:
dataset_obs = dataset.cutout(
**cutouts_kwargs,
)
except NoOverlapError:
log.warning(
f"{observation.obs_id} discarded because"
" it is outside the FOV of the reference geom."
)
return None
else:
dataset_obs = dataset.copy()
if dataset.models is not None:
models = dataset.models.copy()
models.reassign(dataset.name, dataset_obs.name)
dataset_obs.models = models
log.info(f"Computing dataset for observation {observation.obs_id}")
for maker in self.makers:
log.info(f"Running {maker.tag}")
dataset_obs = maker.run(dataset=dataset_obs, observation=observation)
return dataset_obs
[docs]
def callback(self, dataset):
if self.stack_datasets and dataset is not None:
if type(self._dataset) is MapDataset and type(dataset) is MapDatasetOnOff:
dataset = dataset.to_map_dataset(name=dataset.name)
self._dataset.stack(dataset)
else:
self._datasets.append(dataset)
[docs]
def error_callback(self, dataset):
# parallel run could cause a memory error with non-explicit message.
self._error = True
[docs]
def run(self, dataset, observations, datasets=None):
"""Run data reduction.
Parameters
----------
dataset : `~gammapy.datasets.MapDataset`
Reference dataset (used only for stacking if datasets are provided).
observations : `Observations`
Observations.
datasets : `~gammapy.datasets.Datasets`
Base datasets, if provided its length must be the same as the observations.
Returns
-------
datasets : `~gammapy.datasets.Datasets`
Datasets.
"""
if isinstance(dataset, MapDataset):
# also valid for Spectrum as it inherits from MapDataset
self._dataset = dataset
else:
raise TypeError("Invalid reference dataset.")
if isinstance(dataset, SpectrumDataset):
self._apply_cutout = False
if datasets is not None:
self._apply_cutout = False
else:
datasets = len(observations) * [dataset]
n_jobs = min(self.n_jobs, len(observations))
parallel.run_multiprocessing(
self.make_dataset,
zip(datasets, observations),
backend=self.parallel_backend,
pool_kwargs=dict(processes=n_jobs),
method="apply_async",
method_kwargs=dict(
callback=self.callback,
error_callback=self.error_callback,
),
task_name="Data reduction",
)
if self._error:
raise RuntimeError("Execution of a sub-process failed")
if self.stack_datasets:
return Datasets([self._dataset])
valid_datasets = [d for d in self._datasets if d is not None]
lookup = {
d.meta_table["OBS_ID"][0]: idx for idx, d in enumerate(valid_datasets)
}
return Datasets(
[
valid_datasets[lookup[obs.obs_id]]
for obs in observations
if obs.obs_id in lookup
]
)