Source code for gammapy.maps.region

import copy
import numpy as np
from astropy import units as u
from astropy.coordinates import SkyCoord
from astropy.table import Table
from astropy.wcs import WCS
from astropy.wcs.utils import proj_plane_pixel_area, wcs_to_celestial_frame
from regions import FITSRegionParser, fits_region_objects_to_table
from gammapy.utils.regions import (
    compound_region_to_list,
    list_to_compound_region,
    make_region,
)
from .base import MapCoord
from .geom import Geom, MapAxis, make_axes, pix_tuple_to_idx
from .utils import edges_from_lo_hi
from .wcs import WcsGeom

__all__ = ["RegionGeom"]


[docs]class RegionGeom(Geom): """Map geometry representing a region on the sky. Parameters ---------- region : `~regions.SkyRegion` Region object. axes : list of `MapAxis` Non-spatial data axes. wcs : `~astropy.wcs.WCS` Optional wcs object to project the region if needed. """ is_image = False is_allsky = False is_hpx = False _slice_spatial_axes = slice(0, 2) _slice_non_spatial_axes = slice(2, None) projection = "TAN" binsz = 0.01 def __init__(self, region, axes=None, wcs=None): self._region = region self._axes = make_axes(axes) if axes is not None: if len(axes) > 1 or axes[0].name not in ["energy", "energy_true"]: raise ValueError("RegionGeom currently only supports an energy axes.") if wcs is None and region is not None: wcs = WcsGeom.create( skydir=region.center, binsz=self.binsz, proj=self.projection, frame=self.frame, ).wcs self._wcs = wcs self.ndim = len(self.data_shape) @property def frame(self): try: return self.region.center.frame.name except AttributeError: return wcs_to_celestial_frame(self.wcs).name @property def width(self): """Width of bounding box of the region""" if self.region is None: raise ValueError("Region definition required.") regions = compound_region_to_list(self.region) regions_pix = [_.to_pixel(self.wcs) for _ in regions] bbox = regions_pix[0].bounding_box for region_pix in regions_pix[1:]: bbox = bbox.union(region_pix.bounding_box) rectangle_pix = bbox.to_region() rectangle = rectangle_pix.to_sky(self.wcs) return u.Quantity([rectangle.width, rectangle.height]) @property def region(self): return self._region @property def axes(self): return self._axes @property def wcs(self): return self._wcs @property def center_coord(self): """(`astropy.coordinates.SkyCoord`)""" return self.pix_to_coord(self.center_pix) @property def center_pix(self): return tuple((np.array(self.data_shape) - 1.0) / 2)[::-1] @property def center_skydir(self): """Center skydir""" try: return self.region.center except AttributeError: xp, yp = self.wcs.wcs.crpix return SkyCoord.from_pixel(xp=xp, yp=yp, wcs=self.wcs)
[docs] def contains(self, coords): if self.region is None: raise ValueError("Region definition required.") coords = MapCoord.create(coords) return self.region.contains(coords.skycoord, self.wcs)
[docs] def separation(self, position): return self.center_skydir.separation(position)
@property def data_shape(self): """Shape of the Numpy data array matching this geometry.""" return self._shape[::-1] @property def _shape(self): npix_shape = [1, 1] ax_shape = [ax.nbin for ax in self.axes] return tuple(npix_shape + ax_shape)
[docs] def get_coord(self, frame=None): """Get map coordinates from the geometry. Returns ------- coord : `~MapCoord` Map coordinate object. """ cdict = {} cdict["skycoord"] = self.center_skydir.reshape((1, 1)) if self.axes is not None: for ax in self.axes: cdict[ax.name] = ax.center.reshape((-1, 1, 1)) if frame is None: frame = self.frame return MapCoord.create(cdict, frame=self.frame).to_frame(frame)
[docs] def pad(self): raise NotImplementedError("Padding of `RegionGeom` not supported")
[docs] def crop(self): raise NotImplementedError("Cropping of `RegionGeom` not supported")
[docs] def solid_angle(self): if self.region is None: raise ValueError("Region definition required.") area = self.region.to_pixel(self.wcs).area solid_angle = area * proj_plane_pixel_area(self.wcs) * u.deg ** 2 return solid_angle.to("sr")
[docs] def bin_volume(self): return self.solid_angle() * self.axes[0].bin_width.reshape((-1, 1, 1))
[docs] def to_cube(self, axes): axes = copy.deepcopy(self.axes) + axes return self._init_copy(axes=axes)
[docs] def to_image(self): return self._init_copy(axes=None)
[docs] def upsample(self, factor, axis): axes = copy.deepcopy(self.axes) idx = self.get_axis_index_by_name(axis) axes[idx] = axes[idx].upsample(factor) return self._init_copy(axes=axes)
[docs] def downsample(self, factor, axis): axes = copy.deepcopy(self.axes) idx = self.get_axis_index_by_name(axis) axes[idx] = axes[idx].downsample(factor) return self._init_copy(axes=axes)
[docs] def pix_to_coord(self, pix): lon = np.where( (-0.5 < pix[0]) & (pix[0] < 0.5), self.center_skydir.data.lon, np.nan * u.deg, ) lat = np.where( (-0.5 < pix[1]) & (pix[1] < 0.5), self.center_skydir.data.lat, np.nan * u.deg, ) coords = (lon, lat) for p, ax in zip(pix[self._slice_non_spatial_axes], self.axes): coords += (ax.pix_to_coord(p),) return coords
[docs] def pix_to_idx(self, pix, clip=False): idxs = list(pix_tuple_to_idx(pix)) for i, idx in enumerate(idxs[self._slice_non_spatial_axes]): if clip: np.clip(idx, 0, self.axes[i].nbin - 1, out=idx) else: np.putmask(idx, (idx < 0) | (idx >= self.axes[i].nbin), -1) return tuple(idxs)
[docs] def coord_to_pix(self, coords): if self.region is None: raise ValueError("Region definition required.") coords = MapCoord.create(coords, frame=self.frame) in_region = self.region.contains(coords.skycoord, wcs=self.wcs) x = np.zeros(coords.shape) x[~in_region] = np.nan y = np.zeros(coords.shape) y[~in_region] = np.nan pix = (x, y) for coord, ax in zip(coords[self._slice_non_spatial_axes], self.axes): pix += (ax.coord_to_pix(coord),) return pix
[docs] def get_idx(self): idxs = (0, 0) if self.axes is not None: for ax in self.axes: idxs += (np.arange(ax.nbin).reshape((-1, 1, 1)),) return np.broadcast_arrays(*idxs)
def _make_bands_cols(self): pass
[docs] @classmethod def create(cls, region, **kwargs): """Create region. Parameters ---------- region : str or `~regions.SkyRegion` Region axes : list of `MapAxis` Non spatial axes. Returns ------- geom : `RegionGeom` Region geometry """ if isinstance(region, str): region = make_region(region) return cls(region, **kwargs)
def __repr__(self): axes = ["lon", "lat"] + [_.name for _ in self.axes] try: frame = self.center_skydir.frame.name lon = self.center_skydir.data.lon.deg lat = self.center_skydir.data.lat.deg except AttributeError: frame, lon, lat = "", np.nan, np.nan return ( f"{self.__class__.__name__}\n\n" f"\tregion : {self.region.__class__.__name__}\n" f"\taxes : {axes}\n" f"\tshape : {self.data_shape[::-1]}\n" f"\tndim : {self.ndim}\n" f"\tframe : {frame}\n" f"\tcenter : {lon:.1f} deg, {lat:.1f} deg\n" ) def __eq__(self, other): # check overall shape and axes compatibility if self.data_shape != other.data_shape: return False for axis, otheraxis in zip(self.axes, other.axes): if axis != otheraxis: return False # TODO: compare regions return True def _to_region_table(self): """Export region to a FITS region table.""" if self.region is None: raise ValueError("Region definition required.") # TODO: make this a to_hdulist() method region_list = compound_region_to_list(self.region) pixel_region_list = [] for reg in region_list: pixel_region_list.append(reg.to_pixel(self.wcs)) table = fits_region_objects_to_table(pixel_region_list) table.meta.update(self.wcs.to_header()) return table
[docs] @classmethod def from_hdulist(cls, hdulist, format="ogip"): """Read region table and convert it to region list.""" if "REGION" in hdulist: region_table = Table.read(hdulist["REGION"]) parser = FITSRegionParser(region_table) pix_region = parser.shapes.to_regions() wcs = WCS(region_table.meta) regions = [] for reg in pix_region: regions.append(reg.to_sky(wcs)) region = list_to_compound_region(regions) else: region, wcs = None, None ebounds = Table.read(hdulist["EBOUNDS"]) emin = ebounds["E_MIN"].quantity emax = ebounds["E_MAX"].quantity edges = edges_from_lo_hi(emin, emax) axis = MapAxis.from_edges(edges, interp="log", name="energy") return cls(region=region, wcs=wcs, axes=[axis])
[docs] def union(self, other): """Stack a RegionGeom by making the union""" if not self == other: print(self, other) raise ValueError("Can only make union if extra axes are equivalent.") if other.region: if self.region: self._region = self.region.union(other.region) else: self._region = other.region