# Licensed under a 3-clause BSD style license - see LICENSE.rst
"""Session class driving the high-level interface API"""
import logging
from astropy.coordinates import SkyCoord
from astropy.table import Table
from regions import CircleSkyRegion
import yaml
from gammapy.analysis.config import AnalysisConfig
from gammapy.cube import MapDataset, MapDatasetMaker, SafeMaskMaker
from gammapy.data import DataStore
from gammapy.maps import Map, MapAxis, WcsGeom
from gammapy.modeling import Datasets, Fit
from gammapy.modeling.models import Models
from gammapy.modeling.serialize import dict_to_models
from gammapy.spectrum import (
FluxPointsDataset,
FluxPointsEstimator,
ReflectedRegionsBackgroundMaker,
SpectrumDataset,
SpectrumDatasetMaker,
)
from gammapy.utils.scripts import make_path
__all__ = ["Analysis"]
log = logging.getLogger(__name__)
[docs]class Analysis:
"""Config-driven high-level analysis interface.
It is initialized by default with a set of configuration parameters and values declared in
an internal high-level interface model, though the user can also provide configuration
parameters passed as a nested dictionary at the moment of instantiation. In that case these
parameters will overwrite the default values of those present in the configuration file.
For more info see :ref:`analysis`.
Parameters
----------
config : dict or `AnalysisConfig`
Configuration options following `AnalysisConfig` schema
"""
def __init__(self, config):
self.config = config
self.config.set_logging()
self.datastore = None
self.observations = None
self.datasets = None
self.models = None
self.fit = None
self.fit_result = None
self.flux_points = None
@property
def config(self):
"""Analysis configuration (`AnalysisConfig`)"""
return self._config
@config.setter
def config(self, value):
if isinstance(value, dict):
self._config = AnalysisConfig(**value)
elif isinstance(value, AnalysisConfig):
self._config = value
else:
raise TypeError("config must be dict or AnalysisConfig.")
[docs] def get_observations(self):
"""Fetch observations from the data store according to criteria defined in the configuration."""
path = make_path(self.config.observations.datastore)
if path.is_file():
self.datastore = DataStore.from_file(path)
elif path.is_dir():
self.datastore = DataStore.from_dir(path)
else:
raise FileNotFoundError(f"Datastore not found: {path}")
log.info("Fetching observations.")
observations_settings = self.config.observations
if (
len(observations_settings.obs_ids)
and observations_settings.obs_file is not None
):
raise ValueError(
"Values for both parameters obs_ids and obs_file are not accepted."
)
elif (
not len(observations_settings.obs_ids)
and observations_settings.obs_file is None
):
obs_list = self.datastore.get_observations()
ids = [obs.obs_id for obs in obs_list]
elif len(observations_settings.obs_ids):
obs_list = self.datastore.get_observations(observations_settings.obs_ids)
ids = [obs.obs_id for obs in obs_list]
else:
path = make_path(self.config.observations.obs_file)
ids = list(Table.read(path, format="ascii", data_start=0).columns[0])
if observations_settings.obs_cone.lon is not None:
cone = dict(
type="sky_circle",
frame=observations_settings.obs_cone.frame,
lon=observations_settings.obs_cone.lon,
lat=observations_settings.obs_cone.lat,
radius=observations_settings.obs_cone.radius,
border="0 deg",
)
selected_cone = self.datastore.obs_table.select_observations(cone)
ids = list(set(ids) & set(selected_cone["OBS_ID"].tolist()))
self.observations = self.datastore.get_observations(ids, skip_missing=True)
if self.config.observations.obs_time.start is not None:
start = self.config.observations.obs_time.start
stop = self.config.observations.obs_time.stop
self.observations = self.observations.select_time([(start, stop)])
log.info(f"Number of selected observations: {len(self.observations)}")
for obs in self.observations:
log.debug(obs)
[docs] def get_datasets(self):
"""Produce reduced datasets."""
if not self.observations or len(self.observations) == 0:
raise RuntimeError("No observations have been selected.")
if self.config.datasets.type == "1d":
self._spectrum_extraction()
elif self.config.datasets.type == "3d":
self._map_making()
else:
ValueError(f"Invalid dataset type: {self.config.datasets.type}")
[docs] def set_models(self, models):
"""Set models on datasets.
Parameters
----------
models : `~gammapy.modeling.models.Models` or str
Models object or YAML models string
"""
if not self.datasets or len(self.datasets) == 0:
raise RuntimeError("Missing datasets")
log.info(f"Reading model.")
if isinstance(models, str):
# FIXME: Models should offer a method to create from YAML str
models = yaml.safe_load(models)
self.models = Models(dict_to_models(models))
elif isinstance(models, Models):
self.models = models
else:
raise TypeError(f"Invalid type: {models!r}")
for dataset in self.datasets:
dataset.models = self.models
log.info(self.models)
[docs] def read_models(self, path):
"""Read models from YAML file."""
path = make_path(path)
models = Models.read(path)
self.set_models(models)
[docs] def run_fit(self, optimize_opts=None):
"""Fitting reduced datasets to model."""
if not self.models:
raise RuntimeError("Missing models")
fit_settings = self.config.fit
for dataset in self.datasets:
if fit_settings.fit_range:
e_min = fit_settings.fit_range.min
e_max = fit_settings.fit_range.max
if isinstance(dataset, MapDataset):
dataset.mask_fit = dataset.counts.geom.energy_mask(e_min, e_max)
else:
dataset.mask_fit = dataset.counts.energy_mask(e_min, e_max)
log.info("Fitting datasets.")
self.fit = Fit(self.datasets)
self.fit_result = self.fit.run(optimize_opts=optimize_opts)
log.info(self.fit_result)
[docs] def get_flux_points(self):
"""Calculate flux points for a specific model component."""
if not self.fit:
raise RuntimeError("No results available from Fit.")
fp_settings = self.config.flux_points
log.info("Calculating flux points.")
e_edges = self._make_energy_axis(fp_settings.energy).edges
flux_point_estimator = FluxPointsEstimator(
e_edges=e_edges,
datasets=self.datasets,
source=fp_settings.source,
**fp_settings.params,
)
fp = flux_point_estimator.run()
fp.table["is_ul"] = fp.table["ts"] < 4
self.flux_points = FluxPointsDataset(
data=fp, models=self.models[fp_settings.source]
)
cols = ["e_ref", "ref_flux", "dnde", "dnde_ul", "dnde_err", "is_ul"]
log.info("\n{}".format(self.flux_points.data.table[cols]))
[docs] def update_config(self, config):
self.config = self.config.update(config=config)
def _create_geometry(self):
"""Create the geometry."""
geom_params = {}
geom_settings = self.config.datasets.geom
skydir_settings = geom_settings.wcs.skydir
if skydir_settings.lon is not None:
skydir = SkyCoord(
skydir_settings.lon, skydir_settings.lat, frame=skydir_settings.frame
)
geom_params["skydir"] = skydir
if skydir_settings.frame == "icrs":
geom_params["frame"] = "icrs"
if skydir_settings.frame == "galactic":
geom_params["frame"] = "galactic"
axes = [self._make_energy_axis(geom_settings.axes.energy)]
geom_params["axes"] = axes
geom_params["binsz"] = geom_settings.wcs.binsize
width = geom_settings.wcs.fov.width.to("deg").value
height = geom_settings.wcs.fov.height.to("deg").value
geom_params["width"] = (width, height)
return WcsGeom.create(**geom_params)
def _map_making(self):
"""Make maps and datasets for 3d analysis."""
log.info("Creating geometry.")
geom = self._create_geometry()
geom_settings = self.config.datasets.geom
geom_irf = dict(energy_axis_true=None, binsz_irf=None)
if geom_settings.axes.energy_true.min is not None:
geom_irf["energy_axis_true"] = self._make_energy_axis(
geom_settings.axes.energy_true
)
geom_irf["binsz_irf"] = geom_settings.wcs.binsize_irf.to("deg").value
offset_max = geom_settings.selection.offset_max
log.info("Creating datasets.")
maker = MapDatasetMaker(selection=self.config.datasets.map_selection)
safe_mask_selection = self.config.datasets.safe_mask.methods
safe_mask_settings = self.config.datasets.safe_mask.settings
maker_safe_mask = SafeMaskMaker(
methods=safe_mask_selection, **safe_mask_settings
)
stacked = MapDataset.create(geom=geom, name="stacked", **geom_irf)
if self.config.datasets.stack:
for obs in self.observations:
log.info(f"Processing observation {obs.obs_id}")
cutout = stacked.cutout(obs.pointing_radec, width=2 * offset_max)
dataset = maker.run(cutout, obs)
dataset = maker_safe_mask.run(dataset, obs)
log.debug(dataset)
stacked.stack(dataset)
datasets = [stacked]
else:
datasets = []
for obs in self.observations:
log.info(f"Processing observation {obs.obs_id}")
cutout = stacked.cutout(obs.pointing_radec, width=2 * offset_max)
dataset = maker.run(cutout, obs)
dataset = maker_safe_mask.run(dataset, obs)
log.debug(dataset)
datasets.append(dataset)
self.datasets = Datasets(datasets)
def _spectrum_extraction(self):
"""Run all steps for the spectrum extraction."""
log.info("Reducing spectrum datasets.")
datasets_settings = self.config.datasets
on_lon = datasets_settings.on_region.lon
on_lat = datasets_settings.on_region.lat
on_center = SkyCoord(on_lon, on_lat, frame=datasets_settings.on_region.frame)
on_region = CircleSkyRegion(on_center, datasets_settings.on_region.radius)
maker_config = {}
if datasets_settings.containment_correction:
maker_config[
"containment_correction"
] = datasets_settings.containment_correction
e_reco = self._make_energy_axis(datasets_settings.geom.axes.energy).edges
maker_config["selection"] = ["counts", "aeff", "edisp"]
dataset_maker = SpectrumDatasetMaker(**maker_config)
bkg_maker_config = {}
if datasets_settings.background.exclusion:
exclusion_region = Map.read(datasets_settings.background.exclusion)
bkg_maker_config["exclusion_mask"] = exclusion_region
bkg_maker = ReflectedRegionsBackgroundMaker(**bkg_maker_config)
safe_mask_selection = self.config.datasets.safe_mask.methods
safe_mask_settings = self.config.datasets.safe_mask.settings
safe_mask_maker = SafeMaskMaker(
methods=safe_mask_selection, **safe_mask_settings
)
e_true = self._make_energy_axis(datasets_settings.geom.axes.energy_true).edges
reference = SpectrumDataset.create(
e_reco=e_reco, e_true=e_true, region=on_region
)
datasets = []
for obs in self.observations:
log.info(f"Processing observation {obs.obs_id}")
dataset = dataset_maker.run(reference.copy(), obs)
dataset = bkg_maker.run(dataset, obs)
if dataset.counts_off is None:
log.info(
f"No OFF region found for observation {obs.obs_id}. Discarding."
)
continue
dataset = safe_mask_maker.run(dataset, obs)
log.debug(dataset)
datasets.append(dataset)
self.datasets = Datasets(datasets)
if self.config.datasets.stack:
stacked = self.datasets.stack_reduce(name="stacked")
self.datasets = Datasets([stacked])
@staticmethod
def _make_energy_axis(axis):
return MapAxis.from_bounds(
name="energy",
lo_bnd=axis.min.value,
hi_bnd=axis.max.to_value(axis.min.unit),
nbin=axis.nbins,
unit=axis.min.unit,
interp="log",
node_type="edges",
)