# Licensed under a 3-clause BSD style license - see LICENSE.rst
"""Utility functions and classes for n-dimensional data and axes."""
from collections import OrderedDict
import numpy as np
from astropy.units import Quantity
from .array import array_stats_str
from .interpolation import ScaledRegularGridInterpolator
__all__ = ["NDDataArray", "DataAxis", "BinnedDataAxis", "sqrt_space"]
[docs]class NDDataArray:
"""ND Data Array Base class
for usage examples see :gp-notebook:`nddata_demo`
Parameters
----------
axes : list
List of `~gammapy.utils.nddata.DataAxis`
data : `~astropy.units.Quantity`
Data
meta : dict
Meta info
interp_kwargs : dict
TODO
"""
default_interp_kwargs = dict(bounds_error=False, values_scale="lin")
"""Default interpolation kwargs used to initialize the
`scipy.interpolate.RegularGridInterpolator`. The interpolation behaviour
of an individual axis ('log', 'linear') can be passed to the axis on
initialization."""
def __init__(self, axes, data=None, meta=None, interp_kwargs=None):
self._axes = axes
if data is not None:
self.data = data
if meta is not None:
self.meta = OrderedDict(meta)
self.interp_kwargs = interp_kwargs or self.default_interp_kwargs
self._regular_grid_interp = None
def __str__(self):
ss = "NDDataArray summary info\n"
for axis in self.axes:
ss += array_stats_str(axis.nodes, axis.name)
ss += array_stats_str(self.data, "Data")
return ss
@property
def axes(self):
"""Array holding the axes in correct order"""
return self._axes
[docs] def axis(self, name):
"""Return axis by name"""
try:
idx = [_.name for _ in self.axes].index(name)
except ValueError:
raise ValueError("Axis {} not found".format(name))
return self.axes[idx]
@property
def data(self):
"""Array holding the n-dimensional data."""
return self._data
@data.setter
def data(self, data):
"""Set data.
Some sanity checks are performed to avoid an invalid array.
Also, the interpolator is set to None to avoid unwanted behaviour.
Parameters
----------
data : `~astropy.units.Quantity`, array-like
Data array
"""
data = Quantity(data)
dimension = len(data.shape)
if dimension != self.dim:
raise ValueError(
"Overall dimensions to not match. "
"Data: {}, Hist: {}".format(dimension, self.dim)
)
for dim in np.arange(self.dim):
axis = self.axes[dim]
if axis.nbins != data.shape[dim]:
msg = "Data shape does not match in dimension {d}\n"
msg += "Axis {n} : {sa}, Data {sd}"
raise ValueError(
msg.format(d=dim, n=axis.name, sa=axis.nbins, sd=data.shape[dim])
)
self._regular_grid_interp = None
self._data = data
@property
def dim(self):
"""Dimension (number of axes)"""
return len(self.axes)
[docs] def find_node(self, **kwargs):
"""Find next node
Parameters
----------
kwargs : dict
Keys are the axis names, Values the evaluation points
"""
node = []
for axis in self.axes:
lookup_val = Quantity(kwargs.pop(axis.name))
temp = axis.find_node(lookup_val)
node.append(temp)
return node
[docs] def evaluate(self, method=None, **kwargs):
"""Evaluate NDData Array
This function provides a uniform interface to several interpolators.
The evaluation nodes are given as ``kwargs``.
Currently available:
`~scipy.interpolate.RegularGridInterpolator`, methods: linear, nearest
Parameters
----------
method : str {'linear', 'nearest'}, optional
Interpolation method
kwargs : dict
Keys are the axis names, Values the evaluation points
Returns
-------
array : `~astropy.units.Quantity`
Interpolated values, axis order is the same as for the NDData array
"""
values = []
for idx, axis in enumerate(self.axes):
# Extract values for each axis, default: nodes
shape = [1] * len(self.axes)
shape[idx] = -1
default = axis.nodes.reshape(tuple(shape))
temp = Quantity(kwargs.pop(axis.name, default))
values.append(np.atleast_1d(temp))
# This is to catch e.g. typos in axis names
if kwargs != {}:
raise ValueError("Input given for unknown axis: {}".format(kwargs))
if self._regular_grid_interp is None:
self._add_regular_grid_interp()
return self._regular_grid_interp(values, method=method, **kwargs)
def _add_regular_grid_interp(self, interp_kwargs=None):
"""Add `~scipy.interpolate.RegularGridInterpolator`
http://docs.scipy.org/doc/scipy-0.14.0/reference/generated/scipy.interpolate.RegularGridInterpolator.html
Parameters
----------
interp_kwargs : dict, optional
Interpolation kwargs
"""
if interp_kwargs is None:
interp_kwargs = self.interp_kwargs
points = [a.nodes for a in self.axes]
points_scale = [a.interpolation_mode for a in self.axes]
self._regular_grid_interp = ScaledRegularGridInterpolator(
points, self.data, points_scale=points_scale, **interp_kwargs
)
[docs]class DataAxis:
"""Data axis to be used with NDDataArray
Axis values are interpreted as nodes.
For binned data see `~gammapy.utils.nddata.BinnedDataAxis`.
Parameters
----------
nodes : `~astropy.units.Quantity`
Interpolation nodes
name : str, optional
Axis name, default: 'Default'
interpolation_mode : str {'linear', 'log'}
Interpolation behaviour, default: 'linear'
"""
def __init__(self, nodes, name="Default", interpolation_mode="linear"):
# Need this for subclassing (see BinnedDataAxis)
if nodes is not None:
self._data = Quantity(nodes)
if (self._data < 0).any() and interpolation_mode == "log":
raise ValueError(
"Interpolation scaling 'log' only support for positive node values."
)
self.name = name
self._interpolation_mode = interpolation_mode
def __str__(self):
ss = self.__class__.__name__
ss += "\nName: {}".format(self.name)
ss += "\nUnit: {}".format(self.unit)
ss += "\nNodes: {}".format(self.nbins)
ss += "\nInterpolation mode: {}".format(self.interpolation_mode)
return ss
@property
def unit(self):
"""Axis unit"""
return self.nodes.unit
[docs] @classmethod
def logspace(cls, vmin, vmax, nbins, unit=None, **kwargs):
"""Create axis with equally log-spaced nodes
if no unit is given, it will be taken from vmax,
log interpolation is enable by default.
Parameters
----------
vmin : `~astropy.units.Quantity`, float
Lowest value
vmax : `~astropy.units.Quantity`, float
Highest value
bins : int
Number of bins
unit : `~astropy.units.UnitBase`, str
Unit
"""
kwargs.setdefault("interpolation_mode", "log")
if unit is not None:
vmin = Quantity(vmin, unit)
vmax = Quantity(vmax, unit)
else:
vmin = Quantity(vmin)
vmax = Quantity(vmax)
unit = vmax.unit
vmin = vmin.to(unit)
x_min, x_max = np.log10([vmin.value, vmax.value])
vals = np.logspace(x_min, x_max, nbins)
return cls(vals * unit, **kwargs)
[docs] def find_node(self, val):
"""Find next node
Parameters
----------
val : `~astropy.units.Quantity`
Lookup value
"""
val = Quantity(val)
if not val.unit.is_equivalent(self.unit):
raise ValueError(
"Units mismatch: val.unit = {!r}, self.unit = {!r}".format(
val.unit, self.unit
)
)
val = val.to(self.nodes.unit)
val = np.atleast_1d(val)
x1 = np.array([val] * self.nbins).transpose()
x2 = np.array([self.nodes] * len(val))
temp = np.abs(x1 - x2)
idx = np.argmin(temp, axis=1)
return idx
@property
def nbins(self):
"""Number of bins"""
return len(self.nodes)
@property
def nodes(self):
"""Evaluation nodes"""
return self._data
@property
def interpolation_mode(self):
"""Interpolation mode
"""
return self._interpolation_mode
[docs]class BinnedDataAxis(DataAxis):
"""Data axis for binned data
Parameters
----------
lo : `~astropy.units.Quantity`
Lower bin edges
hi : `~astropy.units.Quantity`
Upper bin edges
name : str, optional
Axis name, default: 'Default'
interpolation_mode : str {'linear', 'log'}
Interpolation behaviour, default: 'linear'
"""
def __init__(self, lo, hi, **kwargs):
self.lo = Quantity(lo)
self.hi = Quantity(hi)
if ((self.lo < 0).any() or (self.hi < 0).any()) and kwargs.get(
"interpolation_mode"
) == "log":
raise ValueError(
"Interpolation scaling 'log' only support for positive node values."
)
super().__init__(None, **kwargs)
[docs] @classmethod
def logspace(cls, emin, emax, nbins, unit=None, **kwargs):
# TODO: splitout log space into a helper function
vals = DataAxis.logspace(emin, emax, nbins + 1, unit)._data
return cls(vals[:-1], vals[1:], **kwargs)
def __str__(self):
ss = super().__str__()
ss += "\nLower bounds {}".format(self.lo)
ss += "\nUpper bounds {}".format(self.hi)
return ss
@property
def bins(self):
"""Bin edges"""
unit = self.lo.unit
val = np.append(self.lo.value, self.hi.to_value(unit)[-1])
return val * unit
@property
def bin_width(self):
"""Bin width"""
return self.hi - self.lo
@property
def nodes(self):
"""Evaluation nodes.
Depending on the interpolation mode, either log or lin center are
returned
"""
if self.interpolation_mode == "log":
return self.log_center()
else:
return self.lin_center()
[docs] def lin_center(self):
"""Linear bin centers"""
return (self.lo + self.hi) / 2
[docs] def log_center(self):
"""Logarithmic bin centers"""
return np.sqrt(self.lo * self.hi)
[docs]def sqrt_space(start, stop, num):
"""Return numbers spaced evenly on a square root scale.
This function is similar to `numpy.linspace` and `numpy.logspace`.
Parameters
----------
start : float
start is the starting value of the sequence
stop : float
stop is the final value of the sequence
num : int
Number of samples to generate.
Returns
-------
samples : `~numpy.ndarray`
1D array with a square root scale
Examples
--------
>>> from gammapy.utils.nddata import sqrt_space
>>> samples = sqrt_space(0, 2, 5)
array([ 0. , 1. , 1.41421356, 1.73205081, 2. ])
"""
samples2 = np.linspace(start ** 2, stop ** 2, num)
samples = np.sqrt(samples2)
return samples