Source code for gammapy.analysis.config

# Licensed under a 3-clause BSD style license - see LICENSE.rst
import json
import logging
from collections import defaultdict
from enum import Enum
from pathlib import Path
from typing import List
from astropy.coordinates import Angle
from astropy.time import Time
from astropy.units import Quantity
import yaml
from pydantic import BaseModel
from pydantic.utils import deep_update
from gammapy.makers import MapDatasetMaker
from gammapy.utils.scripts import make_path, read_yaml

__all__ = ["AnalysisConfig"]

CONFIG_PATH = Path(__file__).resolve().parent / "config"
DOCS_FILE = CONFIG_PATH / "docs.yaml"

log = logging.getLogger(__name__)


class AngleType(Angle):
    @classmethod
    def __get_validators__(cls):
        yield cls.validate

    @classmethod
    def validate(cls, v):
        return Angle(v)


class EnergyType(Quantity):
    @classmethod
    def __get_validators__(cls):
        yield cls.validate

    @classmethod
    def validate(cls, v):
        v = Quantity(v)
        if v.unit.physical_type != "energy":
            raise ValueError(f"Invalid unit for energy: {v.unit!r}")
        return v


class TimeType(Time):
    @classmethod
    def __get_validators__(cls):
        yield cls.validate

    @classmethod
    def validate(cls, v):
        return Time(v)


class ReductionTypeEnum(str, Enum):
    spectrum = "1d"
    cube = "3d"


class FrameEnum(str, Enum):
    icrs = "icrs"
    galactic = "galactic"


class RequiredIRFEnum(str, Enum):
    aeff = "aeff"
    bkg = "bkg"
    edisp = "edisp"
    psf = "psf"


class BackgroundMethodEnum(str, Enum):
    reflected = "reflected"
    fov = "fov_background"
    ring = "ring"


class SafeMaskMethodsEnum(str, Enum):
    aeff_default = "aeff-default"
    aeff_max = "aeff-max"
    edisp_bias = "edisp-bias"
    offset_max = "offset-max"
    bkg_peak = "bkg-peak"


class MapSelectionEnum(str, Enum):
    counts = "counts"
    exposure = "exposure"
    background = "background"
    psf = "psf"
    edisp = "edisp"


class GammapyBaseConfig(BaseModel):
    class Config:
        validate_all = True
        validate_assignment = True
        extra = "forbid"
        json_encoders = {
            Angle: lambda v: f"{v.value} {v.unit}",
            Quantity: lambda v: f"{v.value} {v.unit}",
            Time: lambda v: f"{v.value}",
        }


class SkyCoordConfig(GammapyBaseConfig):
    frame: FrameEnum = None
    lon: AngleType = None
    lat: AngleType = None


class EnergyAxisConfig(GammapyBaseConfig):
    min: EnergyType = None
    max: EnergyType = None
    nbins: int = None


class SpatialCircleConfig(GammapyBaseConfig):
    frame: FrameEnum = None
    lon: AngleType = None
    lat: AngleType = None
    radius: AngleType = None


class EnergyRangeConfig(GammapyBaseConfig):
    min: EnergyType = None
    max: EnergyType = None


class TimeRangeConfig(GammapyBaseConfig):
    start: TimeType = None
    stop: TimeType = None


class FluxPointsConfig(GammapyBaseConfig):
    energy: EnergyAxisConfig = EnergyAxisConfig()
    source: str = "source"
    parameters: dict = {"selection_optional": "all"}


class LightCurveConfig(GammapyBaseConfig):
    time_intervals: TimeRangeConfig = TimeRangeConfig()
    energy_edges: EnergyAxisConfig = EnergyAxisConfig()
    source: str = "source"
    parameters: dict = {"selection_optional": "all"}


class FitConfig(GammapyBaseConfig):
    fit_range: EnergyRangeConfig = EnergyRangeConfig()


class ExcessMapConfig(GammapyBaseConfig):
    correlation_radius: AngleType = "0.1 deg"
    parameters: dict = {}
    energy_edges: EnergyAxisConfig = EnergyAxisConfig()


class BackgroundConfig(GammapyBaseConfig):
    method: BackgroundMethodEnum = None
    exclusion: Path = None
    parameters: dict = {}


class SafeMaskConfig(GammapyBaseConfig):
    methods: List[SafeMaskMethodsEnum] = [SafeMaskMethodsEnum.aeff_default]
    parameters: dict = {}


class EnergyAxesConfig(GammapyBaseConfig):
    energy: EnergyAxisConfig = EnergyAxisConfig(min="1 TeV", max="10 TeV", nbins=5)
    energy_true: EnergyAxisConfig = EnergyAxisConfig(
        min="0.5 TeV", max="20 TeV", nbins=16
    )


