Source code for gammapy.spectrum.results

# Licensed under a 3-clause BSD style license - see LICENSE.rst
from __future__ import absolute_import, division, print_function, unicode_literals
import numpy as np
from astropy.table import Table, Column
import astropy.units as u
from ..spectrum import CountsSpectrum, models
from ..utils.scripts import read_yaml, make_path
from ..utils.energy import EnergyBounds

__all__ = ["SpectrumFitResult", "SpectrumResult"]


[docs]class SpectrumFitResult(object): """Result of a `~gammapy.spectrum.SpectrumFit`. All fit results should be accessed via this class. Parameters ---------- model : `~gammapy.spectrum.models.SpectralModel` Best-fit model fit_range : `~astropy.units.Quantity` Energy range of the spectral fit statname : str, optional Statistic used for the fit statval : float, optional Final fit statistic stat_per_bin : float, optional Fit statistic value per bin npred : array-like, optional Counts predicted by the fit obs : `~gammapy.spectrum.SpectrumObservation` Input data used for the fit """ __slots__ = [ "model", "fit_range", "statname", "statval", "stat_per_bin", "npred", "obs", ] def __init__( self, model, fit_range=None, statname=None, statval=None, stat_per_bin=None, npred=None, obs=None, ): self.model = model self.fit_range = fit_range self.statname = statname self.statval = statval self.stat_per_bin = stat_per_bin self.npred = npred self.obs = obs
[docs] @classmethod def from_yaml(cls, filename): """Create from YAML file. Parameters ---------- filename : str, Path File to read """ filename = make_path(filename) val = read_yaml(str(filename)) return cls.from_dict(val)
[docs] def to_yaml(self, filename, mode="w"): """Write to YAML file. Parameters ---------- filename : str File to write mode : str Write mode """ import yaml d = self.to_dict() val = yaml.safe_dump(d, default_flow_style=False) with open(str(filename), mode) as outfile: outfile.write(val)
[docs] def to_dict(self): """Convert to dict.""" val = dict() val["model"] = self.model.to_dict() if self.fit_range is not None: val["fit_range"] = dict( min=self.fit_range[0].value, max=self.fit_range[1].value, unit=self.fit_range.unit.to_string("fits"), ) if self.statval is not None: val["statval"] = float(self.statval) if self.statname is not None: val["statname"] = self.statname return val
[docs] @classmethod def from_dict(cls, val): """Create from dict.""" modeldict = val["model"] model = models.SpectralModel.from_dict(modeldict) try: erange = val["fit_range"] energy_range = u.Quantity([erange["min"], erange["max"]], erange["unit"]) except KeyError: energy_range = None return cls(model=model, fit_range=energy_range)
# TODO: rather add this to Parameters?
[docs] def to_table(self, energy_unit="TeV", flux_unit="cm-2 s-1 TeV-1", **kwargs): """Convert to `~astropy.table.Table`. Produce overview table containing the most important parameters """ t = Table() t["model"] = [self.model.__class__.__name__] for par_name, value in self.model.parameters._ufloats.items(): val = value.n err = value.s # Apply correction factor for units # TODO: Refactor current_unit = u.Unit(self.model.parameters[par_name].unit) if current_unit.is_equivalent(energy_unit): factor = current_unit.to(energy_unit) col_unit = energy_unit elif current_unit.is_equivalent(1 / u.Unit(energy_unit)): factor = current_unit.to(1 / u.Unit(energy_unit)) col_unit = 1 / u.Unit(energy_unit) elif current_unit.is_equivalent(flux_unit): factor = current_unit.to(flux_unit) col_unit = flux_unit elif current_unit.is_equivalent(u.dimensionless_unscaled): factor = 1 col_unit = current_unit else: raise ValueError(current_unit) t[par_name] = Column( data=np.atleast_1d(val * factor), unit=col_unit, **kwargs ) t["{}_err".format(par_name)] = Column( data=np.atleast_1d(err * factor), unit=col_unit, **kwargs ) t["fit_range"] = Column( data=[self.fit_range.to(energy_unit)], unit=energy_unit, **kwargs ) return t
def __str__(self): s = "\nFit result info \n" s += "--------------- \n" s += "Model: {} \n".format(self.model) if self.statval is not None: s += "\nStatistic: {0:.3f} ({1})".format(self.statval, self.statname) if self.fit_range is not None: s += "\nFit Range: {}".format(self.fit_range) s += "\n" return s
[docs] def butterfly(self, energy=None, flux_unit="TeV-1 cm-2 s-1"): """Compute butterfly table. Parameters ---------- energy : `~astropy.units.Quantity`, optional Energies at which to evaluate the butterfly. flux_unit : str Flux unit for the butterfly. Returns ------- table : `~astropy.table.Table` Butterfly info in table (cols: 'energy', 'flux', 'flux_lo', 'flux_hi') """ if energy is None: energy = EnergyBounds.equal_log_spacing( self.fit_range[0], self.fit_range[1], 100 ) flux, flux_err = self.model.evaluate_error(energy) table = Table() table["energy"] = energy table["flux"] = flux.to(flux_unit) table["flux_lo"] = flux - flux_err.to(flux_unit) table["flux_hi"] = flux + flux_err.to(flux_unit) return table
@property def expected_source_counts(self): """Predicted source counts (`~gammapy.spectrum.CountsSpectrum`).""" energy = self.obs.on_vector.energy data = self.npred return CountsSpectrum(data=data, energy_lo=energy.lo, energy_hi=energy.hi) # TODO: is this the quantity, and sign, we want for residuals? @property def residuals(self): """Residuals (predicted source - excess). """ resspec = self.expected_source_counts.copy() resspec.data.data -= self.obs.excess_vector.data.data return resspec
[docs] def plot(self, **kwargs): """Plot counts and residuals in two panels. Calls ``plot_counts`` and ``plot_residuals``. """ ax0, ax1 = get_plot_axis(**kwargs) self.plot_counts(ax0) self.plot_residuals(ax1) return ax0, ax1
[docs] def plot_counts(self, ax): """Plot predicted and detected counts.""" self.expected_source_counts.plot(ax=ax, label="mu_src") self.obs.excess_vector.plot(ax=ax, label="excess", fmt=".", energy_unit="TeV") ax.axvline( self.fit_range.to_value("TeV")[0], color="black", linestyle="dashed", label="fit range", ) ax.axvline(self.fit_range.to_value("TeV")[1], color="black", linestyle="dashed") ax.legend(numpoints=1) ax.set_title("")
[docs] def plot_residuals(self, ax): """Plot residuals.""" self.residuals.plot(ax=ax, ecolor="black", fmt="none") ax.axhline(color="black") ymax = 1.4 * max(self.residuals.data.data.value) ax.set_ylim(-ymax, ymax) ax.set_xlabel("Energy [{}]".format("TeV")) ax.set_ylabel("ON (Predicted - Detected)")
[docs]class SpectrumResult(object): """Spectrum analysis results. Contains best fit model and flux points. Parameters ---------- model : `~gammapy.spectrum.models.SpectralModel` Best Fit model points : `~gammapy.spectrum.FluxPoints`, optional Flux points """ def __init__(self, model, points): self.model = model self.points = points @property def flux_point_residuals(self): """Residuals. Defined as ``(points - model) / model`` Returns ------- residuals : `numpy.ndarray` Residuals residuals_err : `numpy.ndarray` Residuals error """ e_ref = self.points.table["e_ref"].quantity points = self.points.table["dnde"].quantity try: points_err = self.points.table["dnde_err"].quantity except KeyError: points_errp = self.points.table["dnde_errp"].quantity points_errn = self.points.table["dnde_errp"].quantity points_err = np.sqrt(points_errp * points_errn) model_val = self.model(e_ref) residuals = ((points - model_val) / model_val).to_value("") residuals_err = (points_err / model_val).to_value("") # Remove residuals for upper_limits residuals[self.points.is_ul] = np.nan residuals_err[self.points.is_ul] = np.nan return residuals, residuals_err
[docs] def plot( self, energy_range, energy_unit="TeV", flux_unit="cm-2 s-1 TeV-1", energy_power=0, fit_kwargs=dict(), butterfly_kwargs=dict(), point_kwargs=dict(), fig_kwargs=dict(), ): """Plot spectrum. Plot best fit model, flux points and residuals. Parameters ---------- energy_range : `~astropy.units.Quantity` Energy range for the plot energy_unit : str, `~astropy.units.Unit`, optional Unit of the energy axis flux_unit : str, `~astropy.units.Unit`, optional Unit of the flux axis energy_power : int Power of energy to multiply flux axis with fit_kwargs : dict, optional forwarded to :func:`gammapy.spectrum.models.SpectralModel.plot` butterfly_kwargs : dict, optional forwarded to :func:`gammapy.spectrum.models.SpectralModel.plot_error` point_kwargs : dict, optional forwarded to :func:`gammapy.spectrum.FluxPoints.plot` fig_kwargs : dict, optional forwarded to :func:`matplotlib.pyplot.figure` Returns ------- ax0 : `~matplotlib.axes.Axes` Spectrum plot axis ax1 : `~matplotlib.axes.Axes` Residuals plot axis """ ax0, ax1 = get_plot_axis(**fig_kwargs) ax0.set_yscale("log") common_kwargs = dict( energy_unit=energy_unit, flux_unit=flux_unit, energy_power=energy_power ) fit_kwargs.update(common_kwargs) point_kwargs.update(common_kwargs) butterfly_kwargs.update(common_kwargs) self.model.plot(energy_range=energy_range, ax=ax0, **fit_kwargs) self.model.plot_error(energy_range=energy_range, ax=ax0, **butterfly_kwargs) self.points.plot(ax=ax0, **point_kwargs) point_kwargs.pop("flux_unit") point_kwargs.pop("energy_power") ax0.set_xlabel("") self._plot_residuals(ax=ax1, **point_kwargs) return ax0, ax1
def _plot_residuals(self, ax=None, energy_unit="TeV", **kwargs): """Plot residuals. Parameters ---------- ax : `~matplotlib.axes.Axes`, optional Axis energy_unit : str, `~astropy.units.Unit`, optional Unit of the energy axis Returns ------- ax : `~matplotlib.axes.Axes`, optional Axis """ import matplotlib.pyplot as plt ax = plt.gca() if ax is None else ax kwargs.setdefault("fmt", ".") y, y_err = self.flux_point_residuals x = self.points.e_ref x = x.to_value(energy_unit) ax.errorbar(x, y, yerr=y_err, **kwargs) ax.axhline(0, color="black") ax.set_xlabel("Energy [{}]".format(energy_unit)) ax.set_ylabel("Residuals") return ax
def get_plot_axis(**kwargs): """Axis setup used for standard plots. kwargs are forwarded to plt.figure() Returns ------- ax0 : `~matplotlib.axes.Axes` Main plot ax1 : `~matplotlib.axes.Axes` Residuals """ from matplotlib import gridspec import matplotlib.pyplot as plt fig = plt.figure(**kwargs) gs = gridspec.GridSpec(5, 1) ax0 = plt.subplot(gs[:-2, :]) ax1 = plt.subplot(gs[3, :], sharex=ax0) gs.update(hspace=0.1) plt.setp(ax0.get_xticklabels(), visible=False) ax0.set_xscale("log") ax1.set_xscale("log") return ax0, ax1