# Licensed under a 3-clause BSD style license - see LICENSE.rst
import abc
import collections.abc
import copy
import numpy as np
from gammapy.utils.scripts import make_name, make_path, read_yaml, write_yaml
from gammapy.utils.table import table_from_row_data
from ..maps import WcsNDMap
from .parameter import Parameters
__all__ = ["Dataset", "Datasets"]
[docs]class Dataset(abc.ABC):
"""Dataset abstract base class.
TODO: add tutorial how to create your own dataset types.
For now, see existing examples in Gammapy how this works:
- `gammapy.cube.MapDataset`
- `gammapy.spectrum.SpectrumDataset`
- `gammapy.spectrum.FluxPointsDataset`
"""
_residuals_labels = {
"diff": "data - model",
"diff/model": "(data - model) / model",
"diff/sqrt(model)": "(data - model) / sqrt(model)",
}
@property
def mask(self):
"""Combined fit and safe mask"""
mask_safe = (
self.mask_safe.data
if isinstance(self.mask_safe, WcsNDMap)
else self.mask_safe
)
mask_fit = (
self.mask_fit.data if isinstance(self.mask_fit, WcsNDMap) else self.mask_fit
)
if mask_safe is not None and mask_fit is not None:
mask = mask_safe & mask_fit
elif mask_fit is not None:
mask = mask_fit
elif mask_safe is not None:
mask = mask_safe
else:
mask = None
return mask
[docs] def stat_sum(self):
"""Total statistic given the current model parameters."""
stat = self.stat_array()
if self.mask is not None:
stat = stat[self.mask]
return np.sum(stat, dtype=np.float64)
[docs] @abc.abstractmethod
def stat_array(self):
"""Statistic array, one value per data point."""
[docs] def copy(self, name=None):
"""A deep copy."""
new = copy.deepcopy(self)
new._name = make_name(name)
return new
@staticmethod
def _compute_residuals(data, model, method="diff"):
with np.errstate(invalid="ignore"):
if method == "diff":
residuals = data - model
elif method == "diff/model":
residuals = (data - model) / model
elif method == "diff/sqrt(model)":
residuals = (data - model) / np.sqrt(model)
else:
raise AttributeError(
f"Invalid method: {method!r}. Choose between 'diff',"
" 'diff/model' and 'diff/sqrt(model)'"
)
return residuals
[docs]class Datasets(collections.abc.MutableSequence):
"""Dataset collection.
Parameters
----------
datasets : `Dataset` or list of `Dataset`
Datasets
"""
def __init__(self, datasets):
if isinstance(datasets, Datasets):
datasets = list(datasets)
dataset_list = datasets
elif isinstance(datasets, list):
dataset_list = []
for data in datasets:
if isinstance(data, Datasets):
dataset_list += list(data)
elif isinstance(data, Dataset):
dataset_list.append(data)
else:
raise TypeError(f"Invalid type: {datasets!r}")
unique_names = []
for dataset in dataset_list:
if dataset.name in unique_names:
raise (ValueError("Dataset names must be unique"))
unique_names.append(dataset.name)
self._datasets = datasets
@property
def parameters(self):
"""Unique parameters (`~gammapy.modeling.Parameters`).
Duplicate parameter objects have been removed.
The order of the unique parameters remains.
"""
parameters = Parameters.from_stack(_.parameters for _ in self)
return parameters.unique_parameters
@property
def names(self):
return [d.name for d in self._datasets]
@property
def is_all_same_type(self):
"""Whether all contained datasets are of the same type"""
return len(set(_.__class__ for _ in self)) == 1
@property
def is_all_same_shape(self):
"""Whether all contained datasets have the same data shape"""
return len(set(_.data_shape for _ in self)) == 1
[docs] def stat_sum(self):
"""Compute joint likelihood"""
stat_sum = 0
# TODO: add parallel evaluation of likelihoods
for dataset in self:
stat_sum += dataset.stat_sum()
return stat_sum
def __str__(self):
str_ = self.__class__.__name__ + "\n"
str_ += "--------\n"
for idx, dataset in enumerate(self):
str_ += f"idx={idx}, id={hex(id(dataset))!r}, name={dataset.name!r}\n"
return str_
[docs] def copy(self):
"""A deep copy."""
return copy.deepcopy(self)
[docs] @classmethod
def read(cls, filedata, filemodel):
"""De-serialize datasets from YAML and FITS files.
Parameters
----------
filedata : str
filepath to yaml datasets file
filemodel : str
filepath to yaml models file
Returns
-------
dataset : 'gammapy.modeling.Datasets'
Datasets
"""
from .serialize import dict_to_datasets
components = read_yaml(make_path(filemodel))
data_list = read_yaml(make_path(filedata))
datasets = dict_to_datasets(data_list, components)
return cls(datasets)
[docs] def write(self, path, prefix="", overwrite=False):
"""Serialize datasets to YAML and FITS files.
Parameters
----------
path : `pathlib.Path`
path to write files
prefix : str
common prefix of file names
overwrite : bool
overwrite datasets FITS files
"""
from .serialize import datasets_to_dict
path = make_path(path)
datasets_dict, components_dict = datasets_to_dict(self, path, prefix, overwrite)
write_yaml(datasets_dict, path / f"{prefix}_datasets.yaml", sort_keys=False)
write_yaml(components_dict, path / f"{prefix}_models.yaml", sort_keys=False)
[docs] def stack_reduce(self, name=None):
"""Reduce the Datasets to a unique Dataset by stacking them together.
This works only if all Dataset are of the same type and if a proper
in-place stack method exists for the Dataset type.
Returns
-------
dataset : ~gammapy.utils.Dataset
the stacked dataset
"""
if not self.is_all_same_type:
raise ValueError(
"Stacking impossible: all Datasets contained are not of a unique type."
)
dataset = self[0].copy(name=name)
for ds in self[1:]:
dataset.stack(ds)
return dataset
[docs] def info_table(self, cumulative=False):
"""Get info table for datasets.
Parameters
----------
cumulative : bool
Cumulate info across all observations
Returns
-------
info_table : `~astropy.table.Table`
Info table.
"""
if not self.is_all_same_type:
raise ValueError("Info table not supported for mixed dataset type.")
stacked = self[0].copy(name="stacked")
rows = [stacked.info_dict()]
for dataset in self[1:]:
if cumulative:
stacked.stack(dataset)
row = stacked.info_dict()
else:
row = dataset.info_dict()
rows.append(row)
return table_from_row_data(rows=rows)
def __getitem__(self, key):
return self._datasets[self._get_idx(key)]
def __delitem__(self, key):
del self._datasets[self._get_idx(key)]
def __setitem__(self, key, dataset):
if isinstance(dataset, Dataset):
if dataset.name in self.names:
raise (ValueError("Dataset names must be unique"))
self._datasets[self._get_idx(key)] = dataset
else:
raise TypeError(f"Invalid type: {type(dataset)!r}")
[docs] def insert(self, idx, dataset):
if isinstance(dataset, Dataset):
if dataset.name in self.names:
raise (ValueError("Dataset names must be unique"))
self._datasets.insert(idx, dataset)
else:
raise TypeError(f"Invalid type: {type(dataset)!r}")
def _get_idx(self, key):
if isinstance(key, (int, slice)):
return key
elif isinstance(key, str):
for idx, dataset in enumerate(self._datasets):
if key == dataset.name:
return idx
raise IndexError(f"No dataset: {key!r}")
else:
raise TypeError(f"Invalid type: {type(key)!r}")
def __len__(self):
return len(self._datasets)