Source code for gammapy.utils.nddata

# Licensed under a 3-clause BSD style license - see LICENSE.rst
"""Utility functions and classes for n-dimensional data and axes."""
import numpy as np
from astropy.units import Quantity
from .array import array_stats_str
from .interpolation import ScaledRegularGridInterpolator

__all__ = ["NDDataArray"]

[docs]class NDDataArray: """ND Data Array Base class Parameters ---------- axes : list List of `~gammapy.utils.nddata.DataAxis` data : `~astropy.units.Quantity` Data meta : dict Meta info interp_kwargs : dict TODO """ default_interp_kwargs = dict(bounds_error=False, values_scale="lin") """Default interpolation kwargs used to initialize the `scipy.interpolate.RegularGridInterpolator`. The interpolation behaviour of an individual axis ('log', 'linear') can be passed to the axis on initialization.""" def __init__(self, axes, data=None, meta=None, interp_kwargs=None): self._axes = axes if data is not None: = data self.meta = meta or {} self.interp_kwargs = interp_kwargs or self.default_interp_kwargs self._regular_grid_interp = None def __str__(self): ss = "NDDataArray summary info\n" for axis in self.axes: ss += str(axis) ss += array_stats_str(, "Data") return ss @property def axes(self): """Array holding the axes in correct order""" return self._axes
[docs] def axis(self, name): """Return axis by name""" try: idx = [ for _ in self.axes].index(name) except ValueError: raise ValueError(f"Axis {name} not found") return self.axes[idx]
@property def data(self): """Array holding the n-dimensional data.""" return self._data @data.setter def data(self, data): """Set data. Some sanity checks are performed to avoid an invalid array. Also, the interpolator is set to None to avoid unwanted behaviour. Parameters ---------- data : `~astropy.units.Quantity`, array-like Data array """ data = Quantity(data) dimension = len(data.shape) if dimension != self.dim: raise ValueError( "Overall dimensions to not match. " "Data: {}, Hist: {}".format(dimension, self.dim) ) for dim in np.arange(self.dim): axis = self.axes[dim] if axis.nbin != data.shape[dim]: msg = "Data shape does not match in dimension {d}\n" msg += "Axis {n} : {sa}, Data {sd}" raise ValueError( msg.format(d=dim,, sa=axis.nbin, sd=data.shape[dim]) ) self._regular_grid_interp = None self._data = data @property def dim(self): """Dimension (number of axes)""" return len(self.axes)
[docs] def evaluate(self, method=None, **kwargs): """Evaluate NDData Array This function provides a uniform interface to several interpolators. The evaluation nodes are given as ``kwargs``. Currently available: `~scipy.interpolate.RegularGridInterpolator`, methods: linear, nearest Parameters ---------- method : str {'linear', 'nearest'}, optional Interpolation method kwargs : dict Keys are the axis names, Values the evaluation points Returns ------- array : `~astropy.units.Quantity` Interpolated values, axis order is the same as for the NDData array """ values = [] for idx, axis in enumerate(self.axes): # Extract values for each axis, default: nodes shape = [1] * len(self.axes) shape[idx] = -1 default = temp = Quantity(kwargs.pop(, default)) values.append(np.atleast_1d(temp)) # This is to catch e.g. typos in axis names if kwargs != {}: raise ValueError(f"Input given for unknown axis: {kwargs}") if self._regular_grid_interp is None: self._add_regular_grid_interp() return self._regular_grid_interp(values, method=method, **kwargs)
def _add_regular_grid_interp(self, interp_kwargs=None): """Add `~scipy.interpolate.RegularGridInterpolator` Parameters ---------- interp_kwargs : dict, optional Interpolation kwargs """ if interp_kwargs is None: interp_kwargs = self.interp_kwargs points = [ for a in self.axes] points_scale = [a.interp for a in self.axes] self._regular_grid_interp = ScaledRegularGridInterpolator( points,, points_scale=points_scale, **interp_kwargs )