# Licensed under a 3-clause BSD style license - see LICENSE.rst
"""Utilities to create scripts and command-line tools."""
import ast
import codecs
import operator
import os.path
import functools
import types
import warnings
import numpy as np
from base64 import urlsafe_b64encode
from pathlib import Path
from uuid import uuid4
import yaml
from gammapy.utils.check import add_checksum, verify_checksum
__all__ = [
"from_yaml",
"get_images_paths",
"make_path",
"read_yaml",
"recursive_merge_dicts",
"to_yaml",
"write_yaml",
]
PATH_DOCS = Path(__file__).resolve().parent / ".." / ".." / "docs"
SKIP = ["_static", "_build", "_checkpoints", "docs/user-guide/model-gallery/"]
YAML_FORMAT = dict(sort_keys=False, indent=4, width=80, default_flow_style=False)
[docs]
def get_images_paths(folder=PATH_DOCS):
"""Generator yields a Path for each image used in notebook.
Parameters
----------
folder : str
Folder where to search.
"""
for i in Path(folder).rglob("images/*"):
if not any(s in str(i) for s in SKIP):
yield i.resolve()
[docs]
def from_yaml(text, sort_keys=False, checksum=False):
"""Read YAML file.
Parameters
----------
text : str
yaml str
sort_keys : bool, optional
Whether to sort keys. Default is False.
checksum : bool
Whether to perform checksum verification. Default is False.
Returns
-------
data : dict
YAML file content as a dictionary.
"""
data = yaml.safe_load(text)
checksum_str = data.pop("checksum", None)
if checksum:
yaml_format = YAML_FORMAT.copy()
yaml_format["sort_keys"] = sort_keys
yaml_str = yaml.dump(data, **yaml_format)
if not verify_checksum(yaml_str, checksum_str):
warnings.warn("Checksum verification failed.", UserWarning)
return data
[docs]
def read_yaml(filename, logger=None, checksum=False):
"""Read YAML file.
Parameters
----------
filename : `~pathlib.Path`
Filename.
logger : `~logging.Logger`
Logger.
checksum : bool
Whether to perform checksum verification. Default is False.
Returns
-------
data : dict
YAML file content as a dictionary.
"""
path = make_path(filename)
if logger is not None:
logger.info(f"Reading {path}")
text = path.read_text()
return from_yaml(text, checksum=checksum)
[docs]
def to_yaml(dictionary, sort_keys=False):
"""Dictionary to yaml file.
Parameters
----------
dictionary : dict
Python dictionary.
sort_keys : bool, optional
Whether to sort keys. Default is False.
"""
from gammapy.utils.metadata import CreatorMetaData
yaml_format = YAML_FORMAT.copy()
yaml_format["sort_keys"] = sort_keys
text = yaml.safe_dump(dictionary, **yaml_format)
creation = CreatorMetaData()
return text + creation.to_yaml()
[docs]
def write_yaml(
text, filename, logger=None, sort_keys=False, checksum=False, overwrite=False
):
"""Write YAML file.
Parameters
----------
text : str
yaml str
filename : `~pathlib.Path`
Filename.
logger : `~logging.Logger`, optional
Logger. Default is None.
sort_keys : bool, optional
Whether to sort keys. Default is True.
checksum : bool, optional
Whether to add checksum keyword. Default is False.
overwrite : bool, optional
Overwrite existing file. Default is False.
"""
if checksum:
text = add_checksum(text, sort_keys=sort_keys)
path = make_path(filename)
path.parent.mkdir(exist_ok=True)
if path.exists() and not overwrite:
raise IOError(f"File exists already: {path}")
if logger is not None:
logger.info(f"Writing {path}")
path.write_text(text)
def make_name(name=None):
"""Make a dataset name."""
if name is None:
name = urlsafe_b64encode(codecs.decode(uuid4().hex, "hex")).decode()[:8]
while name[0] == "_":
name = urlsafe_b64encode(codecs.decode(uuid4().hex, "hex")).decode()[:8]
if not isinstance(name, str):
raise ValueError(
"Name argument must be a string, "
f"got '{name}', which is of type '{type(name)}'"
)
return name
[docs]
def make_path(path):
"""Expand environment variables on `~pathlib.Path` construction.
Parameters
----------
path : str, `pathlib.Path`
Path to expand.
"""
# TODO: raise error or warning if environment variables that don't resolve are used
# e.g. "spam/$DAMN/ham" where `$DAMN` is not defined
# Otherwise this can result in cryptic errors later on
if path is None:
return None
else:
return Path(os.path.expandvars(path))
[docs]
def recursive_merge_dicts(a, b):
"""Recursively merge two dictionaries.
Entries in 'b' override entries in 'a'. The built-in update function cannot be
used for hierarchical dicts, see:
http://stackoverflow.com/questions/3232943/update-value-of-a-nested-dictionary-of-varying-depth/3233356#3233356
Parameters
----------
a : dict
Dictionary to be merged.
b : dict
Dictionary to be merged.
Returns
-------
c : dict
Merged dictionary.
Examples
--------
>>> from gammapy.utils.scripts import recursive_merge_dicts
>>> a = dict(a=42, b=dict(c=43, e=44))
>>> b = dict(d=99, b=dict(c=50, g=98))
>>> c = recursive_merge_dicts(a, b)
>>> print(c)
{'a': 42, 'b': {'c': 50, 'e': 44, 'g': 98}, 'd': 99}
"""
c = a.copy()
for k, v in b.items():
if k in c and isinstance(c[k], dict):
c[k] = recursive_merge_dicts(c[k], v)
else:
c[k] = v
return c
def requires_module(module_name):
"""
Decorator that conditionally enables a method or property based on the availability of a module.
If the specified module is available, the decorated method or property is returned as-is.
If the module is not available:
- For methods: replaces the method with one that raises ImportError when called.
- For properties: replaces the property with one that raises ImportError when accessed.
Parameters
----------
module_name : str
The name of the module to check for.
Returns
-------
function or property
The original object if the module is available, otherwise a fallback.
"""
def decorator(obj):
try:
__import__(module_name)
return obj # Module is available
except ImportError:
if isinstance(obj, property):
return property(
lambda self: raise_import_error(module_name, is_property=True)
)
elif isinstance(obj, (types.FunctionType, types.MethodType)):
@functools.wraps(obj)
def wrapper(*args, **kwargs):
raise_import_error(module_name)
return wrapper
else:
raise TypeError(
"requires_module can only be used on methods or properties."
)
return decorator
def raise_import_error(module_name, is_property=False):
"""
Raises an ImportError with a descriptive message about a missing module.
Parameters
----------
module_name : str
The name of the required module.
is_property : bool
Whether the error is for a property (affects the error message).
"""
kind = "property" if is_property else "method"
raise ImportError(f"The '{module_name}' module is required to use this {kind}.")
# Mapping of AST operators to NumPy functions
_OPERATORS = {
ast.And: np.logical_and,
ast.Or: np.logical_or,
ast.Eq: operator.eq,
ast.NotEq: operator.ne,
ast.Lt: operator.lt,
ast.LtE: operator.le,
ast.Gt: operator.gt,
ast.GtE: operator.ge,
ast.In: lambda a, b: np.isin(a, b),
ast.NotIn: lambda a, b: ~np.isin(a, b),
}
def logic_parser(table, expression):
"""
Parse and apply a logical expression to filter rows from an Astropy Table.
This function evaluates a logical expression on each row of the input table
and returns a new table containing only the rows that satisfy the expression.
The expression can reference any column in the table by name and supports
logical operators (`and`, `or`), comparison operators (`<`, `>`, `==`, `!=`, `in`),
lists, and constants.
Parameters
----------
table : `~astropy.table.Table`
The input table to filter.
expression : str
The logical expression to evaluate on each row. The expression can reference
any column in the table by name.
Returns
-------
table : `~astropy.table.Table`
A table view containing only the rows that satisfy the expression. If no rows
match the condition, an empty table with the same column names and data types
as the input table is returned.
Examples
--------
Given a table with columns 'OBS_ID' and 'EVENT_TYPE':
>>> from astropy.table import Table
>>> data = {'OBS_ID': [1, 2, 3, 4], 'EVENT_TYPE': ['1', '3', '4', '2']}
>>> table = Table(data)
>>> expression = '(OBS_ID < 3) and (OBS_ID > 1) and ((EVENT_TYPE in ["3", "4"]) or (EVENT_TYPE == "3"))'
>>> filtered_table = logic_parser(table, expression)
>>> print(filtered_table)
OBS_ID EVENT_TYPE
------ ----------
2 3
"""
def eval_node(node):
if isinstance(node, ast.BoolOp):
op_func = _OPERATORS[type(node.op)]
values = [eval_node(v) for v in node.values]
result = values[0]
for v in values[1:]:
result = op_func(result, v)
return result
elif isinstance(node, ast.Compare):
left = eval_node(node.left)
for op, comparator in zip(node.ops, node.comparators):
right = eval_node(comparator)
op_func = _OPERATORS[type(op)]
left = op_func(left, right)
return left
elif isinstance(node, ast.Name):
if node.id not in table.colnames:
raise KeyError(
f"Column '{node.id}' not found in the table. Available columns: {table.colnames}"
)
return table[node.id]
elif isinstance(node, ast.Constant):
return node.value
elif isinstance(node, ast.List):
return [eval_node(elt) for elt in node.elts]
else:
raise ValueError(f"Unsupported expression type: {type(node)}")
expr_ast = ast.parse(expression, mode="eval")
mask = eval_node(expr_ast.body)
return table[mask]