# Licensed under a 3-clause BSD style license - see LICENSE.rst
"""Utilities to create scripts and command-line tools."""
import ast
import codecs
import operator
import os.path
import functools
import types
import warnings
import numpy as np
from base64 import urlsafe_b64encode
from pathlib import Path
from uuid import uuid4
import yaml
from gammapy.utils.check import add_checksum, verify_checksum
__all__ = [
"from_yaml",
"get_images_paths",
"make_path",
"read_yaml",
"recursive_merge_dicts",
"to_yaml",
"write_yaml",
"logic_parser",
]
PATH_DOCS = Path(__file__).resolve().parent / ".." / ".." / "docs"
SKIP = ["_static", "_build", "_checkpoints", "docs/user-guide/model-gallery/"]
YAML_FORMAT = dict(sort_keys=False, indent=4, width=80, default_flow_style=False)
[docs]
def get_images_paths(folder=PATH_DOCS):
"""Generator yields a Path for each image used in notebook.
Parameters
----------
folder : str
Folder where to search.
"""
for i in Path(folder).rglob("images/*"):
if not any(s in str(i) for s in SKIP):
yield i.resolve()
[docs]
def from_yaml(text, sort_keys=False, checksum=False):
"""Read YAML file.
Parameters
----------
text : str
yaml str
sort_keys : bool, optional
Whether to sort keys. Default is False.
checksum : bool
Whether to perform checksum verification. Default is False.
Returns
-------
data : dict
YAML file content as a dictionary.
"""
data = yaml.safe_load(text)
checksum_str = data.pop("checksum", None)
if checksum:
yaml_format = YAML_FORMAT.copy()
yaml_format["sort_keys"] = sort_keys
yaml_str = yaml.dump(data, **yaml_format)
if not verify_checksum(yaml_str, checksum_str):
warnings.warn("Checksum verification failed.", UserWarning)
return data
[docs]
def read_yaml(filename, logger=None, checksum=False):
"""Read YAML file.
Parameters
----------
filename : `~pathlib.Path`
Filename.
logger : `~logging.Logger`
Logger.
checksum : bool
Whether to perform checksum verification. Default is False.
Returns
-------
data : dict
YAML file content as a dictionary.
"""
if filename is None:
raise ValueError("The filename is not defined.")
path = make_path(filename)
if logger is not None:
logger.info(f"Reading {path}")
text = path.read_text()
return from_yaml(text, checksum=checksum)
[docs]
def to_yaml(dictionary, sort_keys=False):
"""Dictionary to yaml file.
Parameters
----------
dictionary : dict
Python dictionary.
sort_keys : bool, optional
Whether to sort keys. Default is False.
"""
from gammapy.utils.metadata import CreatorMetaData
yaml_format = YAML_FORMAT.copy()
yaml_format["sort_keys"] = sort_keys
text = yaml.safe_dump(dictionary, **yaml_format)
creation = CreatorMetaData()
return text + creation.to_yaml()
[docs]
def write_yaml(
text, filename, logger=None, sort_keys=False, checksum=False, overwrite=False
):
"""Write YAML file.
Parameters
----------
text : str
yaml str
filename : `~pathlib.Path`
Filename.
logger : `~logging.Logger`, optional
Logger. Default is None.
sort_keys : bool, optional
Whether to sort keys. Default is True.
checksum : bool, optional
Whether to add checksum keyword. Default is False.
overwrite : bool, optional
Overwrite existing file. Default is False.
"""
if checksum:
text = add_checksum(text, sort_keys=sort_keys)
if filename is None:
raise ValueError("The filename is not defined.")
path = make_path(filename)
path.parent.mkdir(exist_ok=True)
if path.exists() and not overwrite:
raise IOError(f"File exists already: {path}")
if logger is not None:
logger.info(f"Writing {path}")
path.write_text(text)
def make_name(name=None):
"""Make a dataset name."""
if name is None:
name = urlsafe_b64encode(codecs.decode(uuid4().hex, "hex")).decode()[:8]
while name[0] == "_":
name = urlsafe_b64encode(codecs.decode(uuid4().hex, "hex")).decode()[:8]
if not isinstance(name, str):
raise ValueError(
"Name argument must be a string, "
f"got '{name}', which is of type '{type(name)}'"
)
return name
[docs]
def make_path(path):
"""Expand environment variables on `~pathlib.Path` construction.
Parameters
----------
path : str, `pathlib.Path`
Path to expand.
"""
# TODO: raise error or warning if environment variables that don't resolve are used
# e.g. "spam/$DAMN/ham" where `$DAMN` is not defined
# Otherwise this can result in cryptic errors later on
if path is None:
return None
else:
return Path(os.path.expandvars(path))
[docs]
def recursive_merge_dicts(a, b):
"""Recursively merge two dictionaries.
Entries in 'b' override entries in 'a'. The built-in update function cannot be
used for hierarchical dicts, see:
http://stackoverflow.com/questions/3232943/update-value-of-a-nested-dictionary-of-varying-depth/3233356#3233356
Parameters
----------
a : dict
Dictionary to be merged.
b : dict
Dictionary to be merged.
Returns
-------
c : dict
Merged dictionary.
Examples
--------
>>> from gammapy.utils.scripts import recursive_merge_dicts
>>> a = dict(a=42, b=dict(c=43, e=44))
>>> b = dict(d=99, b=dict(c=50, g=98))
>>> c = recursive_merge_dicts(a, b)
>>> print(c)
{'a': 42, 'b': {'c': 50, 'e': 44, 'g': 98}, 'd': 99}
"""
c = a.copy()
for k, v in b.items():
if k in c and isinstance(c[k], dict):
c[k] = recursive_merge_dicts(c[k], v)
else:
c[k] = v
return c
def requires_module(module_name):
"""
Decorator that conditionally enables a method or property based on the availability of a module.
If the specified module is available, the decorated method or property is returned as-is.
If the module is not available:
- For methods: replaces the method with one that raises ImportError when called.
- For properties: replaces the property with one that raises ImportError when accessed.
Parameters
----------
module_name : str
The name of the module to check for.
Returns
-------
function or property
The original object if the module is available, otherwise a fallback.
"""
def decorator(obj):
try:
__import__(module_name)
return obj # Module is available
except ImportError:
if isinstance(obj, property):
return property(
lambda self: raise_import_error(module_name, is_property=True)
)
elif isinstance(obj, (types.FunctionType, types.MethodType)):
@functools.wraps(obj)
def wrapper(*args, **kwargs):
raise_import_error(module_name)
return wrapper
else:
raise TypeError(
"requires_module can only be used on methods or properties."
)
return decorator
def raise_import_error(module_name, is_property=False):
"""
Raises an ImportError with a descriptive message about a missing module.
Parameters
----------
module_name : str
The name of the required module.
is_property : bool
Whether the error is for a property (affects the error message).
"""
kind = "property" if is_property else "method"
raise ImportError(f"The '{module_name}' module is required to use this {kind}.")
# Mapping of AST operators to NumPy functions
_OPERATORS = {
ast.And: np.logical_and,
ast.Or: np.logical_or,
ast.Eq: operator.eq,
ast.NotEq: operator.ne,
ast.Lt: operator.lt,
ast.LtE: operator.le,
ast.Gt: operator.gt,
ast.GtE: operator.ge,
ast.In: lambda a, b: np.isin(a, b),
ast.NotIn: lambda a, b: ~np.isin(a, b),
}
[docs]
def logic_parser(table, expression):
"""
Parse and apply a logical expression to filter rows from an Astropy Table.
This function evaluates a logical expression on each row of the input table
and returns a new table containing only the rows that satisfy the expression.
The expression can reference any column in the table by name and supports
logical operators (``and``, ``or``), comparison operators
(``<``, ``<=``, ``>``, ``>=``, ``==``, ``!=``, ``in``), lists, constants,
and chained comparisons (e.g. ``1 < OBS_ID < 3``).
Chained comparisons follow standard Python semantics and are equivalent to
combining individual comparisons with ``and``. For example,
``1 < OBS_ID < 3`` is equivalent to
``(OBS_ID > 1) and (OBS_ID < 3)``.
Parameters
----------
table : `~astropy.table.Table`
The input table to filter.
expression : str
The logical expression to evaluate on each row. The expression can reference
any column in the table by name.
Returns
-------
table : `~astropy.table.Table`
A table view containing only the rows that satisfy the expression. If no rows
match the condition, an empty table with the same column names and data types
as the input table is returned.
Examples
--------
Given a table with columns 'OBS_ID' and 'EVENT_TYPE':
>>> from astropy.table import Table
>>> data = {'OBS_ID': [1, 2, 3, 4], 'EVENT_TYPE': ['1', '3', '4', '2']}
>>> table = Table(data)
Standard logical expression:
>>> expression = '(OBS_ID < 3) and (OBS_ID > 1)'
>>> logic_parser(table, expression)
<Table length=1>
OBS_ID EVENT_TYPE
int64 str1
------ ----------
2 3
Using chained comparisons:
>>> expression = '1 < OBS_ID < 3'
>>> logic_parser(table, expression)
<Table length=1>
OBS_ID EVENT_TYPE
int64 str1
------ ----------
2 3
Combining chained comparisons with other conditions:
>>> expression = '(1 < OBS_ID < 4) and (EVENT_TYPE in ["3", "4"])'
>>> logic_parser(table, expression)
<Table length=2>
OBS_ID EVENT_TYPE
int64 str1
------ ----------
2 3
3 4
"""
def handle_boolop(node):
op_func = _OPERATORS[type(node.op)]
values = [eval_node(v) for v in node.values]
result = values[0]
for value in values[1:]:
result = op_func(result, value)
return result
def handle_compare(node):
left = eval_node(node.left)
parts = []
for op, comparator in zip(node.ops, node.comparators):
right = eval_node(comparator)
op_func = _OPERATORS[type(op)]
parts.append(op_func(left, right))
left = right
result = parts[0]
for part in parts[1:]:
result = np.logical_and(result, part)
return result
def handle_name(node):
if node.id not in table.colnames:
raise KeyError(
f"Column '{node.id}' not found in the table. "
f"Available columns: {table.colnames}"
)
return table[node.id]
def handle_constant(node):
return node.value
def handle_list(node):
return [eval_node(elt) for elt in node.elts]
handlers = {
ast.BoolOp: handle_boolop,
ast.Compare: handle_compare,
ast.Name: handle_name,
ast.Constant: handle_constant,
ast.List: handle_list,
}
def eval_node(node):
handler = handlers.get(type(node))
if handler is None:
raise ValueError(f"Unsupported expression type: {type(node)}")
return handler(node)
expr_ast = ast.parse(expression, mode="eval")
mask = eval_node(expr_ast.body)
return table[mask]
def method_wrapper(func):
"""
Wrap a function for use as a method while preserving its metadata.
This utility wraps a function so it can be assigned as a method on other
classes (or instances) and still calls the original function with the
receiving object as the first argument. The wrapper copies the original
function's metadata (e.g. docstring, name, module, annotations),
which makes the wrapped method appear
in introspection and documentation like the original.
"""
def wrapper(self, *args, **kwargs):
return func(self, *args, **kwargs)
functools.update_wrapper(wrapper, func)
return wrapper