# Licensed under a 3-clause BSD style license - see LICENSE.rst
import copy
import logging
from functools import lru_cache
import numpy as np
from astropy import units as u
from astropy.coordinates import Angle, SkyCoord
from astropy.io import fits
from astropy.table import QTable, Table
from astropy.utils import lazyproperty
from astropy.visualization.wcsaxes import WCSAxes
from astropy.wcs.utils import (
from regions import (
import matplotlib.pyplot as plt
from gammapy.utils.regions import (
from gammapy.visualization.utils import ARTIST_TO_LINE_PROPERTIES
from ..axes import MapAxes
from ..coord import MapCoord
from ..core import Map
from ..geom import Geom, pix_tuple_to_idx
from ..utils import _check_width
from ..wcs import WcsGeom
__all__ = ["RegionGeom"]
class RegionGeom(Geom):
"""Map geometry representing a region on the sky.
The spatial component of the geometry is made up of a
single pixel with an arbitrary shape and size. It can
also have any number of non-spatial dimensions. This class
represents the geometry for the `RegionNDMap` maps.
region : `~regions.SkyRegion`
Region object.
axes : list of `MapAxis`
Non-spatial data axes.
wcs : `~astropy.wcs.WCS`
Optional WCS object to project the region if needed.
binsz_wcs : `~astropy.units.Quantity`
Angular bin size of the underlying `~WcsGeom` used to evaluate
quantities in the region. Default is "0.1 deg". This default
value is adequate for the majority of use cases. If a WCS object
is provided, the input of ``binsz_wcs`` is overridden.
is_regular = True
is_allsky = False
is_hpx = False
is_region = True
_slice_spatial_axes = slice(0, 2)
_slice_non_spatial_axes = slice(2, None)
projection = "TAN"
def __init__(self, region, axes=None, wcs=None, binsz_wcs="0.1 deg"):
self._region = region
self._axes = MapAxes.from_default(axes, n_spatial_axes=2)
self._binsz_wcs = u.Quantity(binsz_wcs)
if wcs is None and region is not None:
if isinstance(region, CompoundSkyRegion):
self._center = compound_region_center(region)
self._center = region.center
wcs = WcsGeom.create(
self._wcs = wcs
# TODO : can we get the width before defining the wcs ?
wcs = WcsGeom.create(
self._wcs = wcs
self.ndim = len(self.data_shape)
# define cached methods
self.get_wcs_coord_and_weights = lru_cache()(self.get_wcs_coord_and_weights)
def __setstate__(self, state):
for key, value in state.items():
if key in ["get_wcs_coord_and_weights"]:
state[key] = lru_cache()(value)
self.__dict__ = state
def frame(self):
"""Coordinate system, either Galactic ("galactic") or Equatorial ("icrs")."""
if self.region is None:
return "icrs"
return self.region.center.frame.name
except AttributeError:
return wcs_to_celestial_frame(self.wcs).name
def binsz_wcs(self):
"""Angular bin size of the underlying `~WcsGeom`.
binsz_wcs: `~astropy.coordinates.Angle`
Angular bin size.
return Angle(proj_plane_pixel_scales(self.wcs), unit="deg")
def _rectangle_bbox(self):
if self.region is None:
raise ValueError("Region definition required.")
regions = compound_region_to_regions(self.region)
regions_pix = [_.to_pixel(self.wcs) for _ in regions]
bbox = regions_pix[0].bounding_box
for region_pix in regions_pix[1:]:
bbox = bbox.union(region_pix.bounding_box)
rectangle_pix = bbox.to_region()
except ValueError:
rectangle_pix = RectanglePixelRegion(
center=PixCoord(*bbox.center[::-1]), width=1, height=1
return rectangle_pix.to_sky(self.wcs)
def width(self):
"""Width of bounding box of the region.
width : `~astropy.units.Quantity`
Dimensions of the region in both spatial dimensions.
Units: ``deg``
rectangle = self._rectangle_bbox
return u.Quantity([rectangle.width.to("deg"), rectangle.height.to("deg")])
def region(self):
"""The spatial component of the region geometry as a `~regions.SkyRegion`."""
return self._region
def is_all_point_sky_regions(self):
"""Whether regions are all point regions."""
regions = compound_region_to_regions(self.region)
return np.all([isinstance(_, PointSkyRegion) for _ in regions])
def axes(self):
"""List of non-spatial axes."""
return self._axes
def axes_names(self):
"""All axes names."""
return ["lon", "lat"] + self.axes.names
def wcs(self):
"""WCS projection object."""
return self._wcs
def center_coord(self):
"""Center coordinate of the region as a `astropy.coordinates.SkyCoord`."""
return self.pix_to_coord(self.center_pix)
def center_pix(self):
"""Pixel values corresponding to the center of the region."""
return tuple((np.array(self.data_shape) - 1.0) / 2)[::-1]
def center_skydir(self):
"""Sky coordinate of the center of the region."""
if self.region is None:
return SkyCoord(np.nan * u.deg, np.nan * u.deg)
return self._rectangle_bbox.center
def npix(self):
"""Number of spatial pixels."""
return ([1], [1])
def contains(self, coords):
"""Check if a given map coordinate is contained in the region.
Requires the `.region` attribute to be set.
For `PointSkyRegion` the method always returns True.
coords : tuple, dict, `MapCoord` or `~astropy.coordinates.SkyCoord`
Object containing coordinate arrays we wish to check for inclusion
in the region.
mask : `~numpy.ndarray`
Boolean array with the same shape as the input that indicates
which coordinates are inside the region.
if self.region is None:
raise ValueError("Region definition required.")
coords = MapCoord.create(coords, frame=self.frame, axis_names=self.axes.names)
if self.is_all_point_sky_regions:
return np.ones(coords.skycoord.shape, dtype=bool)
return self.region.contains(coords.skycoord, self.wcs)
def contains_wcs_pix(self, pix):
"""Check if a given WCS pixel coordinate is contained in the region.
For `PointSkyRegion` the method always returns True.
pix : tuple
Tuple of pixel coordinates.
containment : `~numpy.ndarray`
Boolean array.
if self.is_all_point_sky_regions:
return np.ones(pix[0].shape, dtype=bool)
region_pix = self.region.to_pixel(self.wcs)
return region_pix.contains(PixCoord(pix[0], pix[1]))
def separation(self, position):
"""Angular distance between the center of the region and the given position.
position : `astropy.coordinates.SkyCoord`
Sky coordinate we want the angular distance to.
sep : `~astropy.coordinates.Angle`
The on-sky separation between the given coordinate and the region center.
return self.center_skydir.separation(position)
def data_shape(self):
"""Shape of the `~numpy.ndarray` matching this geometry."""
return self._shape[::-1]
def data_shape_axes(self):
"""Shape of data of the non-spatial axes and unit spatial axes."""
return self.axes.shape[::-1] + (1, 1)
def _shape(self):
"""Number of bins in each dimension.
The spatial dimension is always (1, 1), as a
`RegionGeom` is not pixelized further.
return tuple((1, 1) + self.axes.shape)
def get_coord(self, mode="center", frame=None, sparse=False, axis_name=None):
"""Get map coordinates from the geometry.
mode : {'center', 'edges'}, optional
Get center or edge coordinates for the non-spatial axes.
Default is "center".
frame : str or `~astropy.coordinates.Frame`, optional
Coordinate frame. Default is None.
sparse : bool, optional
Compute sparse coordinates. Default is False.
axis_name : str, optional
If mode = "edges", the edges will be returned for this axis only.
Default is None.
coord : `~MapCoord`
Map coordinate object.
if mode == "edges" and axis_name is None:
raise ValueError("Mode 'edges' requires axis name")
coords = self.axes.get_coord(mode=mode, axis_name=axis_name)
coords["skycoord"] = self.center_skydir.reshape((1, 1))
if frame is None:
frame = self.frame
coords = MapCoord.create(coords, frame=self.frame).to_frame(frame)
if not sparse:
coords = coords.broadcasted
return coords
def _pad_spatial(self, pad_width):
raise NotImplementedError("Spatial padding of `RegionGeom` not supported")
def crop(self):
raise NotImplementedError("Cropping of `RegionGeom` not supported")
def solid_angle(self):
"""Get solid angle of the region.
angle : `~astropy.units.Quantity`
Solid angle of the region in steradians.
Units: ``sr``
if self.region is None:
raise ValueError("Region definition required.")
# compound regions do not implement area()
# so we use the mask representation and estimate the area
# from the pixels in the mask using oversampling
if isinstance(self.region, CompoundSkyRegion):
# oversample by a factor of ten
oversampling = 10.0
wcs = self.to_binsz_wcs(self.binsz_wcs / oversampling).wcs
pixel_region = self.region.to_pixel(wcs)
mask = pixel_region.to_mask()
area = np.count_nonzero(mask) / oversampling**2
# all other types of regions should implement area
area = self.region.to_pixel(self.wcs).area
solid_angle = area * proj_plane_pixel_area(self.wcs) * u.deg**2
return solid_angle.to("sr")
def bin_volume(self):
"""If the `RegionGeom` has a non-spatial axis, it returns the volume of the region.
If not, it returns the solid angle size.
volume : `~astropy.units.Quantity`
Volume of the region.
bin_volume = self.solid_angle() * np.ones(self.data_shape)
for idx, ax in enumerate(self.axes):
shape = self.ndim * [1]
shape[-(idx + 3)] = -1
bin_volume = bin_volume * ax.bin_width.reshape(tuple(shape))
return bin_volume
def to_wcs_geom(self, width_min=None):
"""Get the minimal equivalent geometry which contains the region.
width_min : `~astropy.quantity.Quantity`, optional
Minimum width for the resulting geometry. Can be a single number or two,
for different minimum widths in each spatial dimension.
Default is None.
wcs_geom : `~WcsGeom`
A WCS geometry object.
if width_min is not None:
width = np.max(
[self.width.to_value("deg"), _check_width(width_min)], axis=0
width = self.width
wcs_geom_region = WcsGeom(wcs=self.wcs)
wcs_geom = wcs_geom_region.cutout(position=self.center_skydir, width=width)
wcs_geom = wcs_geom.to_cube(self.axes)
return wcs_geom
def to_binsz_wcs(self, binsz):
"""Change the bin size of the underlying WCS geometry.
binsz : float, str or `~astropy.quantity.Quantity`
Bin size.
region : `~RegionGeom`
Region geometry with the same axes and region as the input,
but different WCS pixelization.
new_geom = RegionGeom(self.region, axes=self.axes, binsz_wcs=binsz)
return new_geom
def get_wcs_coord_and_weights(self, factor=10):
"""Get the array of spatial coordinates and corresponding weights.
The coordinates are the center of a pixel that intersects the region and
the weights that represent which fraction of the pixel is contained
in the region.
factor : int, optional
Oversampling factor to compute the weights.
Default is 10.
region_coord : `~MapCoord`
MapCoord object with the coordinates inside
the region.
weights : `~numpy.ndarray`
Weights representing the fraction of each pixel
contained in the region.
wcs_geom = self.to_wcs_geom()
weights = wcs_geom.to_image().region_weights(
regions=[self.region], oversampling_factor=factor
mask = weights.data > 0
weights = weights.data[mask]
# Get coordinates
coords = wcs_geom.get_coord(sparse=True).apply_mask(mask)
return coords, weights
def to_binsz(self, binsz):
"""Return self."""
return self
def to_cube(self, axes):
"""Append non-spatial axes to create a higher-dimensional geometry.
region : `~RegionGeom`
Region geometry with the added axes.
axes = copy.deepcopy(self.axes) + axes
return self._init_copy(region=self.region, wcs=self.wcs, axes=axes)
def to_image(self):
"""Remove non-spatial axes to create a 2D region.
region : `~RegionGeom`
Region geometry without any non-spatial axes.
return self._init_copy(region=self.region, wcs=self.wcs, axes=None)
def upsample(self, factor, axis_name=None):
"""Upsample a non-spatial dimension of the region by a given factor.
region : `~RegionGeom`
Region geometry with the upsampled axis.
axes = self.axes.upsample(factor=factor, axis_name=axis_name)
return self._init_copy(region=self.region, wcs=self.wcs, axes=axes)
def downsample(self, factor, axis_name):
"""Downsample a non-spatial dimension of the region by a given factor.
region : `~RegionGeom`
Region geometry with the downsampled axis.
axes = self.axes.downsample(factor=factor, axis_name=axis_name)
return self._init_copy(region=self.region, wcs=self.wcs, axes=axes)
def pix_to_coord(self, pix):
lon = np.where(
(-0.5 < pix[0]) & (pix[0] < 0.5),
np.nan * u.deg,
lat = np.where(
(-0.5 < pix[1]) & (pix[1] < 0.5),
np.nan * u.deg,
coords = (lon, lat)
for p, ax in zip(pix[self._slice_non_spatial_axes], self.axes):
coords += (ax.pix_to_coord(p),)
return coords
def pix_to_idx(self, pix, clip=False):
idxs = list(pix_tuple_to_idx(pix))
for i, idx in enumerate(idxs[self._slice_non_spatial_axes]):
if clip:
np.clip(idx, 0, self.axes[i].nbin - 1, out=idx)
np.putmask(idx, (idx < 0) | (idx >= self.axes[i].nbin), -1)
return tuple(idxs)
def coord_to_pix(self, coords):
# inherited docstring
if isinstance(coords, tuple) and len(coords) == len(self.axes):
skydir = self.center_skydir.transform_to(self.frame)
coords = (skydir.data.lon, skydir.data.lat) + coords
elif isinstance(coords, dict):
valid_keys = ["lon", "lat", "skycoord"]
if not any([_ in coords for _ in valid_keys]):
coords.setdefault("skycoord", self.center_skydir)
coords = MapCoord.create(coords, frame=self.frame, axis_names=self.axes.names)
if self.region is None:
pix = (0, 0)
in_region = self.contains(coords.skycoord)
x = np.zeros(coords.skycoord.shape)
x[~in_region] = np.nan
y = np.zeros(coords.skycoord.shape)
y[~in_region] = np.nan
pix = (x, y)
pix += self.axes.coord_to_pix(coords)
return pix
def get_idx(self):
idxs = [np.arange(n, dtype=float) for n in self.data_shape[::-1]]
return np.meshgrid(*idxs[::-1], indexing="ij")[::-1]
def _make_bands_cols(self):
return []
def create(cls, region, **kwargs):
"""Create region geometry.
The input region can be passed in the form of a ds9 string and will be parsed
internally by `~regions.Regions.parse`. See:
* https://astropy-regions.readthedocs.io/en/stable/region_io.html
* http://ds9.si.edu/doc/ref/region.html
region : str or `~regions.SkyRegion`
Region definition.
**kwargs : dict
Keyword arguments passed to `RegionGeom.__init__`.
geom : `RegionGeom`
Region geometry.
return cls.from_regions(regions=region, **kwargs)
def __str__(self):
axes = ["lon", "lat"] + [_.name for _ in self.axes]
frame = self.center_skydir.frame.name
lon = self.center_skydir.data.lon.deg
lat = self.center_skydir.data.lat.deg
except AttributeError:
frame, lon, lat = "", np.nan, np.nan
return (
f"\tregion : {self.region.__class__.__name__}\n"
f"\taxes : {axes}\n"
f"\tshape : {self.data_shape[::-1]}\n"
f"\tndim : {self.ndim}\n"
f"\tframe : {frame}\n"
f"\tcenter : {lon:.1f} deg, {lat:.1f} deg\n"
def is_allclose(self, other, rtol_axes=1e-6, atol_axes=1e-6):
"""Compare two data IRFs for equivalency.
other : `RegionGeom`
Region geometry to compare against.
rtol_axes : float, optional
Relative tolerance for the axes comparison.
Default is 1e-6.
atol_axes : float, optional
Relative tolerance for the axes comparison.
Default is 1e-6.
is_allclose : bool
Whether the geometry is all close.
if not isinstance(other, self.__class__):
return TypeError(f"Cannot compare {type(self)} and {type(other)}")
if self.data_shape != other.data_shape:
return False
axes_eq = self.axes.is_allclose(other.axes, rtol=rtol_axes, atol=atol_axes)
# TODO: compare regions based on masks...
regions_eq = True
return axes_eq and regions_eq
def __eq__(self, other):
if not isinstance(other, self.__class__):
return False
return self.is_allclose(other=other)
def _to_region_table(self):
"""Export region to a FITS region table."""
if self.region is None:
raise ValueError("Region definition required.")
region_list = compound_region_to_regions(self.region)
pixel_region_list = []
for reg in region_list:
table = Regions(pixel_region_list).serialize(format="fits")
header = WcsGeom(wcs=self.wcs).to_header()
return table
def to_hdulist(self, format="ogip", hdu_bands=None, hdu_region=None):
"""Convert geometry to HDU list.
format : {"ogip", "gadf", "ogip-sherpa"}
HDU format. Default is "ogip".
hdu_bands : str, optional
Name or index of the HDU with the BANDS table.
Default is None.
hdu_region : str, optional
Name or index of the HDU with the region table.
Not used for the "gadf" format.
Default is None.
hdulist : `~astropy.io.fits.HDUList`
HDU list.
if hdu_bands is None:
hdu_bands = "HDU_BANDS"
if hdu_region is None:
hdu_region = "HDU_REGION"
if format != "gadf":
hdu_region = "REGION"
hdulist = fits.HDUList()
hdulist.append(self.axes.to_table_hdu(hdu_bands=hdu_bands, format=format))
# region HDU
if self.region:
region_table = self._to_region_table()
region_hdu = fits.BinTableHDU(region_table, name=hdu_region)
return hdulist
def from_regions(cls, regions, **kwargs):
"""Create region geometry from list of regions.
The regions are combined with union to a compound region.
regions : list of `~regions.SkyRegion` or str
**kwargs: dict
Keyword arguments forwarded to `RegionGeom`.
geom : `RegionGeom`
Region map geometry.
if isinstance(regions, str):
regions = Regions.parse(data=regions, format="ds9")
elif isinstance(regions, SkyRegion):
regions = [regions]
elif isinstance(regions, SkyCoord):
regions = [PointSkyRegion(center=regions)]
elif isinstance(regions, list) and len(regions) == 0:
regions = None
if regions:
regions = regions_to_compound_region(regions)
return cls(region=regions, **kwargs)
def from_hdulist(cls, hdulist, format="ogip", hdu=None):
"""Read region table and convert it to region list.
hdulist : `~astropy.io.fits.HDUList`
HDU list.
format : {"ogip", "ogip-arf", "gadf"}
HDU format. Default is "ogip".
hdu : str, optional
Name of the HDU. Default is None.
geom : `RegionGeom`
Region map geometry.
region_hdu = "REGION"
if format == "gadf" and hdu:
region_hdu = f"{hdu}_{region_hdu}"
if region_hdu in hdulist:
region_table = QTable.read(hdulist[region_hdu])
regions_pix = Regions.parse(data=region_table, format="fits")
except TypeError:
region_table = Table.read(hdulist[region_hdu])
regions_pix = Regions.parse(data=region_table, format="fits")
wcs = WcsGeom.from_header(region_table.meta).wcs
regions = []
for region_pix in regions_pix:
# TODO: remove workaround once regions issue with fits serialization is sorted out
# see https://github.com/astropy/regions/issues/400
# requires update regions to >=v0.6
region_pix.meta["include"] = True
region = regions_to_compound_region(regions)
region, wcs = None, None
if format == "ogip":
hdu_bands = "EBOUNDS"
elif format == "ogip-arf":
hdu_bands = "SPECRESP"
elif format == "gadf":
hdu_bands = hdu + "_BANDS"
raise ValueError(f"Unknown format {format}")
axes = MapAxes.from_table_hdu(hdulist[hdu_bands], format=format)
return cls(region=region, wcs=wcs, axes=axes)
def union(self, other):
"""Stack a `RegionGeom` by making the union."""
if not self == other:
raise ValueError("Can only make union if extra axes are equivalent.")
if other.region:
if self.region:
self._region = self.region.union(other.region)
self._region = other.region
def plot_region(self, ax=None, kwargs_point=None, path_effect=None, **kwargs):
"""Plot region in the sky.
ax : `~astropy.visualization.WCSAxes`, optional
Axes to plot on. If no axes are given,
the region is shown using the minimal
equivalent WCS geometry.
Default is None.
kwargs_point : dict, optional
Keyword arguments passed to `~matplotlib.lines.Line2D` for plotting
of point sources. Default is None.
path_effect : `~matplotlib.patheffects.PathEffect`, optional
Path effect applied to artists and lines.
Default is None.
**kwargs : dict
Keyword arguments forwarded to `~regions.PixelRegion.as_artist`.
ax : `~astropy.visualization.WCSAxes`
Axes to plot on.
if self.region:
kwargs_point = kwargs_point or {}
if ax is None:
ax = plt.gca()
if not isinstance(ax, WCSAxes):
wcs_geom = self.to_wcs_geom()
m = Map.from_geom(geom=wcs_geom.to_image())
ax = m.plot(add_cbar=False, vmin=-1, vmax=0)
kwargs.setdefault("facecolor", "None")
kwargs.setdefault("edgecolor", "tab:blue")
kwargs_point.setdefault("marker", "*")
for key, value in kwargs.items():
key_point = ARTIST_TO_LINE_PROPERTIES.get(key, None)
if key_point:
kwargs_point[key_point] = value
for region in compound_region_to_regions(self.region):
region_pix = region.to_pixel(wcs=ax.wcs)
if isinstance(region, PointSkyRegion):
artist = region_pix.as_artist(**kwargs_point)
artist = region_pix.as_artist(**kwargs)
if path_effect:
return ax
logging.info("Region definition required.")