# Licensed under a 3-clause BSD style license - see LICENSE.rst"""Interpolation utilities"""fromitertoolsimportcompressimportnumpyasnpimportscipy.interpolatefromastropyimportunitsasu__all__=["interpolate_profile","interpolation_scale","ScaledRegularGridInterpolator",]INTERPOLATION_ORDER={None:0,"nearest":0,"linear":1,"quadratic":2,"cubic":3}
[docs]classScaledRegularGridInterpolator:"""Thin wrapper around `scipy.interpolate.RegularGridInterpolator`. The values are scaled before the interpolation and back-scaled after the interpolation. Dimensions of length 1 are ignored in the interpolation of the data. Parameters ---------- points : tuple of `~numpy.ndarray` or `~astropy.units.Quantity` Tuple of points passed to `RegularGridInterpolator`. values : `~numpy.ndarray` Values passed to `RegularGridInterpolator`. points_scale : tuple of str Interpolation scale used for the points. values_scale : {'lin', 'log', 'sqrt'} Interpolation scaling applied to values. If the values vary over many magnitudes a 'log' scaling is recommended. axis : int or None Axis along which to interpolate. method : {"linear", "nearest"} Default interpolation method. Can be overwritten when calling the `ScaledRegularGridInterpolator`. **kwargs : dict Keyword arguments passed to `RegularGridInterpolator`. """def__init__(self,points,values,points_scale=None,values_scale="lin",extrapolate=True,axis=None,**kwargs,):ifpoints_scaleisNone:points_scale=["lin"]*len(points)self.scale_points=[interpolation_scale(scale)forscaleinpoints_scale]self.scale=interpolation_scale(values_scale)self.axis=axisself._include_dimensions=[len(p)>1forpinpoints]values_scaled=self.scale(values)points_scaled=self._scale_points(points=points)ifextrapolate:kwargs.setdefault("bounds_error",False)kwargs.setdefault("fill_value",None)method=kwargs.get("method",None)ifnotnp.any(self._include_dimensions):ifmethod!="nearest":raiseValueError("Interpolating scalar values requires using ""method='nearest' explicitely.")ifnp.any(self._include_dimensions):values_scaled=np.squeeze(values_scaled)ifaxisisNone:self._interpolate=scipy.interpolate.RegularGridInterpolator(points=points_scaled,values=values_scaled,**kwargs)else:self._interpolate=scipy.interpolate.interp1d(points_scaled[0],values_scaled,axis=axis)def_scale_points(self,points):points_scaled=[scale(p)forp,scaleinzip(points,self.scale_points)]ifnp.any(self._include_dimensions):points_scaled=compress(points_scaled,self._include_dimensions)returntuple(points_scaled)
[docs]def__call__(self,points,method=None,clip=True,**kwargs):"""Interpolate data points. Parameters ---------- points : tuple of `~numpy.ndarray` or `~astropy.units.Quantity` Tuple of coordinate arrays of the form (x_1, x_2, x_3, ...). Arrays are broadcasted internally. method : {None, "linear", "nearest"} Linear or nearest neighbour interpolation. None will choose the default defined on init. clip : bool Clip values at zero after interpolation. """points=self._scale_points(points=points)ifself.axisisNone:points=np.broadcast_arrays(*points)points_interp=np.stack([_.flatfor_inpoints]).Tvalues=self._interpolate(points_interp,method,**kwargs)values=self.scale.inverse(values.reshape(points[0].shape))else:values=self._interpolate(points[0])values=self.scale.inverse(values)ifclip:values=np.clip(values,0,np.inf)returnvalues
[docs]definterpolation_scale(scale="lin"):"""Interpolation scaling. Parameters ---------- scale : {"lin", "log", "sqrt"} Choose interpolation scaling. """ifscalein["lin","linear"]:returnLinearScale()elifscale=="log":returnLogScale()elifscale=="sqrt":returnSqrtScale()elifscale=="stat-profile":returnStatProfileScale()elifisinstance(scale,InterpolationScale):returnscaleelse:raiseValueError(f"Not a valid value scaling mode: '{scale}'.")
classInterpolationScale:"""Interpolation scale base class."""def__call__(self,values):ifhasattr(self,"_unit"):values=u.Quantity(values,copy=False).to_value(self._unit)else:ifisinstance(values,u.Quantity):self._unit=values.unitvalues=values.valuereturnself._scale(values)definverse(self,values):values=self._inverse(values)ifhasattr(self,"_unit"):returnu.Quantity(values,self._unit,copy=False)else:returnvaluesclassLogScale(InterpolationScale):"""Logarithmic scaling"""tiny=np.finfo(np.float32).tinydef_scale(self,values):values=np.clip(values,self.tiny,np.inf)returnnp.log(values)@classmethoddef_inverse(cls,values):output=np.exp(values)returnnp.where(abs(output)-cls.tiny<=cls.tiny,0,output)classSqrtScale(InterpolationScale):"""Sqrt scaling"""@staticmethoddef_scale(values):sign=np.sign(values)returnsign*np.sqrt(sign*values)@classmethoddef_inverse(cls,values):returnnp.power(values,2)classStatProfileScale(InterpolationScale):"""Sqrt scaling"""def__init__(self,axis=0):self.axis=axisdef_scale(self,values):values=np.sign(np.gradient(values,axis=self.axis))*valuessign=np.sign(values)returnsign*np.sqrt(sign*values)@classmethoddef_inverse(cls,values):returnnp.power(values,2)classLinearScale(InterpolationScale):"""Linear scaling"""@staticmethoddef_scale(values):returnvalues@classmethoddef_inverse(cls,values):returnvalues
[docs]definterpolate_profile(x,y,interp_scale="sqrt"):"""Helper function to interpolate one-dimensional profiles. Parameters ---------- x : `~numpy.ndarray` Array of x values y : `~numpy.ndarray` Array of y values interp_scale : {"sqrt", "lin"} Interpolation scale applied to the profile. If the profile is of parabolic shape, a "sqrt" scaling is recommended. In other cases or for fine sampled profiles a "lin" can also be used. Returns ------- interp : `ScaledRegularGridInterpolator` Interpolator """sign=np.sign(np.gradient(y))returnScaledRegularGridInterpolator(points=(x,),values=sign*y,values_scale=interp_scale)