# Licensed under a 3-clause BSD style license - see LICENSE.rst
"""Utilities for testing."""
import os
import sys
from numpy.testing import assert_allclose
import astropy
import astropy.units as u
from astropy.coordinates import SkyCoord
from astropy.time import Time
from astropy.utils.introspection import minversion
import matplotlib.pyplot as plt
__all__ = [
"assert_quantity_allclose",
"assert_skycoord_allclose",
"assert_time_allclose",
"Checker",
"mpl_plot_check",
"requires_data",
"requires_dependency",
]
ASTROPY_LT_5_3 = minversion(astropy, "5.3.dev")
# Cache for `requires_dependency`
_requires_dependency_cache = {}
[docs]def requires_dependency(name):
"""Decorator to declare required dependencies for tests.
Examples
--------
::
from gammapy.utils.testing import requires_dependency
@requires_dependency('scipy')
def test_using_scipy():
import scipy
...
"""
import pytest
if name in _requires_dependency_cache:
skip_it = _requires_dependency_cache[name]
else:
try:
__import__(name)
skip_it = False
except ImportError:
skip_it = True
_requires_dependency_cache[name] = skip_it
reason = f"Missing dependency: {name}"
return pytest.mark.skipif(skip_it, reason=reason)
def has_data(name):
"""Check if certain set of data available."""
if name == "gammapy-extra":
return "GAMMAPY_EXTRA" in os.environ
elif name == "gammapy-data":
return "GAMMAPY_DATA" in os.environ
elif name == "gamma-cat":
return "GAMMA_CAT" in os.environ
elif name == "fermi-lat":
return "GAMMAPY_FERMI_LAT_DATA" in os.environ
else:
raise ValueError(f"Invalid name: {name}")
[docs]def requires_data(name="gammapy-data"):
"""Decorator to declare required data for tests.
Examples
--------
::
from gammapy.utils.testing import requires_data
@requires_data()
def test_using_data_files():
filename = "$GAMMAPY_DATA/..."
...
"""
import pytest
if not isinstance(name, str):
raise TypeError(
"You must call @requires_data with a name (str). "
"Usually this: @requires_data()"
)
skip_it = not has_data(name)
reason = f"Missing data: {name}"
return pytest.mark.skipif(skip_it, reason=reason)
def run_cli(cli, args, exit_code=0):
"""Run Click command line tool.
Thin wrapper around `click.testing.CliRunner`
that prints info to stderr if the command fails.
Parameters
----------
cli : click.Command
Click command.
args : list of str
Argument list.
exit_code : int, optional
Expected exit code of the command. Default is 0.
Returns
-------
result : `click.testing.Result`
Result.
"""
from click.testing import CliRunner
result = CliRunner().invoke(cli, args, catch_exceptions=False)
if result.exit_code != exit_code:
sys.stderr.write("Exit code mismatch!\n")
sys.stderr.write("Output:\n")
sys.stderr.write(result.output)
return result
[docs]def assert_skycoord_allclose(actual, desired):
"""Assert all-close for `astropy.coordinates.SkyCoord` objects.
- Frames can be different, aren't checked at the moment.
"""
assert isinstance(actual, SkyCoord)
assert isinstance(desired, SkyCoord)
assert_allclose(actual.data.lon.deg, desired.data.lon.deg)
assert_allclose(actual.data.lat.deg, desired.data.lat.deg)
[docs]def assert_time_allclose(actual, desired, atol=1e-3):
"""Assert all-close for `astropy.time.Time` objects.
atol : Absolute tolerance in seconds. Default is 1e-3.
"""
assert isinstance(actual, Time)
assert isinstance(desired, Time)
assert actual.scale == desired.scale
assert actual.format == desired.format
dt = actual - desired
assert_allclose(dt.sec, 0, rtol=0, atol=atol)
[docs]def assert_quantity_allclose(actual, desired, rtol=1.0e-7, atol=None, **kwargs):
"""Assert all-close for `~astropy.units.Quantity` objects.
Notes
_____
Requires that ``unit`` is identical, not just that quantities
are allclose taking different units into account.
We prefer this kind of assert for testing, since units
should only change on purpose, so this tests more behaviour.
"""
# TODO: change this later to explicitly check units are the same!
# assert actual.unit == desired.unit
args = _unquantify_allclose_arguments(actual, desired, rtol, atol)
assert_allclose(*args, **kwargs)
def _unquantify_allclose_arguments(actual, desired, rtol, atol):
actual = u.Quantity(actual, subok=True, copy=False)
desired = u.Quantity(desired, subok=True, copy=False)
try:
desired = desired.to(actual.unit)
except u.UnitsError:
raise u.UnitsError(
"Units for 'desired' ({}) and 'actual' ({}) "
"are not convertible".format(desired.unit, actual.unit)
)
if atol is None:
# by default, we assume an absolute tolerance of 0
atol = u.Quantity(0)
else:
atol = u.Quantity(atol, subok=True, copy=False)
try:
atol = atol.to(actual.unit)
except u.UnitsError:
raise u.UnitsError(
"Units for 'atol' ({}) and 'actual' ({}) "
"are not convertible".format(atol.unit, actual.unit)
)
rtol = u.Quantity(rtol, subok=True, copy=False)
try:
rtol = rtol.to(u.dimensionless_unscaled)
except Exception:
raise u.UnitsError("`rtol` should be dimensionless")
return actual.value, desired.value, rtol.value, atol.value
[docs]def mpl_plot_check():
"""Matplotlib plotting test context manager.
Create a new figure on __enter__ and calls savefig for the
current figure in __exit__. This will trigger a render of the
Figure, which can sometimes raise errors if there is a problem.
This is writing to an in-memory byte buffer, i.e. is faster
than writing to disk.
"""
from io import BytesIO
class MPLPlotCheck:
def __enter__(self):
plt.figure()
def __exit__(self, type, value, traceback):
plt.savefig(BytesIO(), format="png")
plt.close()
return MPLPlotCheck()
[docs]class Checker:
"""Base class for checker classes in Gammapy."""
[docs] def run(self, checks="all"):
if checks == "all":
checks = self.CHECKS.keys()
unknown_checks = sorted(set(checks).difference(self.CHECKS.keys()))
if unknown_checks:
raise ValueError(f"Unknown checks: {unknown_checks!r}")
for check in checks:
method = getattr(self, self.CHECKS[check])
yield from method()
UNIT_REPLACEMENTS_ASTROPY_5_3 = {
"cm2 s TeV": "TeV s cm2",
"1 / (cm2 s)": "1 / (s cm2)",
"erg / (cm2 s)": "erg / (s cm2)",
}
def modify_unit_order_astropy_5_3(expected_str):
"""Modify unit order for tests with astropy >= 5.3."""
if ASTROPY_LT_5_3:
for old, new in UNIT_REPLACEMENTS_ASTROPY_5_3.items():
expected_str = expected_str.replace(old, new)
return expected_str