Source code for gammapy.maps.wcsnd

# Licensed under a 3-clause BSD style license - see LICENSE.rst
from __future__ import absolute_import, division, print_function, unicode_literals
import copy
import numpy as np
from astropy.io import fits
from collections import OrderedDict
from .utils import unpack_seq
from .geom import pix_tuple_to_idx, axes_pix_to_coord
from .utils import interp_to_order
from .wcsmap import WcsGeom
from .wcsmap import WcsMap
from .reproject import reproject_car_to_hpx, reproject_car_to_wcs

__all__ = [
    'WcsNDMap',
]


[docs]class WcsNDMap(WcsMap): """Representation of a N+2D map using WCS with two spatial dimensions and N non-spatial dimensions. This class uses an ND numpy array to store map values. For maps with non-spatial dimensions and variable pixel size it will allocate an array with dimensions commensurate with the largest image plane. Parameters ---------- geom : `~gammapy.maps.WcsGeom` WCS geometry object. data : `~numpy.ndarray` Data array. If none then an empty array will be allocated. dtype : str, optional Data type, default is float32 meta : `~collections.OrderedDict` Dictionary to store meta data. """ def __init__(self, geom, data=None, dtype='float32', meta=None): # TODO: Figure out how to mask pixels for integer data types shape = tuple([np.max(geom.npix[0]), np.max(geom.npix[1])] + [ax.nbin for ax in geom.axes]) if data is None: data = self._init_data(geom, shape, dtype) elif data.shape != shape[::-1]: raise ValueError('Wrong shape for input data array. Expected {} ' 'but got {}'.format(shape, data.shape)) super(WcsNDMap, self).__init__(geom, data, meta) def _init_data(self, geom, shape, dtype): # Check whether corners of each image plane are valid coords = [] if not geom.is_regular: for idx in np.ndindex(geom.shape): pix = (np.array([0.0, float(geom.npix[0][idx] - 1)]), np.array([0.0, float(geom.npix[1][idx] - 1)])) pix += tuple([np.array(2 * [t]) for t in idx]) coords += geom.pix_to_coord(pix) else: pix = (np.array([0.0, float(geom.npix[0] - 1)]), np.array([0.0, float(geom.npix[1] - 1)])) pix += tuple([np.array(2 * [0.0]) for i in range(geom.ndim - 2)]) coords += geom.pix_to_coord(pix) if np.all(np.isfinite(np.vstack(coords))): if geom.is_regular: data = np.zeros(shape, dtype=dtype).T else: data = np.nan * np.ones(shape, dtype=dtype).T for idx in np.ndindex(geom.shape): data[idx, slice(geom.npix[0][idx]), slice(geom.npix[1][idx])] = 0.0 else: data = np.full(shape, np.nan, dtype=dtype).T idx = geom.get_idx() m = np.all(np.stack([t != -1 for t in idx]), axis=0) data[m] = 0.0 return data
[docs] @classmethod def from_hdu(cls, hdu, hdu_bands=None): """Make a WcsNDMap object from a FITS HDU. Parameters ---------- hdu : `~astropy.io.fits.BinTableHDU` or `~astropy.io.fits.ImageHDU` The map FITS HDU. hdu_bands : `~astropy.io.fits.BinTableHDU` The BANDS table HDU. """ geom = WcsGeom.from_header(hdu.header, hdu_bands) shape = tuple([ax.nbin for ax in geom.axes]) shape_wcs = tuple([np.max(geom.npix[0]), np.max(geom.npix[1])]) meta = cls._get_meta_from_header(hdu.header) map_out = cls(geom, meta=meta) # TODO: Should we support extracting slices? if isinstance(hdu, fits.BinTableHDU): pix = hdu.data.field('PIX') pix = np.unravel_index(pix, shape_wcs[::-1]) vals = hdu.data.field('VALUE') if 'CHANNEL' in hdu.data.columns.names and shape: chan = hdu.data.field('CHANNEL') chan = np.unravel_index(chan, shape[::-1]) idx = chan + pix else: idx = pix map_out.set_by_idx(idx[::-1], vals) else: map_out.data = hdu.data return map_out
[docs] def get_by_idx(self, idx): idx = pix_tuple_to_idx(idx) return self._data.T[idx]
[docs] def interp_by_coord(self, coords, interp=None): if self.geom.is_regular: pix = self.geom.coord_to_pix(coords) return self.interp_by_pix(pix, interp=interp) else: return self._interp_by_coord_griddata(coords, interp=interp)
[docs] def interp_by_pix(self, pix, interp=None): """Interpolate map values at the given pixel coordinates. """ if not self.geom.is_regular: raise ValueError('Pixel-based interpolation not supported for ' 'non-regular geometries.') order = interp_to_order(interp) if order == 0 or order == 1: return self._interp_by_pix_linear_grid(pix, order=order) elif order == 2 or order == 3: return self._interp_by_pix_map_coordinates(pix, order=order) else: raise ValueError('Invalid interpolation order: {}'.format(order))
def _interp_by_pix_linear_grid(self, pix, order=1): # TODO: Cache interpolator method_lookup = {0: 'nearest', 1: 'linear'} method = method_lookup.get(order, None) if method is None: raise ValueError('Invalid interpolation method: {}'.format(interp)) from scipy.interpolate import RegularGridInterpolator grid_pix = [np.arange(n, dtype=float) for n in self.data.shape[::-1]] if np.any(np.isfinite(self.data)): data = self.data.copy().T data[~np.isfinite(data)] = 0.0 else: data = self.data.T fn = RegularGridInterpolator(grid_pix, data, fill_value=None, bounds_error=False, method=method) return fn(tuple(pix)) def _interp_by_pix_map_coordinates(self, pix, order=1): from scipy.ndimage import map_coordinates pix = tuple([np.array(x, ndmin=1) if not isinstance(x, np.ndarray) or x.ndim == 0 else x for x in pix]) return map_coordinates(self.data.T, pix, order=order, mode='nearest') def _interp_by_coord_griddata(self, coords, interp=None): order = interp_to_order(interp) method_lookup = {0: 'nearest', 1: 'linear', 3: 'cubic'} method = method_lookup.get(order, None) if method is None: raise ValueError('Invalid interpolation method: {}'.format(interp)) from scipy.interpolate import griddata grid_coords = self.geom.get_coord(flat=True) data = self.data[np.isfinite(self.data)] vals = griddata(grid_coords, data, coords, method=method) m = ~np.isfinite(vals) if np.any(m): vals_fill = griddata(grid_coords, data, tuple([c[m] for c in coords]), method='nearest') vals[m] = vals_fill return vals
[docs] def interp_image(self, coords, order=1): if self.geom.ndim == 2: raise ValueError('Operation only supported for maps with one or more ' 'non-spatial dimensions.') elif self.geom.ndim == 3: return self._interp_image_cube(coords, order) else: raise NotImplementedError
def _interp_image_cube(self, coords, order=1): """Interpolate an image plane of a cube.""" # TODO: consider re-writing to support maps with > 3 dimensions from scipy.interpolate import interp1d axis = self.geom.axes[0] idx = axis.coord_to_idx_interp(coords[0]) map_slice = slice(int(idx[0]), int(idx[-1]) + 1) pix_vals = [float(t) for t in idx] pix = axis.coord_to_pix(coords[0]) data = self.data[map_slice] if coords[0] < axis.center[0] or coords[0] > axis.center[-1]: kind = 'linear' if order >= 1 else 'nearest' fill_value = 'extrapolate' else: kind = order fill_value = None # TODO: Cache interpolating function? fn = interp1d(pix_vals, data, copy=False, axis=0, kind=kind, fill_value=fill_value) data_interp = fn(float(pix)) geom = self.geom.to_image() return self.__class__(geom, data_interp)
[docs] def fill_by_idx(self, idx, weights=None): idx = pix_tuple_to_idx(idx) msk = np.all(np.stack([t != -1 for t in idx]), axis=0) idx = [t[msk] for t in idx] if weights is not None: weights = np.asarray(weights, dtype=self.data.dtype) weights = weights[msk] idx = np.ravel_multi_index(idx, self.data.T.shape) idx, idx_inv = np.unique(idx, return_inverse=True) weights = np.bincount(idx_inv, weights=weights).astype(self.data.dtype) self.data.T.flat[idx] += weights
[docs] def set_by_idx(self, idx, vals): idx = pix_tuple_to_idx(idx) self.data.T[idx] = vals
[docs] def iter_by_image(self): for idx in np.ndindex(self.geom.shape): yield self.data[idx[::-1]], idx
[docs] def iter_by_pix(self, buffersize=1): pix = list(self.geom.get_idx(flat=True)) vals = self.data[np.isfinite(self.data)] return unpack_seq(np.nditer([vals] + pix, flags=['external_loop', 'buffered'], buffersize=buffersize))
[docs] def iter_by_coord(self, buffersize=1): coords = list(self.geom.get_coord(flat=True)) vals = self.data[np.isfinite(self.data)] return unpack_seq(np.nditer([vals] + coords, flags=['external_loop', 'buffered'], buffersize=buffersize))
[docs] def sum_over_axes(self): if self.geom.ndim == 2: return copy.deepcopy(self) map_out = self.__class__(self.geom.to_image()) if not self.geom.is_regular: vals = self.get_by_idx(self.geom.get_idx()) map_out.fill_by_coord(self.geom.get_coord()[:2], vals) else: axis = tuple(np.arange(self.data.ndim - 2).tolist()) map_out.data = np.sum(self.data, axis=axis) return map_out
def _reproject_wcs(self, geom, mode='interp', order=1): from reproject import reproject_interp, reproject_exact map_out = WcsNDMap(geom) axes_eq = np.all([ax0 == ax1 for ax0, ax1 in zip(geom.axes, self.geom.axes)]) for vals, idx in map_out.iter_by_image(): if self.geom.ndim == 2 or axes_eq: img = self.data[idx[::-1]] else: coords = axes_pix_to_coord(geom.axes, idx) img = self.interp_image(coords, order=order).data # FIXME: This is a temporary solution for handling maps # with undefined pixels if np.any(~np.isfinite(img)): img = img.copy() img[~np.isfinite(img)] = 0.0 # TODO: Create WCS object for image plane if # multi-resolution geom shape_out = geom.get_image_shape(idx)[::-1] if self.geom.projection == 'CAR' and self.geom.is_allsky: data, footprint = reproject_car_to_wcs((img, self.geom.wcs), geom.wcs, shape_out=shape_out) elif mode == 'interp': data, footprint = reproject_interp((img, self.geom.wcs), geom.wcs, shape_out=shape_out) elif mode == 'exact': data, footprint = reproject_exact((img, self.geom.wcs), geom.wcs, shape_out=shape_out) else: raise TypeError( "Invalid reprojection mode, either choose 'interp' or 'exact'") vals[...] = data return map_out def _reproject_hpx(self, geom, mode='interp', order=1): from reproject import reproject_to_healpix from .hpxnd import HpxNDMap map_out = HpxNDMap(geom) coordsys = 'galactic' if geom.coordsys == 'GAL' else 'icrs' axes_eq = np.all([ax0 == ax1 for ax0, ax1 in zip(geom.axes, self.geom.axes)]) for vals, idx in map_out.iter_by_image(): if self.geom.ndim == 2 or axes_eq: img = self.data[idx[::-1]] else: coords = axes_pix_to_coord(geom.axes, idx) img = self.interp_image(coords, order=order).data # TODO: For partial-sky HPX we need to map from full- to # partial-sky indices if self.geom.projection == 'CAR' and self.geom.is_allsky: data, footprint = reproject_car_to_hpx((img, self.geom.wcs), coordsys, nside=geom.nside, nested=geom.nest, order=order) else: data, footprint = reproject_to_healpix((img, self.geom.wcs), coordsys, nside=geom.nside, nested=geom.nest, order=order) vals[...] = data return map_out
[docs] def pad(self, pad_width, mode='constant', cval=0, order=1): if np.isscalar(pad_width): pad_width = (pad_width, pad_width) pad_width += (0,) * (self.geom.ndim - 2) geom = self.geom.pad(pad_width[:2]) if self.geom.is_regular and mode != 'interp': return self._pad_np(geom, pad_width, mode, cval) else: return self._pad_coadd(geom, pad_width, mode, cval, order)
def _pad_np(self, geom, pad_width, mode, cval): """Pad a map with `~np.pad`. This method only works for regular geometries but should be more efficient when working with large maps. """ kw = {} if mode == 'constant': kw['constant_values'] = cval pad_width = [(t, t) for t in pad_width] data = np.pad(self.data, pad_width[::-1], mode, **kw) map_out = self.__class__(geom, data, meta=copy.deepcopy(self.meta)) return map_out def _pad_coadd(self, geom, pad_width, mode, cval, order): """Pad a map manually by coadding the original map with the new map.""" idx_in = self.geom.get_idx(flat=True) idx_in = tuple([t + w for t, w in zip(idx_in, pad_width)])[::-1] idx_out = geom.get_idx(flat=True)[::-1] map_out = self.__class__(geom, meta=copy.deepcopy(self.meta)) map_out.coadd(self) if mode == 'constant': pad_msk = np.zeros_like(map_out.data, dtype=bool) pad_msk[idx_out] = True pad_msk[idx_in] = False map_out.data[pad_msk] = cval elif mode in ['edge', 'interp']: coords = geom.pix_to_coord(idx_out[::-1]) m = self.geom.contains(coords) coords = tuple([c[~m] for c in coords]) vals = self.interp_by_coord(coords, interp=0 if mode == 'edge' else order) map_out.set_by_coord(coords, vals) else: raise ValueError('Unrecognized pad mode: {}'.format(mode)) return map_out
[docs] def crop(self, crop_width): if np.isscalar(crop_width): crop_width = (crop_width, crop_width) geom = self.geom.crop(crop_width) if self.geom.is_regular: slices = [slice(crop_width[0], int(self.geom.npix[0] - crop_width[0])), slice(crop_width[1], int(self.geom.npix[1] - crop_width[1]))] for ax in self.geom.axes: slices += [slice(None)] data = self.data[slices[::-1]] map_out = self.__class__(geom, data, meta=copy.deepcopy(self.meta)) else: # FIXME: This could be done more efficiently by # constructing the appropriate slices for each image plane map_out = self.__class__(geom, meta=copy.deepcopy(self.meta)) map_out.coadd(self) return map_out
[docs] def upsample(self, factor, order=0, preserve_counts=True): from scipy.ndimage import map_coordinates geom = self.geom.upsample(factor) idx = geom.get_idx() pix = ((idx[0] - 0.5 * (factor - 1)) / factor, (idx[1] - 0.5 * (factor - 1)) / factor,) + idx[2:] data = map_coordinates(self.data.T, pix, order=order, mode='nearest') if preserve_counts: data /= factor**2 return self.__class__(geom, data, meta=copy.deepcopy(self.meta))
[docs] def downsample(self, factor, preserve_counts=True): from skimage.measure import block_reduce geom = self.geom.downsample(factor) block_size = tuple([factor, factor] + [1] * (self.geom.ndim - 2)) data = block_reduce(self.data, block_size[::-1], np.nansum) if not preserve_counts: data /= factor**2 return self.__class__(geom, data, meta=copy.deepcopy(self.meta))
[docs] def plot(self, ax=None, idx=None, **kwargs): """Quickplot method. Parameters ---------- norm : str Set the normalization scheme of the color map. idx : tuple Set the image slice to plot if this map has non-spatial dimensions. **kwargs : dict Keyword arguments passed to `~matplotlib.pyplot.imshow`. Returns ------- fig : `~matplotlib.figure.Figure` Figure object. ax : `~astropy.visualization.wcsaxes.WCSAxes` WCS axis object im : `~matplotlib.image.AxesImage` Image object. """ import matplotlib.pyplot as plt import matplotlib.colors as colors if ax is None: fig = plt.gcf() ax = fig.add_subplot(111, projection=self.geom.wcs) if idx is not None: slices = (slice(None), slice(None)) + idx data = self.data[slices[::-1]] else: data = self.data kwargs.setdefault('interpolation', 'nearest') kwargs.setdefault('origin', 'lower') kwargs.setdefault('norm', None) if kwargs['norm'] == 'log': kwargs['norm'] = colors.LogNorm() elif kwargs['norm'] == 'pow2': kwargs['norm'] = colors.PowerNorm(gamma=0.5) im = ax.imshow(data, **kwargs) ax.coords.grid(color='w', linestyle=':', linewidth=0.5) return fig, ax, im