Source code for gammapy.utils.nddata

# Licensed under a 3-clause BSD style license - see LICENSE.rst
"""Utility functions and classes for n-dimensional data and axes.
"""
from __future__ import absolute_import, division, print_function, unicode_literals
import itertools
import numpy as np
from astropy.units import Quantity
from ..extern.bunch import Bunch
from .array import array_stats_str

__all__ = [
    'NDDataArray',
    'DataAxis',
    'BinnedDataAxis',
    'sqrt_space',
]


[docs]class NDDataArray(object): """ND Data Array Base class for usage examples see :gp-extra-notebook:`nddata_demo` Parameters ---------- axes : list List of `~gammapy.utils.nddata.DataAxis` data : `~astropy.units.Quantity` Data meta : dict Meta info interp_kwargs : dict TODO """ default_interp_kwargs = dict(bounds_error=False) """Default interpolation kwargs used to initialize the `scipy.interpolate.RegularGridInterpolator`. The interpolation behaviour of an individual axis ('log', 'linear') can be passed to the axis on initialization.""" def __init__(self, axes, data=None, meta=None, interp_kwargs=None): self._axes = axes if data is not None: self.data = data if meta is not None: self.meta = Bunch(meta) self.interp_kwargs = interp_kwargs or self.default_interp_kwargs self._regular_grid_interp = None @property def axes(self): """Array holding the axes in correct order""" return self._axes
[docs] def axis(self, name): """Return axis by name""" try: idx = [_.name for _ in self.axes].index(name) except ValueError: raise ValueError('Axis {} not found'.format(name)) return self.axes[idx]
@property def data(self): """Array holding the n-dimensional data.""" return self._data @data.setter def data(self, data): """Set data. Some sanity checks are performed to avoid an invalid array. Also, the interpolator is set to None to avoid unwanted behaviour. Parameters ---------- data : `~astropy.units.Quantity`, array-like Data array """ data = Quantity(data) dimension = len(data.shape) if dimension != self.dim: raise ValueError('Overall dimensions to not match. ' 'Data: {}, Hist: {}'.format(dimension, self.dim)) for dim in np.arange(self.dim): axis = self.axes[dim] if axis.nbins != data.shape[dim]: msg = 'Data shape does not match in dimension {d}\n' msg += 'Axis {n} : {sa}, Data {sd}' raise ValueError(msg.format(d=dim, n=axis.name, sa=axis.nbins, sd=data.shape[dim])) self._regular_grid_interp = None self._data = data @property def dim(self): """Dimension (number of axes)""" return len(self.axes) def __str__(self): """String representation""" ss = 'NDDataArray summary info\n' for axis in self.axes: ss += array_stats_str(axis.nodes, axis.name) ss += array_stats_str(self.data, 'Data') return ss
[docs] def find_node(self, **kwargs): """Find next node Parameters ---------- kwargs : dict Keys are the axis names, Values the evaluation points """ node = list() for axis in self.axes: lookup_val = Quantity(kwargs.pop(axis.name)) temp = axis.find_node(lookup_val) node.append(temp) return node
[docs] def evaluate(self, method=None, **kwargs): """Evaluate NDData Array This function provides a uniform interface to several interpolators. The evaluation nodes are given as ``kwargs``. Currently available: `~scipy.interpolate.RegularGridInterpolator`, methods: linear, nearest Parameters ---------- method : str {'linear', 'nearest'}, optional Interpolation method kwargs : dict Keys are the axis names, Values the evaluation points Returns ------- array : `~astropy.units.Quantity` Interpolated values, axis order is the same as for the NDData array """ values = list() for axis in self.axes: # Extract values for each axis, default: nodes temp = Quantity(kwargs.pop(axis.name, axis.nodes)) # Transform to correct unit temp = temp.to(axis.unit).value # Transform to match interpolation behaviour of axis values.append(np.atleast_1d(axis._interp_values(temp))) # This is to catch e.g. typos in axis names if kwargs != {}: raise ValueError("Input given for unknown axis: {}".format(kwargs)) if method is None: return self._eval_regular_grid_interp( values) * self.data.unit elif method == 'linear': return self._eval_regular_grid_interp( values, method='linear') * self.data.unit elif method == 'nearest': return self._eval_regular_grid_interp( values, method='nearest') * self.data.unit else: raise ValueError('Interpolator {} not available'.format(method))
def _eval_regular_grid_interp(self, values, **kwargs): """Evaluate linear interpolator Input: list of values to evaluate, in correct units and correct order. """ if self._regular_grid_interp is None: self._add_regular_grid_interp() # This is necessary since np.append does not support the 1D case if self.dim > 1: shapes = np.concatenate([np.shape(_) for _ in values]) else: shapes = values[0].shape # Flatten in order to support 2D array input values = [_.flatten() for _ in values] points = list(itertools.product(*values)) res = self._regular_grid_interp(points, **kwargs) res = np.reshape(res, shapes).squeeze() return res def _add_regular_grid_interp(self, interp_kwargs=None): """Add `~scipy.interpolate.RegularGridInterpolator` http://docs.scipy.org/doc/scipy-0.14.0/reference/generated/scipy.interpolate.RegularGridInterpolator.html Parameters ---------- interp_kwargs : dict, optional Interpolation kwargs """ from scipy.interpolate import RegularGridInterpolator if interp_kwargs is None: interp_kwargs = self.interp_kwargs points = [a._interp_nodes() for a in self.axes] values = self.data.value # If values contains nan, only setup interpolator in valid range if np.isnan(values).any(): if self.dim > 1: raise NotImplementedError('Data grid contains nan. This is not' 'supported for arrays dimension > 1') else: mask = np.isfinite(values) points = [points[0][mask]] values = values[mask] self._regular_grid_interp = RegularGridInterpolator(points, values, **interp_kwargs)
[docs]class DataAxis(object): """Data axis to be used with NDDataArray Axis values are interpreted as nodes. For binned data see `~gammapy.utils.nddata.BinnedDataAxis`. Parameters ---------- nodes : `~astropy.units.Quantity` Interpolation nodes name : str, optional Axis name, default: 'Default' interpolation_mode : str {'linear', 'log'} Interpolation behaviour, default: 'linear' """ def __init__(self, nodes, name='Default', interpolation_mode='linear'): # Need this for subclassing (see BinnedDataAxis) if nodes is not None: self._data = Quantity(nodes) self.name = name self._interpolation_mode = interpolation_mode def __str__(self): ss = self.__class__.__name__ ss += '\nName: {}'.format(self.name) ss += '\nUnit: {}'.format(self.unit) ss += '\nNodes: {}'.format(self.nbins) ss += '\nInterpolation mode: {}'.format(self.interpolation_mode) return ss @property def unit(self): """Axis unit""" return self.nodes.unit @classmethod
[docs] def logspace(cls, vmin, vmax, nbins, unit=None, **kwargs): """Create axis with equally log-spaced nodes if no unit is given, it will be taken from vmax, log interpolation is enable by default. Parameters ---------- vmin : `~astropy.units.Quantity`, float Lowest value vmax : `~astropy.units.Quantity`, float Highest value bins : int Number of bins unit : `~astropy.units.UnitBase`, str Unit """ kwargs.setdefault('interpolation_mode', 'log') if unit is not None: vmin = Quantity(vmin, unit) vmax = Quantity(vmax, unit) else: vmin = Quantity(vmin) vmax = Quantity(vmax) unit = vmax.unit vmin = vmin.to(unit) x_min, x_max = np.log10([vmin.value, vmax.value]) vals = np.logspace(x_min, x_max, nbins) return cls(vals * unit, **kwargs)
[docs] def find_node(self, val): """Find next node Parameters ---------- val : `~astropy.units.Quantity` Lookup value """ val = Quantity(val) if not val.unit.is_equivalent(self.unit): raise ValueError('Units {} and {} do not match'.format( val.unit, self.unit)) val = val.to(self.nodes.unit) val = np.atleast_1d(val) x1 = np.array([val] * self.nbins).transpose() x2 = np.array([self.nodes] * len(val)) temp = np.abs(x1 - x2) idx = np.argmin(temp, axis=1) return idx
@property def nbins(self): """Number of bins""" return len(self.nodes) @property def nodes(self): """Evaluation nodes""" return self._data @property def interpolation_mode(self): """Interpolation mode """ return self._interpolation_mode def _interp_nodes(self): """Nodes to be used for interpolation""" if self.interpolation_mode == 'log': return np.log10(self.nodes.value) else: return self.nodes.value def _interp_values(self, values): """Transform values correctly for interpolation""" if self.interpolation_mode == 'log': return np.log10(values) else: return values
[docs]class BinnedDataAxis(DataAxis): """Data axis for binned data Parameters ---------- lo : `~astropy.units.Quantity` Lower bin edges hi : `~astropy.units.Quantity` Upper bin edges name : str, optional Axis name, default: 'Default' interpolation_mode : str {'linear', 'log'} Interpolation behaviour, default: 'linear' """ def __init__(self, lo, hi, **kwargs): self.lo = Quantity(lo) self.hi = Quantity(hi) super(BinnedDataAxis, self).__init__(None, **kwargs) @classmethod
[docs] def logspace(cls, emin, emax, nbins, unit=None, **kwargs): # TODO: splitout log space into a helper function vals = DataAxis.logspace(emin, emax, nbins + 1, unit)._data return cls(vals[:-1], vals[1:], **kwargs)
def __str__(self): ss = super(BinnedDataAxis, self).__str__() ss += '\nLower bounds {}'.format(self.lo) ss += '\nUpper bounds {}'.format(self.hi) return ss @property def bins(self): """Bin edges""" unit = self.lo.unit val = np.append(self.lo.value, self.hi.value[-1]) return val * unit @property def bin_width(self): """Bin width""" return self.hi - self.lo @property def nodes(self): """Evaluation nodes. Depending on the interpolation mode, either log or lin center are returned """ if self.interpolation_mode == 'log': return self.log_center() else: return self.lin_center()
[docs] def lin_center(self): """Linear bin centers""" return (self.lo + self.hi) / 2
[docs] def log_center(self): """Logarithmic bin centers""" return np.sqrt(self.lo * self.hi)
[docs]def sqrt_space(start, stop, num): """Return numbers spaced evenly on a square root scale. This function is similar to `numpy.linspace` and `numpy.logspace`. Parameters ---------- start : float start is the starting value of the sequence stop : float stop is the final value of the sequence num : int Number of samples to generate. Returns ------- samples : `~numpy.ndarray` 1D array with a square root scale Examples -------- >>> from gammapy.utils.nddata import sqrt_space >>> samples = sqrt_space(0, 2, 5) array([ 0. , 1. , 1.41421356, 1.73205081, 2. ]) """ samples2 = np.linspace(start ** 2, stop ** 2, num) samples = np.sqrt(samples2) return samples