Source code for gammapy.utils.scripts

# Licensed under a 3-clause BSD style license - see LICENSE.rst
"""Utilities to create scripts and command-line tools."""

import ast
import codecs
import operator
import os.path
import functools

import types
import warnings
import numpy as np
from base64 import urlsafe_b64encode
from pathlib import Path
from uuid import uuid4
import yaml
from gammapy.utils.check import add_checksum, verify_checksum

__all__ = [
    "from_yaml",
    "get_images_paths",
    "make_path",
    "read_yaml",
    "recursive_merge_dicts",
    "to_yaml",
    "write_yaml",
    "logic_parser",
]

PATH_DOCS = Path(__file__).resolve().parent / ".." / ".." / "docs"
SKIP = ["_static", "_build", "_checkpoints", "docs/user-guide/model-gallery/"]
YAML_FORMAT = dict(sort_keys=False, indent=4, width=80, default_flow_style=False)


[docs] def get_images_paths(folder=PATH_DOCS): """Generator yields a Path for each image used in notebook. Parameters ---------- folder : str Folder where to search. """ for i in Path(folder).rglob("images/*"): if not any(s in str(i) for s in SKIP): yield i.resolve()
[docs] def from_yaml(text, sort_keys=False, checksum=False): """Read YAML file. Parameters ---------- text : str yaml str sort_keys : bool, optional Whether to sort keys. Default is False. checksum : bool Whether to perform checksum verification. Default is False. Returns ------- data : dict YAML file content as a dictionary. """ data = yaml.safe_load(text) checksum_str = data.pop("checksum", None) if checksum: yaml_format = YAML_FORMAT.copy() yaml_format["sort_keys"] = sort_keys yaml_str = yaml.dump(data, **yaml_format) if not verify_checksum(yaml_str, checksum_str): warnings.warn("Checksum verification failed.", UserWarning) return data
[docs] def read_yaml(filename, logger=None, checksum=False): """Read YAML file. Parameters ---------- filename : `~pathlib.Path` Filename. logger : `~logging.Logger` Logger. checksum : bool Whether to perform checksum verification. Default is False. Returns ------- data : dict YAML file content as a dictionary. """ if filename is None: raise ValueError("The filename is not defined.") path = make_path(filename) if logger is not None: logger.info(f"Reading {path}") text = path.read_text() return from_yaml(text, checksum=checksum)
[docs] def to_yaml(dictionary, sort_keys=False): """Dictionary to yaml file. Parameters ---------- dictionary : dict Python dictionary. sort_keys : bool, optional Whether to sort keys. Default is False. """ from gammapy.utils.metadata import CreatorMetaData yaml_format = YAML_FORMAT.copy() yaml_format["sort_keys"] = sort_keys text = yaml.safe_dump(dictionary, **yaml_format) creation = CreatorMetaData() return text + creation.to_yaml()
[docs] def write_yaml( text, filename, logger=None, sort_keys=False, checksum=False, overwrite=False ): """Write YAML file. Parameters ---------- text : str yaml str filename : `~pathlib.Path` Filename. logger : `~logging.Logger`, optional Logger. Default is None. sort_keys : bool, optional Whether to sort keys. Default is True. checksum : bool, optional Whether to add checksum keyword. Default is False. overwrite : bool, optional Overwrite existing file. Default is False. """ if checksum: text = add_checksum(text, sort_keys=sort_keys) if filename is None: raise ValueError("The filename is not defined.") path = make_path(filename) path.parent.mkdir(exist_ok=True) if path.exists() and not overwrite: raise IOError(f"File exists already: {path}") if logger is not None: logger.info(f"Writing {path}") path.write_text(text)
def make_name(name=None): """Make a dataset name.""" if name is None: name = urlsafe_b64encode(codecs.decode(uuid4().hex, "hex")).decode()[:8] while name[0] == "_": name = urlsafe_b64encode(codecs.decode(uuid4().hex, "hex")).decode()[:8] if not isinstance(name, str): raise ValueError( "Name argument must be a string, " f"got '{name}', which is of type '{type(name)}'" ) return name
[docs] def make_path(path): """Expand environment variables on `~pathlib.Path` construction. Parameters ---------- path : str, `pathlib.Path` Path to expand. """ # TODO: raise error or warning if environment variables that don't resolve are used # e.g. "spam/$DAMN/ham" where `$DAMN` is not defined # Otherwise this can result in cryptic errors later on if path is None: return None else: return Path(os.path.expandvars(path))
[docs] def recursive_merge_dicts(a, b): """Recursively merge two dictionaries. Entries in 'b' override entries in 'a'. The built-in update function cannot be used for hierarchical dicts, see: http://stackoverflow.com/questions/3232943/update-value-of-a-nested-dictionary-of-varying-depth/3233356#3233356 Parameters ---------- a : dict Dictionary to be merged. b : dict Dictionary to be merged. Returns ------- c : dict Merged dictionary. Examples -------- >>> from gammapy.utils.scripts import recursive_merge_dicts >>> a = dict(a=42, b=dict(c=43, e=44)) >>> b = dict(d=99, b=dict(c=50, g=98)) >>> c = recursive_merge_dicts(a, b) >>> print(c) {'a': 42, 'b': {'c': 50, 'e': 44, 'g': 98}, 'd': 99} """ c = a.copy() for k, v in b.items(): if k in c and isinstance(c[k], dict): c[k] = recursive_merge_dicts(c[k], v) else: c[k] = v return c
def requires_module(module_name): """ Decorator that conditionally enables a method or property based on the availability of a module. If the specified module is available, the decorated method or property is returned as-is. If the module is not available: - For methods: replaces the method with one that raises ImportError when called. - For properties: replaces the property with one that raises ImportError when accessed. Parameters ---------- module_name : str The name of the module to check for. Returns ------- function or property The original object if the module is available, otherwise a fallback. """ def decorator(obj): try: __import__(module_name) return obj # Module is available except ImportError: if isinstance(obj, property): return property( lambda self: raise_import_error(module_name, is_property=True) ) elif isinstance(obj, (types.FunctionType, types.MethodType)): @functools.wraps(obj) def wrapper(*args, **kwargs): raise_import_error(module_name) return wrapper else: raise TypeError( "requires_module can only be used on methods or properties." ) return decorator def raise_import_error(module_name, is_property=False): """ Raises an ImportError with a descriptive message about a missing module. Parameters ---------- module_name : str The name of the required module. is_property : bool Whether the error is for a property (affects the error message). """ kind = "property" if is_property else "method" raise ImportError(f"The '{module_name}' module is required to use this {kind}.") # Mapping of AST operators to NumPy functions _OPERATORS = { ast.And: np.logical_and, ast.Or: np.logical_or, ast.Eq: operator.eq, ast.NotEq: operator.ne, ast.Lt: operator.lt, ast.LtE: operator.le, ast.Gt: operator.gt, ast.GtE: operator.ge, ast.In: lambda a, b: np.isin(a, b), ast.NotIn: lambda a, b: ~np.isin(a, b), }
[docs] def logic_parser(table, expression): """ Parse and apply a logical expression to filter rows from an Astropy Table. This function evaluates a logical expression on each row of the input table and returns a new table containing only the rows that satisfy the expression. The expression can reference any column in the table by name and supports logical operators (``and``, ``or``), comparison operators (``<``, ``<=``, ``>``, ``>=``, ``==``, ``!=``, ``in``), lists, constants, and chained comparisons (e.g. ``1 < OBS_ID < 3``). Chained comparisons follow standard Python semantics and are equivalent to combining individual comparisons with ``and``. For example, ``1 < OBS_ID < 3`` is equivalent to ``(OBS_ID > 1) and (OBS_ID < 3)``. Parameters ---------- table : `~astropy.table.Table` The input table to filter. expression : str The logical expression to evaluate on each row. The expression can reference any column in the table by name. Returns ------- table : `~astropy.table.Table` A table view containing only the rows that satisfy the expression. If no rows match the condition, an empty table with the same column names and data types as the input table is returned. Examples -------- Given a table with columns 'OBS_ID' and 'EVENT_TYPE': >>> from astropy.table import Table >>> data = {'OBS_ID': [1, 2, 3, 4], 'EVENT_TYPE': ['1', '3', '4', '2']} >>> table = Table(data) Standard logical expression: >>> expression = '(OBS_ID < 3) and (OBS_ID > 1)' >>> logic_parser(table, expression) <Table length=1> OBS_ID EVENT_TYPE int64 str1 ------ ---------- 2 3 Using chained comparisons: >>> expression = '1 < OBS_ID < 3' >>> logic_parser(table, expression) <Table length=1> OBS_ID EVENT_TYPE int64 str1 ------ ---------- 2 3 Combining chained comparisons with other conditions: >>> expression = '(1 < OBS_ID < 4) and (EVENT_TYPE in ["3", "4"])' >>> logic_parser(table, expression) <Table length=2> OBS_ID EVENT_TYPE int64 str1 ------ ---------- 2 3 3 4 """ def handle_boolop(node): op_func = _OPERATORS[type(node.op)] values = [eval_node(v) for v in node.values] result = values[0] for value in values[1:]: result = op_func(result, value) return result def handle_compare(node): left = eval_node(node.left) parts = [] for op, comparator in zip(node.ops, node.comparators): right = eval_node(comparator) op_func = _OPERATORS[type(op)] parts.append(op_func(left, right)) left = right result = parts[0] for part in parts[1:]: result = np.logical_and(result, part) return result def handle_name(node): if node.id not in table.colnames: raise KeyError( f"Column '{node.id}' not found in the table. " f"Available columns: {table.colnames}" ) return table[node.id] def handle_constant(node): return node.value def handle_list(node): return [eval_node(elt) for elt in node.elts] handlers = { ast.BoolOp: handle_boolop, ast.Compare: handle_compare, ast.Name: handle_name, ast.Constant: handle_constant, ast.List: handle_list, } def eval_node(node): handler = handlers.get(type(node)) if handler is None: raise ValueError(f"Unsupported expression type: {type(node)}") return handler(node) expr_ast = ast.parse(expression, mode="eval") mask = eval_node(expr_ast.body) return table[mask]
def method_wrapper(func): """ Wrap a function for use as a method while preserving its metadata. This utility wraps a function so it can be assigned as a method on other classes (or instances) and still calls the original function with the receiving object as the first argument. The wrapper copies the original function's metadata (e.g. docstring, name, module, annotations), which makes the wrapped method appear in introspection and documentation like the original. """ def wrapper(self, *args, **kwargs): return func(self, *args, **kwargs) functools.update_wrapper(wrapper, func) return wrapper