# Licensed under a 3-clause BSD style license - see LICENSE.rst"""Covariance class."""importnumpyasnpimportscipyimportmatplotlib.pyplotaspltfromgammapy.utils.parallelimportis_ray_initializedfrom.parameterimportParameters__all__=["Covariance"]
[docs]classCovariance:"""Parameter covariance class. Parameters ---------- parameters : `~gammapy.modeling.Parameters` Parameter list. data : `~numpy.ndarray` Covariance data array. """def__init__(self,parameters,data=None):self.parameters=parametersifdataisNone:data=np.diag([p.error**2forpinself.parameters])self._data=np.asanyarray(data,dtype=float)@propertydefshape(self):"""Covariance shape."""npars=len(self.parameters)returnnpars,npars@propertydefdata(self):"""Covariance data as a `~numpy.ndarray`."""returnself._data@data.setterdefdata(self,value):value=np.asanyarray(value)npars=len(self.parameters)shape=(npars,npars)ifvalue.shape!=shape:raiseValueError(f"Invalid covariance shape: {value.shape}, expected {shape}")self._data=value@staticmethoddef_expand_factor_matrix(matrix,parameters):"""Expand covariance matrix with zeros for frozen parameters."""npars=len(parameters)matrix_expanded=np.zeros((npars,npars))mask_frozen=[par.frozenforparinparameters]pars_index=[np.where(np.array(parameters)==p)[0][0]forpinparameters]mask_duplicate=[pars_idx!=idxforidx,pars_idxinenumerate(pars_index)]mask=np.array(mask_frozen)|np.array(mask_duplicate)free_parameters=~(mask|mask[:,np.newaxis])matrix_expanded[free_parameters]=matrix.ravel()returnmatrix_expanded
[docs]@classmethoddeffrom_factor_matrix(cls,parameters,matrix):"""Set covariance from factor covariance matrix. Used in the optimizer interface. """npars=len(parameters)ifnotmatrix.shape==(npars,npars):matrix=cls._expand_factor_matrix(matrix,parameters)scales=[par.scaleforparinparameters]scale_matrix=np.outer(scales,scales)data=scale_matrix*matrixreturncls(parameters,data=data)
[docs]@classmethoddeffrom_stack(cls,covar_list):"""Stack sub-covariance matrices from list. Parameters ---------- covar_list : list of `Covariance` List of sub-covariances. Returns ------- covar : `Covariance` Stacked covariance. """parameters=Parameters.from_stack([_.parametersfor_incovar_list])covar=cls(parameters)forsubcovarincovar_list:covar.set_subcovariance(subcovar)returncovar
[docs]defget_subcovariance(self,parameters):"""Get sub-covariance matrix. Parameters ---------- parameters : `Parameters` Sub list of parameters. Returns ------- covariance : `~numpy.ndarray` Sub-covariance. """idx=[self.parameters.index(par)forparinparameters]data=self._data[np.ix_(idx,idx)]returnself.__class__(parameters=parameters,data=data)
[docs]defset_subcovariance(self,covar):"""Set sub-covariance matrix. Parameters ---------- covar : `Covariance` Sub-covariance. """ifis_ray_initialized():# This copy is required to make the covariance setting work with rayself._data=self._data.copy()idx=[self.parameters.index(par)forparincovar.parameters]ifnotnp.allclose(self.data[np.ix_(idx,idx)],covar.data):self.data[idx,:]=0self.data[:,idx]=0self._data[np.ix_(idx,idx)]=covar.data
[docs]defplot_correlation(self,ax=None,**kwargs):"""Plot correlation matrix. Parameters ---------- ax : `~matplotlib.axes.Axes`, optional Axis to plot on. Default is None. **kwargs : dict Keyword arguments passed to `~gammapy.visualization.plot_heatmap`. Returns ------- ax : `~matplotlib.axes.Axes`, optional Matplotlib axes. """fromgammapy.visualizationimportannotate_heatmap,plot_heatmapnpars=len(self.parameters)figsize=(npars*0.8,npars*0.65)plt.figure(figsize=figsize)ax=plt.gca()ifaxisNoneelseaxkwargs.setdefault("cmap","coolwarm")names=self.parameters.namesim,cbar=plot_heatmap(data=self.correlation,col_labels=names,row_labels=names,ax=ax,vmin=-1,vmax=1,cbarlabel="Correlation",**kwargs,)annotate_heatmap(im=im)returnax
@propertydefcorrelation(self):r"""Correlation matrix as a `numpy.ndarray`. Correlation :math:`C` is related to covariance :math:`\Sigma` via: .. math:: C_{ij} = \frac{ \Sigma_{ij} }{ \sqrt{\Sigma_{ii} \Sigma_{jj}} } """err=np.sqrt(np.diag(self.data))withnp.errstate(invalid="ignore",divide="ignore"):correlation=self.data/np.outer(err,err)returnnp.nan_to_num(correlation)@propertydefscipy_mvn(self):returnscipy.stats.multivariate_normal(self.parameters.value,self.data,allow_singular=True)def__str__(self):returnstr(self.data)def__array__(self):returnself.data