class SelectionConfig(GammapyBaseConfig):
    offset_max: AngleType = "2.5 deg"


class WidthConfig(GammapyBaseConfig):
    width: AngleType = "5 deg"
    height: AngleType = "5 deg"


class WcsConfig(GammapyBaseConfig):
    skydir: SkyCoordConfig = SkyCoordConfig()
    binsize: AngleType = "0.02 deg"
    width: WidthConfig = WidthConfig()
    binsize_irf: AngleType = "0.2 deg"


class GeomConfig(GammapyBaseConfig):
    wcs: WcsConfig = WcsConfig()
    selection: SelectionConfig = SelectionConfig()
    axes: EnergyAxesConfig = EnergyAxesConfig()


class DatasetsConfig(GammapyBaseConfig):
    type: ReductionTypeEnum = ReductionTypeEnum.spectrum
    stack: bool = True
    geom: GeomConfig = GeomConfig()
    map_selection: List[MapSelectionEnum] = MapDatasetMaker.available_selection
    background: BackgroundConfig = BackgroundConfig()
    safe_mask: SafeMaskConfig = SafeMaskConfig()
    on_region: SpatialCircleConfig = SpatialCircleConfig()
    containment_correction: bool = True


class ObservationsConfig(GammapyBaseConfig):
    datastore: Path = Path("$GAMMAPY_DATA/hess-dl3-dr1/")
    obs_ids: List[int] = []
    obs_file: Path = None
    obs_cone: SpatialCircleConfig = SpatialCircleConfig()
    obs_time: TimeRangeConfig = TimeRangeConfig()
    required_irf: List[RequiredIRFEnum] = ["aeff", "edisp", "psf", "bkg"]


class LogConfig(GammapyBaseConfig):
    level: str = "info"
    filename: Path = None
    filemode: str = None
    format: str = None
    datefmt: str = None


class GeneralConfig(GammapyBaseConfig):
    log: LogConfig = LogConfig()
    outdir: str = "."


[docs]class AnalysisConfig(GammapyBaseConfig): """Gammapy analysis configuration.""" general: GeneralConfig = GeneralConfig() observations: ObservationsConfig = ObservationsConfig() datasets: DatasetsConfig = DatasetsConfig() fit: FitConfig = FitConfig() flux_points: FluxPointsConfig = FluxPointsConfig() excess_map: ExcessMapConfig = ExcessMapConfig() light_curve: LightCurveConfig = LightCurveConfig() def __str__(self): """Display settings in pretty YAML format.""" info = self.__class__.__name__ + "\n\n\t" data = self.to_yaml() data = data.replace("\n", "\n\t") info += data return info.expandtabs(tabsize=4)
[docs] @classmethod def read(cls, path): """Reads from YAML file.""" config = read_yaml(path) return AnalysisConfig(**config)
[docs] @classmethod def from_yaml(cls, config_str): """Create from YAML string.""" settings = yaml.safe_load(config_str) return AnalysisConfig(**settings)
[docs] def write(self, path, overwrite=False): """Write to YAML file.""" path = make_path(path) if path.exists() and not overwrite: raise IOError(f"File exists already: {path}") path.write_text(self.to_yaml())
[docs] def to_yaml(self): """Convert to YAML string.""" # Here using `dict()` instead of `json()` would be more natural. # We should change this once pydantic adds support for custom encoders # to `dict()`. See https://github.com/samuelcolvin/pydantic/issues/1043 config = json.loads(self.json()) return yaml.dump( config, sort_keys=False, indent=4, width=80, default_flow_style=None )
[docs] def set_logging(self): """Set logging config. Calls ``logging.basicConfig``, i.e. adjusts global logging state. """ self.general.log.level = self.general.log.level.upper() logging.basicConfig(**self.general.log.dict()) log.info("Setting logging config: {!r}".format(self.general.log.dict()))
[docs] def update(self, config=None): """Update config with provided settings. Parameters ---------- config : string dict or `AnalysisConfig` object Configuration settings provided in dict() syntax. """ if isinstance(config, str): other = AnalysisConfig.from_yaml(config) elif isinstance(config, AnalysisConfig): other = config else: raise TypeError(f"Invalid type: {config}") config_new = deep_update( self.dict(exclude_defaults=True), other.dict(exclude_defaults=True) ) return AnalysisConfig(**config_new)
@staticmethod def _get_doc_sections(): """Returns dict with commented docs from docs file""" doc = defaultdict(str) with open(DOCS_FILE) as f: for line in filter(lambda line: not line.startswith("---"), f): line = line.strip("\n") if line.startswith("# Section: "): keyword = line.replace("# Section: ", "") doc[keyword] += line + "\n" return doc