# Licensed under a 3-clause BSD style license - see LICENSE.rstimportcopyimportinspectfromcollections.abcimportSequenceimportnumpyasnpimportscipyimportastropy.unitsasufromastropy.ioimportfitsfromastropy.tableimportColumn,Table,hstackfromastropy.timeimportTimefromastropy.utilsimportlazypropertyimportmatplotlib.pyplotaspltfromgammapy.utils.interpolationimportinterpolation_scalefromgammapy.utils.timeimporttime_ref_from_dict,time_ref_to_dictfrom.utilsimportINVALID_INDEX,edges_from_lo_hi__all__=["MapAxes","MapAxis","TimeMapAxis","LabelMapAxis"]defflat_if_equal(array):ifarray.ndim==2andnp.all(array==array[0]):returnarray[0]else:returnarrayclassAxisCoordInterpolator:"""Axis coord interpolator"""def__init__(self,edges,interp="lin"):self.scale=interpolation_scale(interp)self.x=self.scale(edges)self.y=np.arange(len(edges),dtype=float)self.fill_value="extrapolate"iflen(edges)==1:self.kind=0else:self.kind=1defcoord_to_pix(self,coord):"""Pix to coord"""interp_fn=scipy.interpolate.interp1d(x=self.x,y=self.y,kind=self.kind,fill_value=self.fill_value)returninterp_fn(self.scale(coord))defpix_to_coord(self,pix):"""Coord to pix"""interp_fn=scipy.interpolate.interp1d(x=self.y,y=self.x,kind=self.kind,fill_value=self.fill_value)returnself.scale.inverse(interp_fn(pix))PLOT_AXIS_LABEL={"energy":"Energy","energy_true":"True Energy","offset":"FoV Offset","rad":"Source Offset","migra":"Energy / True Energy","fov_lon":"FoV Lon.","fov_lat":"FoV Lat.","time":"Time",}DEFAULT_LABEL_TEMPLATE="{quantity} [{unit}]"
[docs]classMapAxis:"""Class representing an axis of a map. Provides methods for transforming to/from axis and pixel coordinates. An axis is defined by a sequence of node values that lie at the center of each bin. The pixel coordinate at each node is equal to its index in the node array (0, 1, ..). Bin edges are offset by 0.5 in pixel coordinates from the nodes such that the lower/upper edge of the first bin is (-0.5,0.5). Parameters ---------- nodes : `~numpy.ndarray` or `~astropy.units.Quantity` Array of node values. These will be interpreted as either bin edges or centers according to ``node_type``. interp : str Interpolation method used to transform between axis and pixel coordinates. Valid options are 'log', 'lin', and 'sqrt'. name : str Axis name node_type : str Flag indicating whether coordinate nodes correspond to pixel edges (node_type = 'edges') or pixel centers (node_type = 'center'). 'center' should be used where the map values are defined at a specific coordinate (e.g. differential quantities). 'edges' should be used where map values are defined by an integral over coordinate intervals (e.g. a counts histogram). unit : str String specifying the data units. """# TODO: Cache an interpolation object?def__init__(self,nodes,interp="lin",name="",node_type="edges",unit=""):ifnotisinstance(name,str):raiseTypeError(f"Name must be a string, got: {type(name)!r}")iflen(nodes)!=len(np.unique(nodes)):raiseValueError("MapAxis: node values must be unique")if~(np.all(nodes==np.sort(nodes))ornp.all(nodes[::-1]==np.sort(nodes))):raiseValueError("MapAxis: node values must be sorted")ifisinstance(nodes,u.Quantity):unit=nodes.unitifnodes.unitisnotNoneelse""nodes=nodes.valueelse:nodes=np.array(nodes)self._name=nameself._unit=u.Unit(unit)self._nodes=nodes.astype(float)self._node_type=node_typeself._interp=interpif(self._nodes<0).any()andinterp!="lin":raiseValueError(f"Interpolation scaling {interp!r} only support for positive node values.")# Set pixel coordinate of first nodeifnode_type=="edges":self._pix_offset=-0.5nbin=len(nodes)-1elifnode_type=="center":self._pix_offset=0.0nbin=len(nodes)else:raiseValueError(f"Invalid node type: {node_type!r}")self._nbin=nbinself._use_center_as_plot_labels=None
[docs]defassert_name(self,required_name):"""Assert axis name if a specific one is required. Parameters ---------- required_name : str Required """ifself.name!=required_name:raiseValueError("Unexpected axis name,"f' expected "{required_name}", got: "{self.name}"')
[docs]defis_aligned(self,other,atol=2e-2):"""Check if other map axis is aligned. Two axes are aligned if their center coordinate values map to integers on the other axes as well and if the interpolation modes are equivalent. Parameters ---------- other : `MapAxis` Other map axis. atol : float Absolute numerical tolerance for the comparison measured in bins. Returns ------- aligned : bool Whether the axes are aligned """pix=self.coord_to_pix(other.center)pix_other=other.coord_to_pix(self.center)pix_all=np.append(pix,pix_other)aligned=np.allclose(np.round(pix_all)-pix_all,0,atol=atol)returnalignedandself.interp==other.interp
[docs]defis_allclose(self,other,**kwargs):"""Check if other map axis is all close. Parameters ---------- other : `MapAxis` Other map axis **kwargs : dict Keyword arguments forwarded to `~numpy.allclose` Returns ------- is_allclose : bool Whether other axis is allclose """ifnotisinstance(other,self.__class__):returnTypeError(f"Cannot compare {type(self)} and {type(other)}")ifself.edges.shape!=other.edges.shape:returnFalseifnotself.unit.is_equivalent(other.unit):returnFalsereturn(np.allclose(self.edges,other.edges,**kwargs)andself._node_type==other._node_typeandself._interp==other._interpandself.name.upper()==other.name.upper())
def__eq__(self,other):ifnotisinstance(other,self.__class__):returnFalsereturnself.is_allclose(other,rtol=1e-6,atol=1e-6)def__ne__(self,other):returnnotself.__eq__(other)def__hash__(self):returnid(self)@lazypropertydef_transform(self):"""Interpolate coordinates to pixel"""returnAxisCoordInterpolator(edges=self._nodes,interp=self.interp)@propertydefis_energy_axis(self):returnself.namein["energy","energy_true"]@propertydefinterp(self):"""Interpolation scale of the axis."""returnself._interp@propertydefname(self):"""Name of the axis."""returnself._name@lazypropertydefedges(self):"""Return array of bin edges."""pix=np.arange(self.nbin+1,dtype=float)-0.5returnu.Quantity(self.pix_to_coord(pix),self._unit,copy=False)@propertydefedges_min(self):"""Return array of bin edges max values."""returnself.edges[:-1]@propertydefedges_max(self):"""Return array of bin edges min values."""returnself.edges[1:]@propertydefbounds(self):"""Bounds of the axis (~astropy.units.Quantity)"""idx=[0,-1]ifself.node_type=="edges":returnself.edges[idx]else:returnself.center[idx]@propertydefas_plot_xerr(self):"""Return tuple of xerr to be used with plt.errorbar()"""return(self.center-self.edges_min,self.edges_max-self.center,)@propertydefuse_center_as_plot_labels(self):"""Use center as plot labels"""ifself._use_center_as_plot_labelsisnotNone:returnself._use_center_as_plot_labelsreturnself.node_type=="center"@use_center_as_plot_labels.setterdefuse_center_as_plot_labels(self,value):"""Use center as plot labels"""self._use_center_as_plot_labels=bool(value)@propertydefas_plot_labels(self):"""Return list of axis plot labels"""ifself.use_center_as_plot_labels:labels=[f"{val:.2e}"forvalinself.center]else:labels=[f"{val_min:.2e} - {val_max:.2e}"forval_min,val_maxinself.iter_by_edges]returnlabels@propertydefas_plot_edges(self):"""Plot edges"""returnself.edges@propertydefas_plot_center(self):"""Plot center"""returnself.center@propertydefas_plot_scale(self):"""Plot axis scale"""mpl_scale={"lin":"linear","sqrt":"linear","log":"log"}returnmpl_scale[self.interp]
[docs]defto_node_type(self,node_type):"""Return MapAxis copy changing its node type to node_type. Parameters ---------- node_type : str 'edges' or 'center' the target node type Returns ------- axis : `~gammapy.maps.MapAxis` the new MapAxis """ifnode_type==self.node_type:returnselfelse:ifnode_type=="center":nodes=self.centerelse:nodes=self.edgesreturnself.__class__(nodes=nodes,interp=self.interp,name=self.name,node_type=node_type,unit=self.unit,)
[docs]defrename(self,new_name):"""Rename the axis. Parameters ---------- new_name : str The new name for the axis. Returns ------- axis : `~gammapy.maps.MapAxis` Renamed MapAxis """returnself.copy(name=new_name)
@propertydefiter_by_edges(self):"""Iterate by intervals defined by the edges"""forvalue_min,value_maxinzip(self.edges[:-1],self.edges[1:]):yield(value_min,value_max)@lazypropertydefcenter(self):"""Return array of bin centers."""pix=np.arange(self.nbin,dtype=float)returnu.Quantity(self.pix_to_coord(pix),self._unit,copy=False)@lazypropertydefbin_width(self):"""Array of bin widths."""returnnp.diff(self.edges)@propertydefnbin(self):"""Return number of bins."""returnself._nbin@propertydefnbin_per_decade(self):"""Return number of bins."""ifself.interp!="log":raiseValueError("Bins per decade can only be computed for log-spaced axes")ifself.node_type=="edges":values=self.edgeselse:values=self.centerndecades=np.log10(values.max()/values.min())return(self._nbin/ndecades).value@propertydefnode_type(self):"""Return node type ('center' or 'edges')."""returnself._node_type@propertydefunit(self):"""Return coordinate axis unit."""returnself._unit
[docs]@classmethoddeffrom_bounds(cls,lo_bnd,hi_bnd,nbin,**kwargs):"""Generate an axis object from a lower/upper bound and number of bins. If node_type = 'edges' then bounds correspond to the lower and upper bound of the first and last bin. If node_type = 'center' then bounds correspond to the centers of the first and last bin. Parameters ---------- lo_bnd : float Lower bound of first axis bin. hi_bnd : float Upper bound of last axis bin. nbin : int Number of bins. interp : {'lin', 'log', 'sqrt'} Interpolation method used to transform between axis and pixel coordinates. Default: 'lin'. """nbin=int(nbin)interp=kwargs.setdefault("interp","lin")node_type=kwargs.setdefault("node_type","edges")ifnode_type=="edges":nnode=nbin+1elifnode_type=="center":nnode=nbinelse:raiseValueError(f"Invalid node type: {node_type!r}")ifinterp=="lin":nodes=np.linspace(lo_bnd,hi_bnd,nnode)elifinterp=="log":nodes=np.exp(np.linspace(np.log(lo_bnd),np.log(hi_bnd),nnode))elifinterp=="sqrt":nodes=np.linspace(lo_bnd**0.5,hi_bnd**0.5,nnode)**2.0else:raiseValueError(f"Invalid interp: {interp}")returncls(nodes,**kwargs)
[docs]@classmethoddeffrom_energy_edges(cls,energy_edges,unit=None,name=None,interp="log"):"""Make an energy axis from adjacent edges. Parameters ---------- energy_edges : `~astropy.units.Quantity`, float Energy edges unit : `~astropy.units.Unit` Energy unit name : str Name of the energy axis, either 'energy' or 'energy_true' interp: str interpolation mode. Default is 'log'. Returns ------- axis : `MapAxis` Axis with name "energy" and interp "log". """energy_edges=u.Quantity(energy_edges,unit)ifnotenergy_edges.unit.is_equivalent("TeV"):raiseValueError(f"Please provide a valid energy unit, got {energy_edges.unit} instead.")ifnameisNone:name="energy"ifnamenotin["energy","energy_true"]:raiseValueError("Energy axis can only be named 'energy' or 'energy_true'")returncls.from_edges(energy_edges,unit=unit,interp=interp,name=name)
[docs]@classmethoddeffrom_energy_bounds(cls,energy_min,energy_max,nbin,unit=None,per_decade=False,name=None,node_type="edges",):"""Make an energy axis. Used frequently also to make energy grids, by making the axis, and then using ``axis.center`` or ``axis.edges``. Parameters ---------- energy_min, energy_max : `~astropy.units.Quantity`, float Energy range nbin : int Number of bins unit : `~astropy.units.Unit` Energy unit per_decade : bool Whether `nbin` is given per decade. name : str Name of the energy axis, either 'energy' or 'energy_true' Returns ------- axis : `MapAxis` Axis with name "energy" and interp "log". """energy_min=u.Quantity(energy_min,unit)energy_max=u.Quantity(energy_max,unit)ifunitisNone:unit=energy_max.unitenergy_min=energy_min.to(unit)ifnotenergy_max.unit.is_equivalent("TeV"):raiseValueError(f"Please provide a valid energy unit, got {energy_max.unit} instead.")ifper_decade:nbin=np.ceil(np.log10(energy_max/energy_min).value*nbin)ifnameisNone:name="energy"ifnamenotin["energy","energy_true"]:raiseValueError("Energy axis can only be named 'energy' or 'energy_true'")returncls.from_bounds(energy_min.value,energy_max.value,nbin=nbin,unit=unit,interp="log",name=name,node_type=node_type,)
[docs]@classmethoddeffrom_nodes(cls,nodes,**kwargs):"""Generate an axis object from a sequence of nodes (bin centers). This will create a sequence of bins with edges half-way between the node values. This method should be used to construct an axis where the bin center should lie at a specific value (e.g. a map of a continuous function). Parameters ---------- nodes : `~numpy.ndarray` Axis nodes (bin center). interp : {'lin', 'log', 'sqrt'} Interpolation method used to transform between axis and pixel coordinates. Default: 'lin'. """iflen(nodes)<1:raiseValueError("Nodes array must have at least one element.")returncls(nodes,node_type="center",**kwargs)
[docs]@classmethoddeffrom_edges(cls,edges,**kwargs):"""Generate an axis object from a sequence of bin edges. This method should be used to construct an axis where the bin edges should lie at specific values (e.g. a histogram). The number of bins will be one less than the number of edges. Parameters ---------- edges : `~numpy.ndarray` Axis bin edges. interp : {'lin', 'log', 'sqrt'} Interpolation method used to transform between axis and pixel coordinates. Default: 'lin'. """iflen(edges)<2:raiseValueError("Edges array must have at least two elements.")returncls(edges,node_type="edges",**kwargs)
[docs]defappend(self,axis):"""Append another map axis to this axis Name, interp type and node type must agree between the axes. If the node type is "edges", the edges must be contiguous and non-overlapping. Parameters ---------- axis : `MapAxis` Axis to append. Returns ------- axis : `MapAxis` Appended axis """ifself.node_type!=axis.node_type:raiseValueError(f"Node type must agree, got {self.node_type} and {axis.node_type}")ifself.name!=axis.name:raiseValueError(f"Names must agree, got {self.name} and {axis.name} ")ifself.interp!=axis.interp:raiseValueError(f"Interp type must agree, got {self.interp} and {axis.interp}")ifself.node_type=="edges":edges=np.append(self.edges,axis.edges[1:])returnself.from_edges(edges=edges,interp=self.interp,name=self.name)else:nodes=np.append(self.center,axis.center)returnself.from_nodes(nodes=nodes,interp=self.interp,name=self.name)
[docs]defpad(self,pad_width):"""Pad axis by a given number of pixels Parameters ---------- pad_width : int or tuple of int A single int pads in both direction of the axis, a tuple specifies, which number of bins to pad at the low and high edge of the axis. Returns ------- axis : `MapAxis` Padded axis """ifisinstance(pad_width,tuple):pad_low,pad_high=pad_widthelse:pad_low,pad_high=pad_width,pad_widthifself.node_type=="edges":pix=np.arange(-pad_low,self.nbin+pad_high+1)-0.5edges=self.pix_to_coord(pix)returnself.from_edges(edges=edges,interp=self.interp,name=self.name)else:pix=np.arange(-pad_low,self.nbin+pad_high)nodes=self.pix_to_coord(pix)returnself.from_nodes(nodes=nodes,interp=self.interp,name=self.name)
[docs]@classmethoddeffrom_stack(cls,axes):"""Create a map axis by merging a list of other map axes. If the node type is "edges" the bin edges in the provided axes must be contiguous and non-overlapping. Parameters ---------- axes : list of `MapAxis` List of map axis to merge. Returns ------- axis : `MapAxis` Merged axis """ax_stacked=axes[0]foraxinaxes[1:]:ax_stacked=ax_stacked.append(ax)returnax_stacked
[docs]defpix_to_coord(self,pix):"""Transform from pixel to axis coordinates. Parameters ---------- pix : `~numpy.ndarray` Array of pixel coordinate values. Returns ------- coord : `~numpy.ndarray` Array of axis coordinate values. """pix=pix-self._pix_offsetvalues=self._transform.pix_to_coord(pix=pix)returnu.Quantity(values,unit=self.unit,copy=False)
[docs]defpix_to_idx(self,pix,clip=False):"""Convert pix to idx Parameters ---------- pix : `~numpy.ndarray` Pixel coordinates. clip : bool Choose whether to clip indices to the valid range of the axis. If false then indices for coordinates outside the axi range will be set -1. Returns ------- idx : `~numpy.ndarray` Pixel indices. """ifclip:idx=np.clip(pix,0,self.nbin-1)else:condition=(pix<0)|(pix>=self.nbin)idx=np.where(condition,-1,pix)returnidx
[docs]defcoord_to_pix(self,coord):"""Transform from axis to pixel coordinates. Parameters ---------- coord : `~numpy.ndarray` Array of axis coordinate values. Returns ------- pix : `~numpy.ndarray` Array of pixel coordinate values. """coord=u.Quantity(coord,self.unit,copy=False).valuepix=self._transform.coord_to_pix(coord=coord)returnnp.array(pix+self._pix_offset,ndmin=1)
[docs]defcoord_to_idx(self,coord,clip=False):"""Transform from axis coordinate to bin index. Parameters ---------- coord : `~numpy.ndarray` Array of axis coordinate values. clip : bool Choose whether to clip the index to the valid range of the axis. If false then indices for values outside the axis range will be set -1. Returns ------- idx : `~numpy.ndarray` Array of bin indices. """coord=u.Quantity(coord,self.unit,copy=False,ndmin=1).valueedges=self.edges.valueidx=np.digitize(coord,edges)-1ifclip:idx=np.clip(idx,0,self.nbin-1)else:withnp.errstate(invalid="ignore"):idx[coord>edges[-1]]=INVALID_INDEX.intidx[~np.isfinite(coord)]=INVALID_INDEX.intreturnidx
[docs]defslice(self,idx):"""Create a new axis object by extracting a slice from this axis. Parameters ---------- idx : slice Slice object selecting a subselection of the axis. Returns ------- axis : `~MapAxis` Sliced axis object. """center=self.center[idx].valueidx=self.coord_to_idx(center)# For edge nodes we need to keep N+1 nodesifself._node_type=="edges":idx=tuple(list(idx)+[1+idx[-1]])nodes=self._nodes[(idx,)]returnMapAxis(nodes,interp=self._interp,name=self._name,node_type=self._node_type,unit=self._unit,)
[docs]defsquash(self):"""Create a new axis object by squashing the axis into one bin. Returns ------- axis : `~MapAxis` Sliced axis object. """# TODO: Decide on handling node_type=center# See https://github.com/gammapy/gammapy/issues/1952returnMapAxis.from_bounds(lo_bnd=self.edges[0].value,hi_bnd=self.edges[-1].value,nbin=1,interp=self._interp,name=self._name,unit=self._unit,)
def__repr__(self):str_=self.__class__.__name__str_+="\n\n"fmt="\t{:<10s} : {:<10s}\n"str_+=fmt.format("name",self.name)str_+=fmt.format("unit","{!r}".format(str(self.unit)))str_+=fmt.format("nbins",str(self.nbin))str_+=fmt.format("node type",self.node_type)vals=self.edgesifself.node_type=="edges"elseself.centerstr_+=fmt.format(f"{self.node_type} min","{:.1e}".format(vals.min()))str_+=fmt.format(f"{self.node_type} max","{:.1e}".format(vals.max()))str_+=fmt.format("interp",self._interp)returnstr_def_init_copy(self,**kwargs):"""Init map axis instance by copying missing init arguments from self."""argnames=inspect.getfullargspec(self.__init__).argsargnames.remove("self")forarginargnames:value=getattr(self,"_"+arg)kwargs.setdefault(arg,copy.deepcopy(value))returnself.__class__(**kwargs)
[docs]defcopy(self,**kwargs):"""Copy `MapAxis` instance and overwrite given attributes. Parameters ---------- **kwargs : dict Keyword arguments to overwrite in the map axis constructor. Returns ------- copy : `MapAxis` Copied map axis. """returnself._init_copy(**kwargs)
[docs]defround(self,coord,clip=False):"""Round coord to nearest axis edge. Parameters ---------- coord : `~astropy.units.Quantity` Coordinates clip : bool Choose whether to clip indices to the valid range of the axis. Returns ------- coord : `~astropy.units.Quantity` Rounded coordinates """edges_pix=self.coord_to_pix(coord)ifclip:edges_pix=np.clip(edges_pix,-0.5,self.nbin-0.5)edges_idx=np.round(edges_pix+0.5)-0.5returnself.pix_to_coord(edges_idx)
[docs]defgroup_table(self,edges):"""Compute bin groups table for the map axis, given coarser bin edges. Parameters ---------- edges : `~astropy.units.Quantity` Group bin edges. Returns ------- groups : `~astropy.table.Table` Map axis group table. """# TODO: try to simplify this codeifself.node_type!="edges":raiseValueError("Only edge based map axis can be grouped")edges_pix=self.coord_to_pix(edges)edges_pix=np.clip(edges_pix,-0.5,self.nbin-0.5)edges_idx=np.round(edges_pix+0.5)-0.5edges_idx=np.unique(edges_idx)edges_ref=self.pix_to_coord(edges_idx)groups=Table()groups[f"{self.name}_min"]=edges_ref[:-1]groups[f"{self.name}_max"]=edges_ref[1:]groups["idx_min"]=(edges_idx[:-1]+0.5).astype(int)groups["idx_max"]=(edges_idx[1:]-0.5).astype(int)iflen(groups)==0:raiseValueError("No overlap between reference and target edges.")groups["bin_type"]="normal "edge_idx_start,edge_ref_start=edges_idx[0],edges_ref[0]ifedge_idx_start>0:underflow={"bin_type":"underflow","idx_min":0,"idx_max":edge_idx_start,f"{self.name}_min":self.pix_to_coord(-0.5),f"{self.name}_max":edge_ref_start,}groups.insert_row(0,vals=underflow)edge_idx_end,edge_ref_end=edges_idx[-1],edges_ref[-1]ifedge_idx_end<(self.nbin-0.5):overflow={"bin_type":"overflow","idx_min":edge_idx_end+1,"idx_max":self.nbin-1,f"{self.name}_min":edge_ref_end,f"{self.name}_max":self.pix_to_coord(self.nbin-0.5),}groups.add_row(vals=overflow)group_idx=Column(np.arange(len(groups)))groups.add_column(group_idx,name="group_idx",index=0)returngroups
[docs]defupsample(self,factor):"""Upsample map axis by a given factor. When up-sampling for each node specified in the axis, the corresponding number of sub-nodes are introduced and preserving the initial nodes. For node type "edges" this results in nbin * factor new bins. For node type "center" this results in (nbin - 1) * factor + 1 new bins. Parameters ---------- factor : int Upsampling factor. Returns ------- axis : `MapAxis` Usampled map axis. """ifself.node_type=="edges":pix=self.coord_to_pix(self.edges)nbin=int(self.nbin*factor)+1pix_new=np.linspace(pix.min(),pix.max(),nbin)edges=self.pix_to_coord(pix_new)returnself.from_edges(edges,name=self.name,interp=self.interp)else:pix=self.coord_to_pix(self.center)nbin=int((self.nbin-1)*factor)+1pix_new=np.linspace(pix.min(),pix.max(),nbin)nodes=self.pix_to_coord(pix_new)returnself.from_nodes(nodes,name=self.name,interp=self.interp)
[docs]defdownsample(self,factor):"""Downsample map axis by a given factor. When down-sampling each n-th (given by the factor) bin is selected from the axis while preserving the axis limits. For node type "edges" this requires nbin to be dividable by the factor, for node type "center" this requires nbin - 1 to be dividable by the factor. Parameters ---------- factor : int Downsampling factor. Returns ------- axis : `MapAxis` Downsampled map axis. """ifself.node_type=="edges":nbin=self.nbin/factorifnp.mod(nbin,1)>0:raiseValueError(f"Number of {self.name} bins is not divisible by {factor}")edges=self.edges[::factor]returnself.from_edges(edges,name=self.name,interp=self.interp)else:nbin=(self.nbin-1)/factorifnp.mod(nbin,1)>0:raiseValueError(f"Number of {self.name} bins - 1 is not divisible by {factor}")nodes=self.center[::factor]returnself.from_nodes(nodes,name=self.name,interp=self.interp)
[docs]defto_header(self,format="ogip",idx=0):"""Create FITS header Parameters ---------- format : {"ogip"} Format specification idx : int Column index of the axis. Returns ------- header : `~astropy.io.fits.Header` Header to extend. """header=fits.Header()ifformatin["ogip","ogip-sherpa"]:header["EXTNAME"]="EBOUNDS","Name of this binary table extension"header["TELESCOP"]="DUMMY","Mission/satellite name"header["INSTRUME"]="DUMMY","Instrument/detector"header["FILTER"]="None","Filter information"header["CHANTYPE"]="PHA","Type of channels (PHA, PI etc)"header["DETCHANS"]=self.nbin,"Total number of detector PHA channels"header["HDUCLASS"]="OGIP","Organisation devising file format"header["HDUCLAS1"]="RESPONSE","File relates to response of instrument"header["HDUCLAS2"]="EBOUNDS","This is an EBOUNDS extension"header["HDUVERS"]="1.2.0","Version of file format"elifformatin["gadf","fgst-ccube","fgst-template"]:key=f"AXCOLS{idx}"name=self.name.upper()ifself.name=="energy"andself.node_type=="edges":header[key]="E_MIN,E_MAX"elifself.name=="energy"andself.node_type=="center":header[key]="ENERGY"elifself.node_type=="edges":header[key]=f"{name}_MIN,{name}_MAX"elifself.node_type=="center":header[key]=nameelse:raiseValueError(f"Invalid node type {self.node_type!r}")key_interp=f"INTERP{idx}"header[key_interp]=self.interpelse:raiseValueError(f"Unknown format {format}")returnheader
[docs]defto_table(self,format="ogip"):"""Convert `~astropy.units.Quantity` to OGIP ``EBOUNDS`` extension. See https://heasarc.gsfc.nasa.gov/docs/heasarc/caldb/docs/memos/cal_gen_92_002/cal_gen_92_002.html#tth_sEc3.2 # noqa: E501 The 'ogip-sherpa' format is equivalent to 'ogip' but uses keV energy units. Parameters ---------- format : {"ogip", "ogip-sherpa", "gadf-dl3", "gtpsf"} Format specification Returns ------- table : `~astropy.table.Table` Table HDU """table=Table()edges=self.edgesifformatin["ogip","ogip-sherpa"]:self.assert_name("energy")ifformat=="ogip-sherpa":edges=edges.to("keV")table["CHANNEL"]=np.arange(self.nbin,dtype=np.int16)table["E_MIN"]=edges[:-1]table["E_MAX"]=edges[1:]elifformatin["ogip-arf","ogip-arf-sherpa"]:self.assert_name("energy_true")ifformat=="ogip-arf-sherpa":edges=edges.to("keV")table["ENERG_LO"]=edges[:-1]table["ENERG_HI"]=edges[1:]elifformat=="gadf-sed":ifself.is_energy_axis:table["e_ref"]=self.centertable["e_min"]=self.edges_mintable["e_max"]=self.edges_maxelifformat=="gadf-dl3":fromgammapy.irf.ioimportIRF_DL3_AXES_SPECIFICATIONifself.name=="energy":column_prefix="ENERG"else:forcolumn_prefix,specinIRF_DL3_AXES_SPECIFICATION.items():ifspec["name"]==self.name:breakifself.node_type=="edges":edges_hi,edges_lo=edges[:-1],edges[1:]else:edges_hi,edges_lo=self.center,self.centertable[f"{column_prefix}_LO"]=edges_hi[np.newaxis]table[f"{column_prefix}_HI"]=edges_lo[np.newaxis]elifformat=="gtpsf":ifself.name=="energy_true":table["Energy"]=self.center.to("MeV")elifself.name=="rad":table["Theta"]=self.center.to("deg")else:raiseValueError("Can only convert true energy or rad axis to"f"'gtpsf' format, got {self.name}")else:raiseValueError(f"{format} is not a valid format")returntable
[docs]defto_table_hdu(self,format="ogip"):"""Convert `~astropy.units.Quantity` to OGIP ``EBOUNDS`` extension. See https://heasarc.gsfc.nasa.gov/docs/heasarc/caldb/docs/memos/cal_gen_92_002/cal_gen_92_002.html#tth_sEc3.2 # noqa: E501 The 'ogip-sherpa' format is equivalent to 'ogip' but uses keV energy units. Parameters ---------- format : {"ogip", "ogip-sherpa", "gtpsf"} Format specification Returns ------- hdu : `~astropy.io.fits.BinTableHDU` Table HDU """table=self.to_table(format=format)ifformat=="gtpsf":name="THETA"else:name=Nonehdu=fits.BinTableHDU(table,name=name)ifformatin["ogip","ogip-sherpa"]:hdu.header.update(self.to_header(format=format))returnhdu
[docs]@classmethoddeffrom_table(cls,table,format="ogip",idx=0,column_prefix=""):"""Instantiate MapAxis from table HDU Parameters ---------- table : `~astropy.table.Table` Table format : {"ogip", "ogip-arf", "fgst-ccube", "fgst-template", "gadf", "gadf-dl3"} Format specification idx : int Column index of the axis. column_prefix : str Column name prefix of the axis, used for creating the axis. Returns ------- axis : `MapAxis` Map Axis """ifformatin["ogip","fgst-ccube"]:energy_min=table["E_MIN"].quantityenergy_max=table["E_MAX"].quantityenergy_edges=(np.append(energy_min.value,energy_max.value[-1])*energy_min.unit)axis=cls.from_edges(energy_edges,name="energy",interp="log")elifformat=="ogip-arf":energy_min=table["ENERG_LO"].quantityenergy_max=table["ENERG_HI"].quantityenergy_edges=(np.append(energy_min.value,energy_max.value[-1])*energy_min.unit)axis=cls.from_edges(energy_edges,name="energy_true",interp="log")elifformatin["fgst-template","fgst-bexpcube"]:allowed_names=["Energy","ENERGY","energy"]forcolnameintable.colnames:ifcolnameinallowed_names:tag=colnamebreaknodes=table[tag].dataaxis=cls.from_nodes(nodes=nodes,name="energy_true",unit="MeV",interp="log")elifformat=="gadf":axcols=table.meta.get("AXCOLS{}".format(idx+1))colnames=axcols.split(",")node_type="edges"iflen(colnames)==2else"center"# TODO: check why this extra case is neededifcolnames[0]=="E_MIN":name="energy"else:name=colnames[0].replace("_MIN","").lower()# this is need for backward compatibilityifname=="theta":name="rad"interp=table.meta.get("INTERP{}".format(idx+1),"lin")ifnode_type=="center":nodes=np.unique(table[colnames[0]].quantity)else:edges_min=np.unique(table[colnames[0]].quantity)edges_max=np.unique(table[colnames[1]].quantity)nodes=edges_from_lo_hi(edges_min,edges_max)axis=MapAxis(nodes=nodes,node_type=node_type,interp=interp,name=name)elifformat=="gadf-dl3":fromgammapy.irf.ioimportIRF_DL3_AXES_SPECIFICATIONspec=IRF_DL3_AXES_SPECIFICATION[column_prefix]name,interp=spec["name"],spec["interp"]# background models are stored in reconstructed energyhduclass=table.meta.get("HDUCLAS2")ifhduclassin{"BKG","RAD_MAX"}andcolumn_prefix=="ENERG":name="energy"edges_lo=table[f"{column_prefix}_LO"].quantity[0]edges_hi=table[f"{column_prefix}_HI"].quantity[0]ifnp.allclose(edges_hi,edges_lo):axis=MapAxis.from_nodes(edges_hi,interp=interp,name=name)else:edges=edges_from_lo_hi(edges_lo,edges_hi)axis=MapAxis.from_edges(edges,interp=interp,name=name)elifformat=="gtpsf":try:energy=table["Energy"].data*u.MeVaxis=MapAxis.from_nodes(energy,name="energy_true",interp="log")exceptKeyError:rad=table["Theta"].data*u.degaxis=MapAxis.from_nodes(rad,name="rad")elifformat=="gadf-sed-energy":if"e_min"intable.colnamesand"e_max"intable.colnames:e_min=flat_if_equal(table["e_min"].quantity)e_max=flat_if_equal(table["e_max"].quantity)edges=edges_from_lo_hi(e_min,e_max)axis=MapAxis.from_energy_edges(edges)elif"e_ref"intable.colnames:e_ref=flat_if_equal(table["e_ref"].quantity)axis=MapAxis.from_nodes(e_ref,name="energy",interp="log")else:raiseValueError("Either 'e_ref', 'e_min' or 'e_max' column ""names are required")elifformat=="gadf-sed-norm":# TODO: guess interp herenodes=flat_if_equal(table["norm_scan"][0])axis=MapAxis.from_nodes(nodes,name="norm")elifformat=="gadf-sed-counts":if"datasets"intable.colnames:labels=np.unique(table["datasets"])axis=LabelMapAxis(labels=labels,name="dataset")else:shape=table["counts"].shapeedges=np.arange(shape[-1]+1)-0.5axis=MapAxis.from_edges(edges,name="dataset")elifformat=="profile":if"datasets"intable.colnames:labels=np.unique(table["datasets"])axis=LabelMapAxis(labels=labels,name="dataset")else:x_ref=table["x_ref"].quantityaxis=MapAxis.from_nodes(x_ref,name="projected-distance")else:raiseValueError(f"Format '{format}' not supported")returnaxis
[docs]@classmethoddeffrom_table_hdu(cls,hdu,format="ogip",idx=0):"""Instantiate MapAxis from table HDU Parameters ---------- hdu : `~astropy.io.fits.BinTableHDU` Table HDU format : {"ogip", "ogip-arf", "fgst-ccube", "fgst-template"} Format specification idx : int Column index of the axis. Returns ------- axis : `MapAxis` Map Axis """table=Table.read(hdu)returncls.from_table(table,format=format,idx=idx)
[docs]classMapAxes(Sequence):"""MapAxis container class. Parameters ---------- axes : list of `MapAxis` List of map axis objects. """def__init__(self,axes,n_spatial_axes=None):unique_names=[]foraxinaxes:ifax.nameinunique_names:raise(ValueError(f"Axis names must be unique, got: '{ax.name}' twice."))unique_names.append(ax.name)self._axes=axesself._n_spatial_axes=n_spatial_axes@propertydefprimary_axis(self):"""Primary extra axis, defined as the one longest Returns ------- axis : `MapAxis` Map axis """# get longest axisidx=np.argmax(self.shape)returnself[int(idx)]@propertydefis_flat(self):"""Whether axes is flat"""shape=np.array(self.shape)returnnp.all(shape==1)@propertydefis_unidimensional(self):"""Whether axes is unidimensional"""shape=np.array(self.shape)non_zero=np.count_nonzero(shape>1)returnself.is_flatornon_zero==1@propertydefreverse(self):"""Reverse axes order"""returnMapAxes(self[::-1])@propertydefiter_with_reshape(self):"""Iterate by shape"""foridx,axisinenumerate(self):# Extract values for each axis, default: nodesshape=[1]*len(self)shape[idx]=-1ifself._n_spatial_axes:shape=(shape[::-1]+[1,]*self._n_spatial_axes)yieldtuple(shape),axis
[docs]defget_coord(self,mode="center",axis_name=None):"""Get axes coordinates Parameters ---------- mode : {"center", "edges"} Coordinate center or edges axis_name : str Axis name for which mode='edges' applies Returns ------- coords : dict of `~astropy.units.Quanity` Map coordinates """coords={}forshape,axisinself.iter_with_reshape:ifmode=="edges"andaxis.name==axis_name:coord=axis.edgeselse:coord=axis.centercoords[axis.name]=coord.reshape(shape)returncoords
@propertydefshape(self):"""Shape of the axes"""returntuple([ax.nbinforaxinself])@propertydefnames(self):"""Names of the axes"""return[ax.nameforaxinself]
[docs]defindex(self,axis_name):"""Get index in list"""returnself.names.index(axis_name)
[docs]defindex_data(self,axis_name):"""Get data index of the axes Parameters ---------- axis_name : str Name of the axis. Returns ------- idx : int Data index """idx=self.names.index(axis_name)returnlen(self)-idx-1
[docs]defupsample(self,factor,axis_name):"""Upsample axis by a given factor Parameters ---------- factor : int Upsampling factor. axis_name : str Axis to upsample. Returns ------- axes : `MapAxes` Map axes """axes=[]foraxinself:ifax.name==axis_name:ax=ax.upsample(factor=factor)axes.append(ax.copy())returnself.__class__(axes=axes)
[docs]defreplace(self,axis):"""Replace a given axis Parameters ---------- axis : `MapAxis` Map axis Returns ------- axes : MapAxes Map axe """axes=[]foraxinself:ifax.name==axis.name:ax=axisaxes.append(ax)returnself.__class__(axes=axes)
[docs]defresample(self,axis):"""Resample axis binning. This method groups the existing bins into a new binning. Parameters ---------- axis : `MapAxis` New map axis. Returns ------- axes : `MapAxes` Axes object with resampled axis. """axis_self=self[axis.name]groups=axis_self.group_table(axis.edges)# Keep only normal binsgroups=groups[groups["bin_type"]=="normal "]edges=edges_from_lo_hi(groups[axis.name+"_min"].quantity,groups[axis.name+"_max"].quantity,)axis_resampled=MapAxis.from_edges(edges=edges,interp=axis.interp,name=axis.name)axes=[]foraxinself:ifax.name==axis.name:axes.append(axis_resampled)else:axes.append(ax.copy())returnself.__class__(axes=axes)
[docs]defdownsample(self,factor,axis_name):"""Downsample axis by a given factor Parameters ---------- factor : int Upsampling factor. axis_name : str Axis to upsample. Returns ------- axes : `MapAxes` Map axes """axes=[]foraxinself:ifax.name==axis_name:ax=ax.downsample(factor=factor)axes.append(ax.copy())returnself.__class__(axes=axes)
[docs]defsquash(self,axis_name):"""Squash axis. Parameters ---------- axis_name : str Axis to squash. Returns ------- axes : `MapAxes` Axes with squashed axis. """axes=[]foraxinself:ifax.name==axis_name:ax=ax.squash()axes.append(ax.copy())returnself.__class__(axes=axes)
[docs]defpad(self,axis_name,pad_width):"""Pad axes Parameters ---------- axis_name : str Name of the axis to pad. pad_width : int or tuple of int Pad width Returns ------- axes : `MapAxes` Axes with squashed axis. """axes=[]foraxinself:ifax.name==axis_name:ax=ax.pad(pad_width=pad_width)axes.append(ax)returnself.__class__(axes=axes)
[docs]defdrop(self,axis_name):"""Drop an axis. Parameters ---------- axis_name : str Name of the axis to remove. Returns ------- axes : `MapAxes` Axes with squashed axis. """axes=[]foraxinself:ifax.name==axis_name:continueaxes.append(ax.copy())returnself.__class__(axes=axes)
[docs]defcoord_to_idx(self,coord,clip=True):"""Transform from axis to pixel indices. Parameters ---------- coord : dict of `~numpy.ndarray` or `MapCoord` Array of axis coordinate values. Returns ------- pix : tuple of `~numpy.ndarray` Array of pixel indices values. """returntuple([ax.coord_to_idx(coord[ax.name],clip=clip)foraxinself])
[docs]defcoord_to_pix(self,coord):"""Transform from axis to pixel coordinates. Parameters ---------- coord : dict of `~numpy.ndarray` Array of axis coordinate values. Returns ------- pix : tuple of `~numpy.ndarray` Array of pixel coordinate values. """returntuple([ax.coord_to_pix(coord[ax.name])foraxinself])
[docs]defpix_to_coord(self,pix):"""Convert pixel coordinates to map coordinates. Parameters ---------- pix : tuple Tuple of pixel coordinates. Returns ------- coords : tuple Tuple of map coordinates. """returntuple([ax.pix_to_coord(p)forax,pinzip(self,pix)])
[docs]defpix_to_idx(self,pix,clip=False):"""Convert pix to idx Parameters ---------- pix : tuple of `~numpy.ndarray` Pixel coordinates. clip : bool Choose whether to clip indices to the valid range of the axis. If false then indices for coordinates outside the axi range will be set -1. Returns ------- idx : tuple `~numpy.ndarray` Pixel indices. """idx=[]forpix_array,axinzip(pix,self):idx.append(ax.pix_to_idx(pix_array,clip=clip))returntuple(idx)
[docs]defslice_by_idx(self,slices):"""Create a new geometry by slicing the non-spatial axes. Parameters ---------- slices : dict Dict of axes names and integers or `slice` object pairs. Contains one element for each non-spatial dimension. For integer indexing the corresponding axes is dropped from the map. Axes not specified in the dict are kept unchanged. Returns ------- geom : `~Geom` Sliced geometry. """axes=[]foraxinself:ax_slice=slices.get(ax.name,slice(None))# in the case where isinstance(ax_slice, int) the axes is droppedifisinstance(ax_slice,slice):ax_sliced=ax.slice(ax_slice)axes.append(ax_sliced.copy())returnself.__class__(axes=axes)
[docs]defto_header(self,format="gadf"):"""Convert axes to FITS header Parameters ---------- format : {"gadf"} Header format Returns ------- header : `~astropy.io.fits.Header` FITS header. """header=fits.Header()foridx,axinenumerate(self,start=1):header_ax=ax.to_header(format=format,idx=idx)header.update(header_ax)returnheader
[docs]defto_table(self,format="gadf"):"""Convert axes to table Parameters ---------- format : {"gadf", "gadf-dl3", "fgst-ccube", "fgst-template", "ogip", "ogip-sherpa", "ogip-arf", "ogip-arf-sherpa"} # noqa E501 Format to use. Returns ------- table : `~astropy.table.Table` Table with axis data """ifformat=="gadf-dl3":tables=[]foraxinself:tables.append(ax.to_table(format=format))table=hstack(tables)elifformatin["gadf","fgst-ccube","fgst-template"]:table=Table()table["CHANNEL"]=np.arange(np.prod(self.shape))axes_ctr=np.meshgrid(*[ax.centerforaxinself])axes_min=np.meshgrid(*[ax.edges_minforaxinself])axes_max=np.meshgrid(*[ax.edges_maxforaxinself])foridx,axinenumerate(self):name=ax.name.upper()ifname=="ENERGY":colnames=["ENERGY","E_MIN","E_MAX"]else:colnames=[name,name+"_MIN",name+"_MAX"]forcolname,vinzip(colnames,[axes_ctr,axes_min,axes_max]):# do not store edges for label axisifax.node_type=="label"andcolname!=name:continuetable[colname]=np.ravel(v[idx])ifisinstance(ax,TimeMapAxis):ref_dict=time_ref_to_dict(ax.reference_time)table.meta.update(ref_dict)elifformatin["ogip","ogip-sherpa","ogip","ogip-arf"]:energy_axis=self["energy"]table=energy_axis.to_table(format=format)else:raiseValueError(f"Unsupported format: '{format}'")returntable
[docs]defto_table_hdu(self,format="gadf",hdu_bands=None):"""Make FITS table columns for map axes. Parameters ---------- format : {"gadf", "fgst-ccube", "fgst-template"} Format to use. hdu_bands : str Name of the bands HDU to use. Returns ------- hdu : `~astropy.io.fits.BinTableHDU` Bin table HDU. """# FIXME: Check whether convention is compatible with# dimensionality of geometry and simplify!!!ifformatin["fgst-ccube","ogip","ogip-sherpa"]:hdu_bands="EBOUNDS"elifformat=="fgst-template":hdu_bands="ENERGIES"elifformat=="gadf"orformatisNone:ifhdu_bandsisNone:hdu_bands="BANDS"else:raiseValueError(f"Unknown format {format}")table=self.to_table(format=format)header=self.to_header(format=format)returnfits.BinTableHDU(table,name=hdu_bands,header=header)
[docs]@classmethoddeffrom_table(cls,table,format="gadf"):"""Create MapAxes from table Parameters ---------- table : `~astropy.table.Table` Bin table HDU format : {"gadf", "gadf-dl3", "fgst-ccube", "fgst-template", "fgst-bexcube", "ogip-arf"} Format to use. Returns ------- axes : `MapAxes` Map axes object """fromgammapy.irf.ioimportIRF_DL3_AXES_SPECIFICATIONaxes=[]# Formats that support only one energy axisifformatin["fgst-ccube","fgst-template","fgst-bexpcube","ogip","ogip-arf",]:axes.append(MapAxis.from_table(table,format=format))elifformat=="gadf":# This limits the max number of axes to 5foridxinrange(5):axcols=table.meta.get("AXCOLS{}".format(idx+1))ifaxcolsisNone:break# TODO: what is good way to check whether it is a given axis type?try:axis=LabelMapAxis.from_table(table,format=format,idx=idx)except(KeyError,TypeError):try:axis=TimeMapAxis.from_table(table,format=format,idx=idx)except(KeyError,ValueError):axis=MapAxis.from_table(table,format=format,idx=idx)axes.append(axis)elifformat=="gadf-dl3":forcolumn_prefixinIRF_DL3_AXES_SPECIFICATION:try:axis=MapAxis.from_table(table,format=format,column_prefix=column_prefix)exceptKeyError:continueaxes.append(axis)elifformat=="gadf-sed":foraxis_formatin["gadf-sed-norm","gadf-sed-energy","gadf-sed-counts"]:try:axis=MapAxis.from_table(table=table,format=axis_format)exceptKeyError:continueaxes.append(axis)elifformat=="lightcurve":axes.extend(cls.from_table(table=table,format="gadf-sed"))axes.append(TimeMapAxis.from_table(table,format="lightcurve"))elifformat=="profile":axes.extend(cls.from_table(table=table,format="gadf-sed"))axes.append(MapAxis.from_table(table,format="profile"))else:raiseValueError(f"Unsupported format: '{format}'")returncls(axes)
[docs]@classmethoddeffrom_default(cls,axes,n_spatial_axes=None):"""Make a sequence of `~MapAxis` objects."""ifaxesisNone:returncls([])axes_out=[]foridx,axinenumerate(axes):ifisinstance(ax,np.ndarray):ax=MapAxis(ax)ifax.name=="":ax._name=f"axis{idx}"axes_out.append(ax)returncls(axes_out,n_spatial_axes=n_spatial_axes)
[docs]defassert_names(self,required_names):"""Assert required axis names and order Parameters ---------- required_names : list of str Required """message=("Incorrect axis order or names. Expected axis "f"order: {required_names}, got: {self.names}.")ifnotlen(self)==len(required_names):raiseValueError(message)try:forax,required_nameinzip(self,required_names):ax.assert_name(required_name)exceptValueError:raiseValueError(message)
[docs]defrename_axes(self,names,new_names):"""Rename the axes. Parameters ---------- names : list or str Names of the axes new_names : list or str New names of the axes (list must be of same length than `names`). Returns ------- axes : `MapAxes` Renamed Map axes object """axes=self.copy()ifisinstance(names,str):names=[names]ifisinstance(new_names,str):new_names=[new_names]forname,new_nameinzip(names,new_names):axes[name]._name=new_namereturnaxes
[docs]defis_allclose(self,other,**kwargs):"""Check if other map axes are all close. Parameters ---------- other : `MapAxes` Other map axes **kwargs : dict Keyword arguments forwarded to `~MapAxis.is_allclose` Returns ------- is_allclose : bool Whether other axes are all close """ifnotisinstance(other,self.__class__):returnTypeError(f"Cannot compare {type(self)} and {type(other)}")returnnp.all([ax0.is_allclose(ax1,**kwargs)forax0,ax1inzip(other,self)])
[docs]defcopy(self):"""Init map axes instance by copying each axis."""returnself.__class__([_.copy()for_inself])
[docs]classTimeMapAxis:"""Class representing a time axis. Provides methods for transforming to/from axis and pixel coordinates. A time axis can represent non-contiguous sequences of non-overlapping time intervals. Time intervals must be provided in increasing order. Parameters ---------- edges_min : `~astropy.units.Quantity` Array of edge time values. This the time delta w.r.t. to the reference time. edges_max : `~astropy.units.Quantity` Array of edge time values. This the time delta w.r.t. to the reference time. reference_time : `~astropy.time.Time` Reference time to use. name : str Axis name interp : str Interpolation method used to transform between axis and pixel coordinates. For now only 'lin' is supported. """node_type="intervals"time_format="iso"def__init__(self,edges_min,edges_max,reference_time,name="time",interp="lin"):self._name=nameedges_min=u.Quantity(edges_min,ndmin=1)edges_max=u.Quantity(edges_max,ndmin=1)ifnotedges_min.unit.is_equivalent("s"):raiseValueError(f"Time edges min must have a valid time unit, got {edges_min.unit}")ifnotedges_max.unit.is_equivalent("s"):raiseValueError(f"Time edges max must have a valid time unit, got {edges_max.unit}")ifnotedges_min.shape==edges_max.shape:raiseValueError("Edges min and edges max must have the same shape,"f" got {edges_min.shape} and {edges_max.shape}.")ifnotnp.all(edges_max>edges_min):raiseValueError("Edges max must all be larger than edge min")ifnotnp.all(edges_min==np.sort(edges_min)):raiseValueError("Time edges min values must be sorted")ifnotnp.all(edges_max==np.sort(edges_max)):raiseValueError("Time edges max values must be sorted")ifinterp!="lin":raiseNotImplementedError(f"Non-linear scaling scheme are not supported yet, got {interp}")self._edges_min=edges_minself._edges_max=edges_maxself._reference_time=Time(reference_time)self._pix_offset=-0.5self._interp=interpdelta=edges_min[1:]-edges_max[:-1]ifnp.any(delta<0*u.s):raiseValueError("Time intervals must not overlap.")@propertydefis_contiguous(self):"""Whether the axis is contiguous"""returnnp.all(self.edges_min[1:]==self.edges_max[:-1])
[docs]defto_contiguous(self):"""Make the time axis contiguous Returns ------- axis : `TimeMapAxis` Contiguous time axis """edges=np.unique(np.stack([self.edges_min,self.edges_max]))returnself.__class__(edges_min=edges[:-1],edges_max=edges[1:],reference_time=self.reference_time,name=self.name,interp=self.interp,)
@propertydefunit(self):"""Axes unit"""returnself.edges_max.unit@propertydefinterp(self):"""Interp"""returnself._interp@propertydefreference_time(self):"""Return reference time used for the axis."""returnself._reference_time@propertydefname(self):"""Return axis name."""returnself._name@propertydefnbin(self):"""Return number of bins in the axis."""returnlen(self.edges_min.flatten())@propertydefedges_min(self):"""Return array of bin edges max values."""returnself._edges_min@propertydefedges_max(self):"""Return array of bin edges min values."""returnself._edges_max@propertydefedges(self):"""Return array of bin edges values."""ifnotself.is_contiguous:raiseValueError("Time axis is not contiguous")returnedges_from_lo_hi(self.edges_min,self.edges_max)@propertydefbounds(self):"""Bounds of the axis (~astropy.units.Quantity)"""returnself.edges_min[0],self.edges_max[-1]@propertydeftime_bounds(self):"""Bounds of the axis (~astropy.units.Quantity)"""t_min,t_max=self.boundsreturnt_min+self.reference_time,t_max+self.reference_time@propertydeftime_min(self):"""Return axis lower edges as Time objects."""returnself._edges_min+self.reference_time@propertydeftime_max(self):"""Return axis upper edges as Time objects."""returnself._edges_max+self.reference_time@propertydeftime_delta(self):"""Return axis time bin width (`~astropy.time.TimeDelta`)."""returnself.time_max-self.time_min@propertydeftime_mid(self):"""Return time bin center (`~astropy.time.Time`)."""returnself.time_min+0.5*self.time_delta@propertydeftime_edges(self):"""Time edges"""returnself.reference_time+self.edges@propertydefas_plot_xerr(self):"""Plot x error"""xn,xp=self.time_mid-self.time_min,self.time_max-self.time_midifself.time_format=="iso":x_errn=xn.to_datetime()x_errp=xp.to_datetime()elifself.time_format=="mjd":x_errn=xn.to("day")x_errp=xp.to("day")else:raiseValueError(f"Invalid time_format: {self.time_format}")returnx_errn,x_errp@propertydefas_plot_labels(self):"""Plot labels"""labels=[]fort_min,t_maxinself.iter_by_edges:label=f"{getattr(t_min,self.time_format)} - {getattr(t_max,self.time_format)}"labels.append(label)returnlabels@propertydefas_plot_edges(self):"""Plot edges"""ifself.time_format=="iso":edges=self.time_edges.to_datetime()elifself.time_format=="mjd":edges=self.time_edges.mjd*u.dayelse:raiseValueError(f"Invalid time_format: {self.time_format}")returnedges@propertydefas_plot_center(self):"""Plot center"""ifself.time_format=="iso":center=self.time_mid.datetimeelifself.time_format=="mjd":center=self.time_mid.mjd*u.dayreturncenter
[docs]defassert_name(self,required_name):"""Assert axis name if a specific one is required. Parameters ---------- required_name : str Required """ifself.name!=required_name:raiseValueError("Unexpected axis name,"f' expected "{required_name}", got: "{self.name}"')
[docs]defis_allclose(self,other,**kwargs):"""Check if other map axis is all close. Parameters ---------- other : `TimeMapAxis` Other map axis **kwargs : dict Keyword arguments forwarded to `~numpy.allclose` Returns ------- is_allclose : bool Whether other axis is allclose """ifnotisinstance(other,self.__class__):returnTypeError(f"Cannot compare {type(self)} and {type(other)}")ifself._edges_min.shape!=other._edges_min.shape:returnFalse# This will test equality at microsec level.delta_min=self.time_min-other.time_mindelta_max=self.time_max-other.time_maxreturn(np.allclose(delta_min.to_value("s"),0.0,**kwargs)andnp.allclose(delta_max.to_value("s"),0.0,**kwargs)andself._interp==other._interpandself.name.upper()==other.name.upper())
@propertydefiter_by_edges(self):"""Iterate by intervals defined by the edges"""fortime_min,time_maxinzip(self.time_min,self.time_max):yield(time_min,time_max)
[docs]defcoord_to_idx(self,coord,**kwargs):"""Transform from axis time coordinate to bin index. Indices of time values falling outside time bins will be set to -1. Parameters ---------- coord : `~astropy.time.Time` or `~astropy.units.Quantity` Array of axis coordinate values. The quantity is assumed to be relative to the reference time. Returns ------- idx : `~numpy.ndarray` Array of bin indices. """ifisinstance(coord,u.Quantity):coord=self.reference_time+coordtime=Time(coord[...,np.newaxis])delta_plus=(time-self.time_min).value>0.0delta_minus=(time-self.time_max).value<=0.0mask=np.logical_and(delta_plus,delta_minus)idx=np.asanyarray(np.argmax(mask,axis=-1))idx[~np.any(mask,axis=-1)]=INVALID_INDEX.intreturnidx
[docs]defcoord_to_pix(self,coord,**kwargs):"""Transform from time to coordinate to pixel position. Pixels of time values falling outside time bins will be set to -1. Parameters ---------- coord : `~astropy.time.Time` Array of axis coordinate values. Returns ------- pix : `~numpy.ndarray` Array of pixel positions. """ifisinstance(coord,u.Quantity):coord=self.reference_time+coordidx=np.atleast_1d(self.coord_to_idx(coord))valid_pix=idx!=INVALID_INDEX.intpix=np.atleast_1d(idx).astype("float")# TODO: is there the equivalent of np.atleast1d for astropy.time.Time?ifcoord.shape==():coord=coord.reshape((1,))relative_time=coord[valid_pix]-self.reference_timescale=interpolation_scale(self._interp)valid_idx=idx[valid_pix]s_min=scale(self._edges_min[valid_idx])s_max=scale(self._edges_max[valid_idx])s_coord=scale(relative_time.to(self._edges_min.unit))pix[valid_pix]+=(s_coord-s_min)/(s_max-s_min)pix[~valid_pix]=INVALID_INDEX.floatreturnpix-0.5
def_init_copy(self,**kwargs):"""Init map axis instance by copying missing init arguments from self."""argnames=inspect.getfullargspec(self.__init__).argsargnames.remove("self")forarginargnames:value=getattr(self,"_"+arg)kwargs.setdefault(arg,copy.deepcopy(value))returnself.__class__(**kwargs)
[docs]defcopy(self,**kwargs):"""Copy `MapAxis` instance and overwrite given attributes. Parameters ---------- **kwargs : dict Keyword arguments to overwrite in the map axis constructor. Returns ------- copy : `MapAxis` Copied map axis. """returnself._init_copy(**kwargs)
[docs]defslice(self,idx):"""Create a new axis object by extracting a slice from this axis. Parameters ---------- idx : slice Slice object selecting a subselection of the axis. Returns ------- axis : `~TimeMapAxis` Sliced axis object. """returnTimeMapAxis(self._edges_min[idx].copy(),self._edges_max[idx].copy(),self.reference_time,interp=self._interp,name=self.name,)
[docs]defsquash(self):"""Create a new axis object by squashing the axis into one bin. Returns ------- axis : `~MapAxis` Sliced axis object. """returnTimeMapAxis(self._edges_min[0],self._edges_max[-1],self.reference_time,interp=self._interp,name=self._name,)
# TODO: if we are to allow log or sqrt bins the reference time should always# be strictly lower than all times# Should we define a mechanism to ensure this is always correct?
[docs]@classmethoddeffrom_time_edges(cls,time_min,time_max,unit="d",interp="lin",name="time"):"""Create TimeMapAxis from the time interval edges defined as `~astropy.time.Time`. The reference time is defined as the lower edge of the first interval. Parameters ---------- time_min : `~astropy.time.Time` Array of lower edge times. time_max : `~astropy.time.Time` Array of lower edge times. unit : `~astropy.units.Unit` or str The unit to convert the edges to. Default is 'd' (day). interp : str Interpolation method used to transform between axis and pixel coordinates. Valid options are 'log', 'lin', and 'sqrt'. name : str Axis name Returns ------- axis : `TimeMapAxis` Time map axis. """unit=u.Unit(unit)reference_time=time_min[0]edges_min=time_min-reference_timeedges_max=time_max-reference_timereturncls(edges_min.to(unit),edges_max.to(unit),reference_time,interp=interp,name=name,)
# TODO: how configurable should that be? column names?
[docs]@classmethoddeffrom_table(cls,table,format="gadf",idx=0):"""Create time map axis from table Parameters ---------- table : `~astropy.table.Table` Bin table HDU format : {"gadf", "fermi-fgl", "lightcurve"} Format to use. Returns ------- axis : `TimeMapAxis` Time map axis. """ifformat=="gadf":axcols=table.meta.get("AXCOLS{}".format(idx+1))colnames=axcols.split(",")name=colnames[0].replace("_MIN","").lower()reference_time=time_ref_from_dict(table.meta)edges_min=np.unique(table[colnames[0]].quantity)edges_max=np.unique(table[colnames[1]].quantity)elifformat=="fermi-fgl":meta=table.meta.copy()meta["MJDREFF"]=str(meta["MJDREFF"]).replace("D-4","e-4")reference_time=time_ref_from_dict(meta=meta)name="time"edges_min=table["Hist_Start"][:-1]edges_max=table["Hist_Start"][1:]elifformat=="lightcurve":# TODO: is this a good format? It just supports mjd...name="time"scale=table.meta.get("TIMESYS","utc")time_min=Time(table["time_min"].data,format="mjd",scale=scale)time_max=Time(table["time_max"].data,format="mjd",scale=scale)reference_time=Time("2001-01-01T00:00:00")reference_time.format="mjd"edges_min=(time_min-reference_time).to("s")edges_max=(time_max-reference_time).to("s")else:raiseValueError(f"Not a supported format: {format}")returncls(edges_min=edges_min,edges_max=edges_max,reference_time=reference_time,name=name,)
[docs]@classmethoddeffrom_gti(cls,gti,name="time"):"""Create a time axis from an input GTI. Parameters ---------- gti : `GTI` GTI table name : str Axis name Returns ------- axis : `TimeMapAxis` Time map axis. """tmin=gti.time_start-gti.time_reftmax=gti.time_stop-gti.time_refreturncls(edges_min=tmin.to("s"),edges_max=tmax.to("s"),reference_time=gti.time_ref,name=name,)
[docs]@classmethoddeffrom_time_bounds(cls,time_min,time_max,nbin,unit="d",name="time"):"""Create linearly spaced time axis from bounds Parameters ---------- time_min : `~astropy.time.Time` Lower bound time_max : `~astropy.time.Time` Upper bound nbin : int Number of bins name : str Name of the axis. """delta=time_max-time_mintime_edges=time_min+delta*np.linspace(0,1,nbin+1)returncls.from_time_edges(time_min=time_edges[:-1],time_max=time_edges[1:],interp="lin",unit=unit,name=name,)
[docs]defto_header(self,format="gadf",idx=0):"""Create FITS header Parameters ---------- format : {"ogip"} Format specification idx : int Column index of the axis. Returns ------- header : `~astropy.io.fits.Header` Header to extend. """header=fits.Header()ifformat=="gadf":key=f"AXCOLS{idx}"name=self.name.upper()header[key]=f"{name}_MIN,{name}_MAX"key_interp=f"INTERP{idx}"header[key_interp]=self.interpref_dict=time_ref_to_dict(self.reference_time)header.update(ref_dict)else:raiseValueError(f"Unknown format {format}")returnheader
[docs]classLabelMapAxis:"""Map axis using labels Parameters ---------- labels : list of str Labels to be used for the axis nodes. name : str Name of the axis. """node_type="label"def__init__(self,labels,name=""):unique_labels=np.unique(labels)ifnotlen(unique_labels)==len(labels):raiseValueError("Node labels must be unique")self._labels=unique_labelsself._name=name@propertydefunit(self):"""Unit"""returnu.Unit("")@propertydefname(self):"""Name of the axis"""returnself._name
[docs]defassert_name(self,required_name):"""Assert axis name if a specific one is required. Parameters ---------- required_name : str Required """ifself.name!=required_name:raiseValueError("Unexpected axis name,"f' expected "{required_name}", got: "{self.name}"')
@propertydefnbin(self):"""Number of bins"""returnlen(self._labels)
[docs]defpix_to_coord(self,pix):"""Transform from pixel to axis coordinates. Parameters ---------- pix : `~numpy.ndarray` Array of pixel coordinate values. Returns ------- coord : `~numpy.ndarray` Array of axis coordinate values. """idx=np.round(pix).astype(int)returnself._labels[idx]
[docs]defcoord_to_idx(self,coord,**kwargs):"""Transform labels to indices If the label is not present an error is raised. Parameters ---------- coord : `~astropy.time.Time` Array of axis coordinate values. Returns ------- idx : `~numpy.ndarray` Array of bin indices. """coord=np.array(coord)[...,np.newaxis]is_equal=coord==self._labelsifnotnp.all(np.any(is_equal,axis=-1)):label=coord[~np.any(is_equal,axis=-1)]raiseValueError(f"Not a valid label: {label}")returnnp.argmax(is_equal,axis=-1)
[docs]defcoord_to_pix(self,coord):"""Transform from axis labels to pixel coordinates. Parameters ---------- coord : `~numpy.ndarray` Array of axis label values. Returns ------- pix : `~numpy.ndarray` Array of pixel coordinate values. """returnself.coord_to_idx(coord).astype("float")
[docs]defpix_to_idx(self,pix,clip=False):"""Convert pix to idx Parameters ---------- pix : tuple of `~numpy.ndarray` Pixel coordinates. clip : bool Choose whether to clip indices to the valid range of the axis. If false then indices for coordinates outside the axi range will be set -1. Returns ------- idx : tuple `~numpy.ndarray` Pixel indices. """ifclip:idx=np.clip(pix,0,self.nbin-1)else:condition=(pix<0)|(pix>=self.nbin)idx=np.where(condition,-1,pix)returnidx
@propertydefcenter(self):"""Center of the label axis"""returnself._labels@propertydefedges(self):"""Edges of the label axis"""raiseValueError("A LabelMapAxis does not define edges")@propertydefedges_min(self):"""Edges of the label axis"""returnself._labels@propertydefedges_max(self):"""Edges of the label axis"""returnself._labels@propertydefbin_width(self):"""Bin width is unity"""returnnp.ones(self.nbin)@propertydefas_plot_xerr(self):"""Plot labels"""return0.5*np.ones(self.nbin)@propertydefas_plot_labels(self):"""Plot labels"""returnself._labels.tolist()@propertydefas_plot_center(self):"""Plot labels"""returnnp.arange(self.nbin)@propertydefas_plot_edges(self):"""Plot labels"""returnnp.arange(self.nbin+1)-0.5
[docs]defto_header(self,format="gadf",idx=0):"""Create FITS header Parameters ---------- format : {"ogip"} Format specification idx : int Column index of the axis. Returns ------- header : `~astropy.io.fits.Header` Header to extend. """header=fits.Header()ifformat=="gadf":key=f"AXCOLS{idx}"header[key]=self.name.upper()else:raiseValueError(f"Unknown format {format}")returnheader
# TODO: how configurable should that be? column names?
[docs]@classmethoddeffrom_table(cls,table,format="gadf",idx=0):"""Create time map axis from table Parameters ---------- table : `~astropy.table.Table` Bin table HDU format : {"gadf"} Format to use. Returns ------- axis : `TimeMapAxis` Time map axis. """ifformat=="gadf":colname=table.meta.get("AXCOLS{}".format(idx+1))column=table[colname]ifnotnp.issubdtype(column.dtype,np.str_):raiseTypeError(f"Not a valid dtype for label axis: '{column.dtype}'")labels=np.unique(column.data)else:raiseValueError(f"Not a supported format: {format}")returncls(labels=labels,name=colname.lower())
[docs]defis_allclose(self,other,**kwargs):"""Check if other map axis is all close. Parameters ---------- other : `LabelMapAxis` Other map axis Returns ------- is_allclose : bool Whether other axis is allclose """ifnotisinstance(other,self.__class__):returnTypeError(f"Cannot compare {type(self)} and {type(other)}")name_equal=self.name.upper()==other.name.upper()labels_equal=np.all(self.center==other.center)returnname_equal&labels_equal
def__eq__(self,other):ifnotisinstance(other,self.__class__):returnFalsereturnself.is_allclose(other=other)def__ne__(self,other):returnnotself.__eq__(other)# TODO: could create sub-labels here using dashes like "label-1-a", etc.
[docs]defupsample(self,*args,**kwargs):"""Upsample axis"""raiseNotImplementedError("Upsampling a LabelMapAxis is not supported")
# TODO: could merge labels here like "label-1-label2", etc.
[docs]defdownsample(self,*args,**kwargs):"""Downsample axis"""raiseNotImplementedError("Downsampling a LabelMapAxis is not supported")
# TODO: could merge labels here like "label-1-label2", etc.
[docs]defresample(self,*args,**kwargs):"""Resample axis"""raiseNotImplementedError("Resampling a LabelMapAxis is not supported")
# TODO: could create new labels here like "label-10-a"
[docs]defpad(self,*args,**kwargs):"""Resample axis"""raiseNotImplementedError("Padding a LabelMapAxis is not supported")
[docs]defslice(self,idx):"""Create a new axis object by extracting a slice from this axis. Parameters ---------- idx : slice Slice object selecting a subselection of the axis. Returns ------- axis : `~LabelMapAxis` Sliced axis object. """returnself.__class__(labels=self._labels[idx],name=self.name,)