Source code for gammapy.maps.region.geom

import copy
import logging
from functools import lru_cache
import numpy as np
from astropy import units as u
from astropy.coordinates import Angle, SkyCoord
from astropy.io import fits
from astropy.table import QTable, Table
from astropy.utils import lazyproperty
from astropy.visualization.wcsaxes import WCSAxes
from astropy.wcs.utils import (
    proj_plane_pixel_area,
    proj_plane_pixel_scales,
    wcs_to_celestial_frame,
)
from regions import (
    CompoundSkyRegion,
    PixCoord,
    PointSkyRegion,
    RectanglePixelRegion,
    Regions,
    SkyRegion,
)
import matplotlib.pyplot as plt
from gammapy.utils.regions import (
    compound_region_center,
    compound_region_to_regions,
    regions_to_compound_region,
)
from gammapy.visualization.utils import ARTIST_TO_LINE_PROPERTIES
from ..axes import MapAxes
from ..coord import MapCoord
from ..core import Map
from ..geom import Geom, pix_tuple_to_idx
from ..utils import _check_width
from ..wcs import WcsGeom

__all__ = ["RegionGeom"]


[docs]class RegionGeom(Geom): """Map geometry representing a region on the sky. The spatial component of the geometry is made up of a single pixel with an arbitrary shape and size. It can also have any number of non-spatial dimensions. This class represents the geometry for the `RegionNDMap` maps. Parameters ---------- region : `~regions.SkyRegion` Region object. axes : list of `MapAxis` Non-spatial data axes. wcs : `~astropy.wcs.WCS` Optional wcs object to project the region if needed. binsz_wcs : `float` Angular bin size of the underlying `~WcsGeom` used to evaluate quantities in the region. Default size is 0.01 deg. This default value is adequate for the majority of use cases. If a wcs object is provided, the input of binsz_wcs is overridden. """ is_regular = True is_allsky = False is_hpx = False is_region = True _slice_spatial_axes = slice(0, 2) _slice_non_spatial_axes = slice(2, None) projection = "TAN" def __init__(self, region, axes=None, wcs=None, binsz_wcs="0.1 deg"): self._region = region self._axes = MapAxes.from_default(axes, n_spatial_axes=2) self._binsz_wcs = u.Quantity(binsz_wcs) if wcs is None and region is not None: if isinstance(region, CompoundSkyRegion): self._center = compound_region_center(region) else: self._center = region.center wcs = WcsGeom.create( binsz=binsz_wcs, skydir=self._center, proj=self.projection, frame=self._center.frame.name, ).wcs self._wcs = wcs self.ndim = len(self.data_shape) # define cached methods self.get_wcs_coord_and_weights = lru_cache()(self.get_wcs_coord_and_weights) def __setstate__(self, state): for key, value in state.items(): if key in ["get_wcs_coord_and_weights"]: state[key] = lru_cache()(value) self.__dict__ = state @property def frame(self): """Coordinate system, either Galactic ("galactic") or Equatorial ("icrs").""" if self.region is None: return "icrs" try: return self.region.center.frame.name except AttributeError: return wcs_to_celestial_frame(self.wcs).name @property def binsz_wcs(self): """Angular bin size of the underlying `~WcsGeom` Returns ------- binsz_wcs: `~astropy.coordinates.Angle` """ return Angle(proj_plane_pixel_scales(self.wcs), unit="deg") @lazyproperty def _rectangle_bbox(self): if self.region is None: raise ValueError("Region definition required.") regions = compound_region_to_regions(self.region) regions_pix = [_.to_pixel(self.wcs) for _ in regions] bbox = regions_pix[0].bounding_box for region_pix in regions_pix[1:]: bbox = bbox.union(region_pix.bounding_box) try: rectangle_pix = bbox.to_region() except ValueError: rectangle_pix = RectanglePixelRegion( center=PixCoord(*bbox.center[::-1]), width=1, height=1 ) return rectangle_pix.to_sky(self.wcs) @property def width(self): """Width of bounding box of the region. Returns ------- width : `~astropy.units.Quantity` Dimensions of the region in both spatial dimensions. Units: ``deg`` """ rectangle = self._rectangle_bbox return u.Quantity([rectangle.width.to("deg"), rectangle.height.to("deg")]) @property def region(self): """`~regions.SkyRegion` object that defines the spatial component of the region geometry""" return self._region @property def is_all_point_sky_regions(self): """Whether regions are all point regions""" regions = compound_region_to_regions(self.region) return np.all([isinstance(_, PointSkyRegion) for _ in regions]) @property def axes(self): """List of non-spatial axes.""" return self._axes @property def axes_names(self): """All axes names""" return ["lon", "lat"] + self.axes.names @property def wcs(self): """WCS projection object.""" return self._wcs @property def center_coord(self): """(`astropy.coordinates.SkyCoord`)""" return self.pix_to_coord(self.center_pix) @property def center_pix(self): """Pixel values corresponding to the center of the region""" return tuple((np.array(self.data_shape) - 1.0) / 2)[::-1] @lazyproperty def center_skydir(self): """Sky coordinate of the center of the region""" if self.region is None: return SkyCoord(np.nan * u.deg, np.nan * u.deg) return self._rectangle_bbox.center @property def npix(self): """Number of spatial pixels""" return (1, 1)
[docs] def contains(self, coords): """Check if a given map coordinate is contained in the region. Requires the `.region` attribute to be set. For `PointSkyRegion` the method always returns true. Parameters ---------- coords : tuple, dict, `MapCoord` or `~astropy.coordinates.SkyCoord` Object containing coordinate arrays we wish to check for inclusion in the region. Returns ------- mask : `~numpy.ndarray` Boolean Numpy array with the same shape as the input that indicates which coordinates are inside the region. """ if self.region is None: raise ValueError("Region definition required.") coords = MapCoord.create(coords, frame=self.frame, axis_names=self.axes.names) if self.is_all_point_sky_regions: return np.ones(coords.skycoord.shape, dtype=bool) return self.region.contains(coords.skycoord, self.wcs)
[docs] def contains_wcs_pix(self, pix): """Check if a given wcs pixel coordinate is contained in the region. For `PointSkyRegion` the method always returns true. Parameters ---------- pix : tuple Tuple of pixel coordinates. Returns ------- containment : `~numpy.ndarray` Bool array. """ if self.is_all_point_sky_regions: return np.ones(pix[0].shape, dtype=bool) region_pix = self.region.to_pixel(self.wcs) return region_pix.contains(PixCoord(pix[0], pix[1]))
[docs] def separation(self, position): """Angular distance between the center of the region and the given position. Parameters ---------- position : `astropy.coordinates.SkyCoord` Sky coordinate we want the angular distance to. Returns ------- sep : `~astropy.coordinates.Angle` The on-sky separation between the given coordinate and the region center. """ return self.center_skydir.separation(position)
@property def data_shape(self): """Shape of the Numpy data array matching this geometry.""" return self._shape[::-1] @property def data_shape_axes(self): """Shape of data of the non-spatial axes and unit spatial axes.""" return self.axes.shape[::-1] + (1, 1) @property def _shape(self): """Number of bins in each dimension. The spatial dimension is always (1, 1), as a `RegionGeom` is not pixelized further """ return tuple((1, 1) + self.axes.shape)
[docs] def get_coord(self, mode="center", frame=None, sparse=False, axis_name=None): """Get map coordinates from the geometry. Parameters ---------- mode : {'center', 'edges'} Get center or edge coordinates for the non-spatial axes. frame : str or `~astropy.coordinates.Frame` Coordinate frame sparse : bool Compute sparse coordinates axis_name : str If mode = "edges", the edges will be returned for this axis only. Returns ------- coord : `~MapCoord` Map coordinate object. """ if mode == "edges" and axis_name is None: raise ValueError("Mode 'edges' requires axis name") coords = self.axes.get_coord(mode=mode, axis_name=axis_name) coords["skycoord"] = self.center_skydir.reshape((1, 1)) if frame is None: frame = self.frame return MapCoord.create(coords, frame=self.frame).to_frame(frame)
def _pad_spatial(self, pad_width): raise NotImplementedError("Spatial padding of `RegionGeom` not supported")
[docs] def crop(self): raise NotImplementedError("Cropping of `RegionGeom` not supported")
[docs] def solid_angle(self): """Get solid angle of the region. Returns ------- angle : `~astropy.units.Quantity` Solid angle of the region. In sr. Units: ``sr`` """ if self.region is None: raise ValueError("Region definition required.") # compound regions do not implement area() # so we use the mask representation and estimate the area # from the pixels in the mask using oversampling if isinstance(self.region, CompoundSkyRegion): # oversample by a factor of ten oversampling = 10.0 wcs = self.to_binsz_wcs(self.binsz_wcs / oversampling).wcs pixel_region = self.region.to_pixel(wcs) mask = pixel_region.to_mask() area = np.count_nonzero(mask) / oversampling**2 else: # all other types of regions should implement area area = self.region.to_pixel(self.wcs).area solid_angle = area * proj_plane_pixel_area(self.wcs) * u.deg**2 return solid_angle.to("sr")
[docs] def bin_volume(self): """If the RegionGeom has a non-spatial axis, it returns the volume of the region. If not, it just returns the solid angle size. Returns ------- volume : `~astropy.units.Quantity` Volume of the region. """ bin_volume = self.solid_angle() * np.ones(self.data_shape) for idx, ax in enumerate(self.axes): shape = self.ndim * [1] shape[-(idx + 3)] = -1 bin_volume = bin_volume * ax.bin_width.reshape(tuple(shape)) return bin_volume
[docs] def to_wcs_geom(self, width_min=None): """Get the minimal equivalent geometry which contains the region. Parameters ---------- width_min : `~astropy.quantity.Quantity` Minimal width for the resulting geometry. Can be a single number or two, for different minimum widths in each spatial dimension. Returns ------- wcs_geom : `~WcsGeom` A WCS geometry object. """ if width_min is not None: width = np.max( [self.width.to_value("deg"), _check_width(width_min)], axis=0 ) else: width = self.width wcs_geom_region = WcsGeom(wcs=self.wcs, npix=self.wcs.array_shape) wcs_geom = wcs_geom_region.cutout(position=self.center_skydir, width=width) wcs_geom = wcs_geom.to_cube(self.axes) return wcs_geom
[docs] def to_binsz_wcs(self, binsz): """Change the bin size of the underlying WCS geometry. Parameters ---------- binzs : float, string or `~astropy.quantity.Quantity` Returns ------- region : `~RegionGeom` A RegionGeom with the same axes and region as the input, but different wcs pixelization. """ new_geom = RegionGeom(self.region, axes=self.axes, binsz_wcs=binsz) return new_geom
[docs] def get_wcs_coord_and_weights(self, factor=10): """Get the array of spatial coordinates and corresponding weights The coordinates are the center of a pixel that intersects the region and the weights that represent which fraction of the pixel is contained in the region. Parameters ---------- factor : int Oversampling factor to compute the weights Returns ------- region_coord : `~MapCoord` MapCoord object with the coordinates inside the region. weights : `~np.array` Weights representing the fraction of each pixel contained in the region. """ wcs_geom = self.to_wcs_geom() weights = wcs_geom.to_image().region_weights( regions=[self.region], oversampling_factor=factor ) mask = weights.data > 0 weights = weights.data[mask] # Get coordinates coords = wcs_geom.get_coord(sparse=True).apply_mask(mask) return coords, weights
[docs] def to_binsz(self, binsz): """Returns self""" return self
[docs] def to_cube(self, axes): """Append non-spatial axes to create a higher-dimensional geometry. Returns ------- region : `~RegionGeom` RegionGeom with the added axes. """ axes = copy.deepcopy(self.axes) + axes return self._init_copy(region=self.region, wcs=self.wcs, axes=axes)
[docs] def to_image(self): """Remove non-spatial axes to create a 2D region. Returns ------- region : `~RegionGeom` RegionGeom without any non-spatial axes. """ return self._init_copy(region=self.region, wcs=self.wcs, axes=None)
[docs] def upsample(self, factor, axis_name=None): """Upsample a non-spatial dimension of the region by a given factor. Returns ------- region : `~RegionGeom` RegionGeom with the upsampled axis. """ axes = self.axes.upsample(factor=factor, axis_name=axis_name) return self._init_copy(region=self.region, wcs=self.wcs, axes=axes)
[docs] def downsample(self, factor, axis_name): """Downsample a non-spatial dimension of the region by a given factor. Returns ------- region : `~RegionGeom` RegionGeom with the downsampled axis. """ axes = self.axes.downsample(factor=factor, axis_name=axis_name) return self._init_copy(region=self.region, wcs=self.wcs, axes=axes)
[docs] def pix_to_coord(self, pix): lon = np.where( (-0.5 < pix[0]) & (pix[0] < 0.5), self.center_skydir.data.lon, np.nan * u.deg, ) lat = np.where( (-0.5 < pix[1]) & (pix[1] < 0.5), self.center_skydir.data.lat, np.nan * u.deg, ) coords = (lon, lat) for p, ax in zip(pix[self._slice_non_spatial_axes], self.axes): coords += (ax.pix_to_coord(p),) return coords
[docs] def pix_to_idx(self, pix, clip=False): idxs = list(pix_tuple_to_idx(pix)) for i, idx in enumerate(idxs[self._slice_non_spatial_axes]): if clip: np.clip(idx, 0, self.axes[i].nbin - 1, out=idx) else: np.putmask(idx, (idx < 0) | (idx >= self.axes[i].nbin), -1) return tuple(idxs)
[docs] def coord_to_pix(self, coords): # inherited docstring if isinstance(coords, tuple) and len(coords) == len(self.axes): skydir = self.center_skydir.transform_to(self.frame) coords = (skydir.data.lon, skydir.data.lat) + coords elif isinstance(coords, dict): valid_keys = ["lon", "lat", "skycoord"] if not any([_ in coords for _ in valid_keys]): coords.setdefault("skycoord", self.center_skydir) coords = MapCoord.create(coords, frame=self.frame, axis_names=self.axes.names) if self.region is None: pix = (0, 0) else: in_region = self.contains(coords.skycoord) x = np.zeros(coords.skycoord.shape) x[~in_region] = np.nan y = np.zeros(coords.skycoord.shape) y[~in_region] = np.nan pix = (x, y) pix += self.axes.coord_to_pix(coords) return pix
[docs] def get_idx(self): idxs = [np.arange(n, dtype=float) for n in self.data_shape[::-1]] return np.meshgrid(*idxs[::-1], indexing="ij")[::-1]
def _make_bands_cols(self): return []
[docs] @classmethod def create(cls, region, **kwargs): """Create region geometry. The input region can be passed in the form of a ds9 string and will be parsed internally by `~regions.Regions.parse`. See: * https://astropy-regions.readthedocs.io/en/stable/region_io.html * http://ds9.si.edu/doc/ref/region.html Parameters ---------- region : str or `~regions.SkyRegion` Region definition **kwargs : dict Keyword arguments passed to `RegionGeom.__init__` Returns ------- geom : `RegionGeom` Region geometry """ return cls.from_regions(regions=region, **kwargs)
def __repr__(self): axes = ["lon", "lat"] + [_.name for _ in self.axes] try: frame = self.center_skydir.frame.name lon = self.center_skydir.data.lon.deg lat = self.center_skydir.data.lat.deg except AttributeError: frame, lon, lat = "", np.nan, np.nan return ( f"{self.__class__.__name__}\n\n" f"\tregion : {self.region.__class__.__name__}\n" f"\taxes : {axes}\n" f"\tshape : {self.data_shape[::-1]}\n" f"\tndim : {self.ndim}\n" f"\tframe : {frame}\n" f"\tcenter : {lon:.1f} deg, {lat:.1f} deg\n" )
[docs] def is_allclose(self, other, rtol_axes=1e-6, atol_axes=1e-6): """Compare two data IRFs for equivalency Parameters ---------- other : `RegionGeom` Geom to compare against. rtol_axes : float Relative tolerance for the axes comparison. atol_axes : float Relative tolerance for the axes comparison. Returns ------- is_allclose : bool Whether the geometry is all close. """ if not isinstance(other, self.__class__): return TypeError(f"Cannot compare {type(self)} and {type(other)}") if self.data_shape != other.data_shape: return False axes_eq = self.axes.is_allclose(other.axes, rtol=rtol_axes, atol=atol_axes) # TODO: compare regions based on masks... regions_eq = True return axes_eq and regions_eq
def __eq__(self, other): if not isinstance(other, self.__class__): return False return self.is_allclose(other=other) def _to_region_table(self): """Export region to a FITS region table.""" if self.region is None: raise ValueError("Region definition required.") region_list = compound_region_to_regions(self.region) pixel_region_list = [] for reg in region_list: pixel_region_list.append(reg.to_pixel(self.wcs)) table = Regions(pixel_region_list).serialize(format="fits") header = WcsGeom(wcs=self.wcs, npix=self.wcs.array_shape).to_header() table.meta.update(header) return table
[docs] def to_hdulist(self, format="ogip", hdu_bands=None, hdu_region=None): """Convert geom to hdulist Parameters ---------- format : {"gadf", "ogip", "ogip-sherpa"} HDU format hdu : str Name of the HDU with the map data. Returns ------- hdulist : `~astropy.io.fits.HDUList` HDU list """ if hdu_bands is None: hdu_bands = "HDU_BANDS" if hdu_region is None: hdu_region = "HDU_REGION" if format != "gadf": hdu_region = "REGION" hdulist = fits.HDUList() hdulist.append(self.axes.to_table_hdu(hdu_bands=hdu_bands, format=format)) # region HDU if self.region: region_table = self._to_region_table() region_hdu = fits.BinTableHDU(region_table, name=hdu_region) hdulist.append(region_hdu) return hdulist
[docs] @classmethod def from_regions(cls, regions, **kwargs): """Create region geom from list of regions The regions are combined with union to a compound region. Parameters ---------- regions : list of `~regions.SkyRegion` or str Regions **kwargs: dict Keyword arguments forwarded to `RegionGeom` Returns ------- geom : `RegionGeom` Region map geometry """ if isinstance(regions, str): regions = Regions.parse(data=regions, format="ds9") elif isinstance(regions, SkyRegion): regions = [regions] elif isinstance(regions, SkyCoord): regions = [PointSkyRegion(center=regions)] elif isinstance(regions, list) and len(regions) == 0: regions = None if regions: regions = regions_to_compound_region(regions) return cls(region=regions, **kwargs)
[docs] @classmethod def from_hdulist(cls, hdulist, format="ogip", hdu=None): """Read region table and convert it to region list. Parameters ---------- hdulist : `~astropy.io.fits.HDUList` HDU list format : {"ogip", "ogip-arf", "gadf"} HDU format Returns ------- geom : `RegionGeom` Region map geometry """ region_hdu = "REGION" if format == "gadf" and hdu: region_hdu = hdu + "_" + region_hdu if region_hdu in hdulist: try: region_table = QTable.read(hdulist[region_hdu]) regions_pix = Regions.parse(data=region_table, format="fits") except TypeError: # TODO: this is needed to support regions=0.5 region_table = Table.read(hdulist[region_hdu]) regions_pix = Regions.parse(data=region_table, format="fits") wcs = WcsGeom.from_header(region_table.meta).wcs regions = [] for region_pix in regions_pix: # TODO: remove workaround once regions issue with fits serialization is sorted out # see https://github.com/astropy/regions/issues/400 region_pix.meta["include"] = True regions.append(region_pix.to_sky(wcs)) region = regions_to_compound_region(regions) else: region, wcs = None, None if format == "ogip": hdu_bands = "EBOUNDS" elif format == "ogip-arf": hdu_bands = "SPECRESP" elif format == "gadf": hdu_bands = hdu + "_BANDS" else: raise ValueError(f"Unknown format {format}") axes = MapAxes.from_table_hdu(hdulist[hdu_bands], format=format) return cls(region=region, wcs=wcs, axes=axes)
[docs] def union(self, other): """Stack a RegionGeom by making the union""" if not self == other: raise ValueError("Can only make union if extra axes are equivalent.") if other.region: if self.region: self._region = self.region.union(other.region) else: self._region = other.region
[docs] def plot_region(self, ax=None, kwargs_point=None, path_effect=None, **kwargs): """Plot region in the sky. Parameters ---------- ax : `~astropy.visualization.WCSAxes` Axes to plot on. If no axes are given, the region is shown using the minimal equivalent WCS geometry. kwargs_point : dict Keyword arguments passed to `~matplotlib.lines.Line2D` for plotting of point sources path_effect : `~matplotlib.patheffects.PathEffect` Path effect applied to artists and lines. **kwargs : dict Keyword arguments forwarded to `~regions.PixelRegion.as_artist` Returns ------- ax : `~astropy.visualization.WCSAxes` Axes to plot on. """ if self.region: kwargs_point = kwargs_point or {} if ax is None: ax = plt.gca() if not isinstance(ax, WCSAxes): ax.remove() wcs_geom = self.to_wcs_geom() m = Map.from_geom(geom=wcs_geom.to_image()) ax = m.plot(add_cbar=False, vmin=-1, vmax=0) kwargs.setdefault("facecolor", "None") kwargs.setdefault("edgecolor", "tab:blue") kwargs_point.setdefault("marker", "*") for key, value in kwargs.items(): key_point = ARTIST_TO_LINE_PROPERTIES.get(key, None) if key_point: kwargs_point[key_point] = value for region in compound_region_to_regions(self.region): region_pix = region.to_pixel(wcs=ax.wcs) if isinstance(region, PointSkyRegion): artist = region_pix.as_artist(**kwargs_point) else: artist = region_pix.as_artist(**kwargs) if path_effect: artist.add_path_effect(path_effect) ax.add_artist(artist) return ax else: logging.info("Region definition required.")