# Licensed under a 3-clause BSD style license - see LICENSE.rst
"""Utilities for testing."""
import os
import sys
from numpy.testing import assert_allclose
import astropy
import astropy.units as u
from astropy.coordinates import SkyCoord
from astropy.time import Time
from astropy.utils.introspection import minversion
import matplotlib.pyplot as plt
from .compat import COPY_IF_NEEDED
__all__ = [
    "assert_quantity_allclose",
    "assert_skycoord_allclose",
    "assert_time_allclose",
    "Checker",
    "mpl_plot_check",
    "requires_data",
    "requires_dependency",
]
ASTROPY_LT_5_3 = minversion(astropy, "5.3.dev")
# Cache for `requires_dependency`
_requires_dependency_cache = {}
[docs]
def requires_dependency(name):
    """Decorator to declare required dependencies for tests.
    Examples
    --------
    ::
        from gammapy.utils.testing import requires_dependency
        @requires_dependency('scipy')
        def test_using_scipy():
            import scipy
            ...
    """
    import pytest
    if name in _requires_dependency_cache:
        skip_it = _requires_dependency_cache[name]
    else:
        try:
            __import__(name)
            skip_it = False
        except ImportError:
            skip_it = True
        _requires_dependency_cache[name] = skip_it
    reason = f"Missing dependency: {name}"
    return pytest.mark.skipif(skip_it, reason=reason) 
def has_data(name):
    """Check if certain set of data available."""
    if name == "gammapy-extra":
        return "GAMMAPY_EXTRA" in os.environ
    elif name == "gammapy-data":
        return "GAMMAPY_DATA" in os.environ
    elif name == "gamma-cat":
        return "GAMMA_CAT" in os.environ
    elif name == "fermi-lat":
        return "GAMMAPY_FERMI_LAT_DATA" in os.environ
    else:
        raise ValueError(f"Invalid name: {name}")
[docs]
def requires_data(name="gammapy-data"):
    """Decorator to declare required data for tests.
    Examples
    --------
    ::
        from gammapy.utils.testing import requires_data
        @requires_data()
        def test_using_data_files():
            filename = "$GAMMAPY_DATA/..."
            ...
    """
    import pytest
    if not isinstance(name, str):
        raise TypeError(
            "You must call @requires_data with a name (str). "
            "Usually this:  @requires_data()"
        )
    skip_it = not has_data(name)
    reason = f"Missing data: {name}"
    return pytest.mark.skipif(skip_it, reason=reason) 
def run_cli(cli, args, exit_code=0):
    """Run Click command line tool.
    Thin wrapper around `click.testing.CliRunner`
    that prints info to stderr if the command fails.
    Parameters
    ----------
    cli : click.Command
        Click command.
    args : list of str
        Argument list.
    exit_code : int, optional
        Expected exit code of the command. Default is 0.
    Returns
    -------
    result : `click.testing.Result`
        Result.
    """
    from click.testing import CliRunner
    result = CliRunner().invoke(cli, args, catch_exceptions=False)
    if result.exit_code != exit_code:
        sys.stderr.write("Exit code mismatch!\n")
        sys.stderr.write("Output:\n")
        sys.stderr.write(result.output)
    return result
[docs]
def assert_skycoord_allclose(actual, desired):
    """Assert all-close for `astropy.coordinates.SkyCoord` objects.
    - Frames can be different, aren't checked at the moment.
    """
    assert isinstance(actual, SkyCoord)
    assert isinstance(desired, SkyCoord)
    assert_allclose(actual.data.lon.deg, desired.data.lon.deg)
    assert_allclose(actual.data.lat.deg, desired.data.lat.deg) 
[docs]
def assert_time_allclose(actual, desired, atol=1e-3):
    """Assert all-close for `astropy.time.Time` objects.
    atol : Absolute tolerance in seconds. Default is 1e-3.
    """
    assert isinstance(actual, Time)
    assert isinstance(desired, Time)
    assert actual.scale == desired.scale
    assert actual.format == desired.format
    dt = actual - desired
    assert_allclose(dt.sec, 0, rtol=0, atol=atol) 
[docs]
def assert_quantity_allclose(actual, desired, rtol=1.0e-7, atol=None, **kwargs):
    """Assert all-close for `~astropy.units.Quantity` objects.
    Notes
    -----
    Requires that ``unit`` is identical, not just that quantities
    are allclose taking different units into account.
    We prefer this kind of assert for testing, since units
    should only change on purpose, so this tests more behaviour.
    """
    # TODO: change this later to explicitly check units are the same!
    # assert actual.unit == desired.unit
    args = _unquantify_allclose_arguments(actual, desired, rtol, atol)
    assert_allclose(*args, **kwargs) 
def _unquantify_allclose_arguments(actual, desired, rtol, atol):
    actual = u.Quantity(actual, subok=True, copy=COPY_IF_NEEDED)
    desired = u.Quantity(desired, subok=True, copy=COPY_IF_NEEDED)
    try:
        desired = desired.to(actual.unit)
    except u.UnitsError:
        raise u.UnitsError(
            "Units for 'desired' ({}) and 'actual' ({}) "
            "are not convertible".format(desired.unit, actual.unit)
        )
    if atol is None:
        # by default, we assume an absolute tolerance of 0
        atol = u.Quantity(0)
    else:
        atol = u.Quantity(atol, subok=True, copy=COPY_IF_NEEDED)
        try:
            atol = atol.to(actual.unit)
        except u.UnitsError:
            raise u.UnitsError(
                "Units for 'atol' ({}) and 'actual' ({}) "
                "are not convertible".format(atol.unit, actual.unit)
            )
    rtol = u.Quantity(rtol, subok=True, copy=COPY_IF_NEEDED)
    try:
        rtol = rtol.to(u.dimensionless_unscaled)
    except Exception:
        raise u.UnitsError("`rtol` should be dimensionless")
    return actual.value, desired.value, rtol.value, atol.value
[docs]
def mpl_plot_check():
    """Matplotlib plotting test context manager.
    Create a new figure on __enter__ and calls savefig for the
    current figure in __exit__. This will trigger a render of the
    Figure, which can sometimes raise errors if there is a problem.
    This is writing to an in-memory byte buffer, i.e. is faster
    than writing to disk.
    """
    from io import BytesIO
    class MPLPlotCheck:
        def __enter__(self):
            plt.figure()
        def __exit__(self, type, value, traceback):
            plt.savefig(BytesIO(), format="png")
            plt.close()
    return MPLPlotCheck() 
[docs]
class Checker:
    """Base class for checker classes in Gammapy."""
[docs]
    def run(self, checks="all"):
        if checks == "all":
            checks = self.CHECKS.keys()
        unknown_checks = sorted(set(checks).difference(self.CHECKS.keys()))
        if unknown_checks:
            raise ValueError(f"Unknown checks: {unknown_checks!r}")
        for check in checks:
            method = getattr(self, self.CHECKS[check])
            yield from method() 
 
UNIT_REPLACEMENTS_ASTROPY_5_3 = {
    "cm2 s TeV": "TeV s cm2",
    "1 / (cm2 s)": "1 / (s cm2)",
    "erg / (cm2 s)": "erg / (s cm2)",
}
def modify_unit_order_astropy_5_3(expected_str):
    """Modify unit order for tests with astropy >= 5.3."""
    if ASTROPY_LT_5_3:
        for old, new in UNIT_REPLACEMENTS_ASTROPY_5_3.items():
            expected_str = expected_str.replace(old, new)
    return expected_str