# Licensed under a 3-clause BSD style license - see LICENSE.rst"""Priors for Gammapy."""importloggingimportnumpyasnpimportastropy.unitsasufromgammapy.modelingimportPriorParameter,PriorParametersfrom.coreimportModelBase__all__=["GaussianPrior","UniformPrior","Prior"]log=logging.getLogger(__name__)def_build_priorparameters_from_dict(data,default_parameters):"""Build a `~gammapy.modeling.PriorParameters` object from input dictionary and default prior parameter values."""par_data=[]input_names=[_["name"]for_indata]forparindefault_parameters:par_dict=par.to_dict()try:index=input_names.index(par_dict["name"])par_dict.update(data[index])exceptValueError:log.warning(f"PriorParameter '{par_dict['name']}' not defined in YAML file."f" Using default value: {par_dict['value']}{par_dict['unit']}")par_data.append(par_dict)returnPriorParameters.from_dict(par_data)
[docs]classPrior(ModelBase):"""Prior base class."""_unit=""def__init__(self,**kwargs):# Copy default parameters from the class to the instancedefault_parameters=self.default_parameters.copy()forparindefault_parameters:value=kwargs.get(par.name,par)ifnotisinstance(value,PriorParameter):par.quantity=u.Quantity(value)else:par=valuesetattr(self,par.name,par)_weight=kwargs.get("weight",None)if_weightisnotNone:self._weight=_weightelse:self._weight=1@propertydefparameters(self):"""Prior parameters as a `~gammapy.modeling.PriorParameters` object."""returnPriorParameters([getattr(self,name)fornameinself.default_parameters.names])def__init_subclass__(cls,**kwargs):# Add priorparameters list on the model sub-class (not instances)cls.default_parameters=PriorParameters([_for_incls.__dict__.values()ifisinstance(_,PriorParameter)])@propertydefweight(self):"""Weight mulitplied to the prior when evaluated."""returnself._weight@weight.setterdefweight(self,value):self._weight=value
[docs]def__call__(self,value):"""Call evaluate method."""# assuming the same unit as the PriorParamater herekwargs={par.name:par.valueforparinself.parameters}returnself.weight*self.evaluate(value.value,**kwargs)
[docs]defto_dict(self,full_output=False):"""Create dictionary for YAML serialisation."""tag=self.tag[0]ifisinstance(self.tag,list)elseself.tagparams=self.parameters.to_dict()ifnotfull_output:forpar,par_defaultinzip(params,self.default_parameters):init=par_default.to_dict()foritemin["min","max","error",]:default=init[item]ifpar[item]==defaultor(np.isnan(par[item])andnp.isnan(default)):delpar[item]data={"type":tag,"parameters":params,"weight":self.weight}ifself.typeisNone:returndataelse:return{self.type:data}
[docs]@classmethoddeffrom_dict(cls,data,**kwargs):"""Get prior parameters from dictionary."""kwargs={}key0=next(iter(data))ifkey0in["prior"]:data=data[key0]ifdata["type"]notincls.tag:raiseValueError(f"Invalid model type {data['type']} for class {cls.__name__}")priorparameters=_build_priorparameters_from_dict(data["parameters"],cls.default_parameters)kwargs["weight"]=data["weight"]returncls.from_parameters(priorparameters,**kwargs)
[docs]classGaussianPrior(Prior):"""One-dimensional Gaussian Prior. Parameters ---------- mu : float Mean of the Gaussian distribution. Default is 0. sigma : float Standard deviation of the Gaussian distribution. Default is 1. """tag=["GaussianPrior"]_type="prior"mu=PriorParameter(name="mu",value=0)sigma=PriorParameter(name="sigma",value=1)
[docs]@staticmethoddefevaluate(value,mu,sigma):"""Evaluate the Gaussian prior."""return((value-mu)/sigma)**2
[docs]classUniformPrior(Prior):"""Uniform Prior. Returns 1 if the parameter value is in (min, max). 0, if otherwise. Parameters ---------- min : float Minimum value. Default is -inf. max : float Maxmimum value. Default is inf. """tag=["UniformPrior"]_type="prior"min=PriorParameter(name="min",value=-np.inf,unit="")max=PriorParameter(name="max",value=np.inf,unit="")
[docs]@staticmethoddefevaluate(value,min,max):"""Evaluate the uniform prior."""ifmin<value<max:return0.0else:return1.0