# Licensed under a 3-clause BSD style license - see LICENSE.rst"""Interpolation utilities."""importhtmlfromitertoolsimportcompressimportnumpyasnpimportscipy.interpolatefromastropyimportunitsasu__all__=["interpolate_profile","interpolation_scale","ScaledRegularGridInterpolator",]INTERPOLATION_ORDER={None:0,"nearest":0,"linear":1,"quadratic":2,"cubic":3}
[docs]classScaledRegularGridInterpolator:"""Thin wrapper around `scipy.interpolate.RegularGridInterpolator`. The values are scaled before the interpolation and back-scaled after the interpolation. Dimensions of length 1 are ignored in the interpolation of the data. Parameters ---------- points : tuple of `~numpy.ndarray` or `~astropy.units.Quantity` Tuple of points passed to `RegularGridInterpolator`. values : `~numpy.ndarray` Values passed to `RegularGridInterpolator`. points_scale : tuple of str Interpolation scale used for the points. values_scale : {'lin', 'log', 'sqrt'} Interpolation scaling applied to values. If the values vary over many magnitudes a 'log' scaling is recommended. axis : int or None Axis along which to interpolate. method : {"linear", "nearest"} Default interpolation method. Can be overwritten when calling the `ScaledRegularGridInterpolator`. **kwargs : dict Keyword arguments passed to `RegularGridInterpolator`. """def__init__(self,points,values,points_scale=None,values_scale="lin",extrapolate=True,axis=None,**kwargs,):ifpoints_scaleisNone:points_scale=["lin"]*len(points)self.scale_points=[interpolation_scale(scale)forscaleinpoints_scale]self.scale=interpolation_scale(values_scale)self.axis=axisself._include_dimensions=[len(p)>1forpinpoints]values_scaled=self.scale(values)points_scaled=self._scale_points(points=points)ifextrapolate:kwargs.setdefault("bounds_error",False)kwargs.setdefault("fill_value",None)method=kwargs.get("method",None)ifnotnp.any(self._include_dimensions):ifmethod!="nearest":raiseValueError("Interpolating scalar values requires using ""method='nearest' explicitly.")ifnp.any(self._include_dimensions):values_scaled=np.squeeze(values_scaled)ifaxisisNone:self._interpolate=scipy.interpolate.RegularGridInterpolator(points=points_scaled,values=values_scaled,**kwargs)else:self._interpolate=scipy.interpolate.interp1d(points_scaled[0],values_scaled,axis=axis)def_repr_html_(self):try:returnself.to_html()exceptAttributeError:returnf"<pre>{html.escape(str(self))}</pre>"def_scale_points(self,points):points_scaled=[scale(p)forp,scaleinzip(points,self.scale_points)]ifnp.any(self._include_dimensions):points_scaled=compress(points_scaled,self._include_dimensions)returntuple(points_scaled)
[docs]def__call__(self,points,method=None,clip=True,**kwargs):"""Interpolate data points. Parameters ---------- points : tuple of `~numpy.ndarray` or `~astropy.units.Quantity` Tuple of coordinate arrays of the form (x_1, x_2, x_3, ...). Arrays are broadcast internally. method : {None, "linear", "nearest"} Linear or nearest neighbour interpolation. Default is None, which is `method` defined on init. clip : bool Clip values at zero after interpolation. """points=self._scale_points(points=points)ifself.axisisNone:points=np.broadcast_arrays(*points)points_interp=np.stack([_.flatfor_inpoints]).Tvalues=self._interpolate(points_interp,method,**kwargs)values=self.scale.inverse(values.reshape(points[0].shape))else:values=self._interpolate(points[0])values=self.scale.inverse(values)ifclip:values=np.clip(values,0,np.inf)returnvalues
[docs]definterpolation_scale(scale="lin"):"""Interpolation scaling. Parameters ---------- scale : {"lin", "log", "sqrt"} Choose interpolation scaling. """ifscalein["lin","linear"]:returnLinearScale()elifscale=="log":returnLogScale()elifscale=="sqrt":returnSqrtScale()elifscale=="stat-profile":returnStatProfileScale()elifisinstance(scale,InterpolationScale):returnscaleelse:raiseValueError(f"Not a valid value scaling mode: '{scale}'.")
[docs]definterpolate_profile(x,y,interp_scale="sqrt"):"""Helper function to interpolate one-dimensional profiles. Parameters ---------- x : `~numpy.ndarray` Array of x values. y : `~numpy.ndarray` Array of y values. interp_scale : {"sqrt", "lin"} Interpolation scale applied to the profile. If the profile is of parabolic shape, a "sqrt" scaling is recommended. In other cases or for fine sampled profiles a "lin" can also be used. Default is "sqrt". Returns ------- interp : `interp1d` Interpolator. """method_dict={"sqrt":"quadratic","lin":"linear"}returnscipy.interpolate.interp1d(x,y,kind=method_dict[interp_scale])