# Licensed under a 3-clause BSD style license - see LICENSE.rst"""Utilities to create scripts and command-line tools."""importastimportcodecsimportoperatorimportos.pathimportfunctoolsimporttypesimportwarningsimportnumpyasnpfrombase64importurlsafe_b64encodefrompathlibimportPathfromuuidimportuuid4importyamlfromgammapy.utils.checkimportadd_checksum,verify_checksum__all__=["from_yaml","get_images_paths","make_path","read_yaml","recursive_merge_dicts","to_yaml","write_yaml",]PATH_DOCS=Path(__file__).resolve().parent/".."/".."/"docs"SKIP=["_static","_build","_checkpoints","docs/user-guide/model-gallery/"]YAML_FORMAT=dict(sort_keys=False,indent=4,width=80,default_flow_style=False)
[docs]defget_images_paths(folder=PATH_DOCS):"""Generator yields a Path for each image used in notebook. Parameters ---------- folder : str Folder where to search. """foriinPath(folder).rglob("images/*"):ifnotany(sinstr(i)forsinSKIP):yieldi.resolve()
[docs]deffrom_yaml(text,sort_keys=False,checksum=False):"""Read YAML file. Parameters ---------- text : str yaml str sort_keys : bool, optional Whether to sort keys. Default is False. checksum : bool Whether to perform checksum verification. Default is False. Returns ------- data : dict YAML file content as a dictionary. """data=yaml.safe_load(text)checksum_str=data.pop("checksum",None)ifchecksum:yaml_format=YAML_FORMAT.copy()yaml_format["sort_keys"]=sort_keysyaml_str=yaml.dump(data,**yaml_format)ifnotverify_checksum(yaml_str,checksum_str):warnings.warn("Checksum verification failed.",UserWarning)returndata
[docs]defread_yaml(filename,logger=None,checksum=False):"""Read YAML file. Parameters ---------- filename : `~pathlib.Path` Filename. logger : `~logging.Logger` Logger. checksum : bool Whether to perform checksum verification. Default is False. Returns ------- data : dict YAML file content as a dictionary. """path=make_path(filename)ifloggerisnotNone:logger.info(f"Reading {path}")text=path.read_text()returnfrom_yaml(text,checksum=checksum)
[docs]defto_yaml(dictionary,sort_keys=False):"""dict to yaml Parameters ---------- dictionary : dict Python dictionary. sort_keys : bool, optional Whether to sort keys. Default is False. """fromgammapy.utils.metadataimportCreatorMetaDatayaml_format=YAML_FORMAT.copy()yaml_format["sort_keys"]=sort_keystext=yaml.safe_dump(dictionary,**yaml_format)creation=CreatorMetaData()returntext+creation.to_yaml()
[docs]defwrite_yaml(text,filename,logger=None,sort_keys=False,checksum=False,overwrite=False):"""Write YAML file. Parameters ---------- text : str yaml str filename : `~pathlib.Path` Filename. logger : `~logging.Logger`, optional Logger. Default is None. sort_keys : bool, optional Whether to sort keys. Default is True. checksum : bool, optional Whether to add checksum keyword. Default is False. overwrite : bool, optional Overwrite existing file. Default is False. """ifchecksum:text=add_checksum(text,sort_keys=sort_keys)path=make_path(filename)path.parent.mkdir(exist_ok=True)ifpath.exists()andnotoverwrite:raiseIOError(f"File exists already: {path}")ifloggerisnotNone:logger.info(f"Writing {path}")path.write_text(text)
defmake_name(name=None):"""Make a dataset name."""ifnameisNone:name=urlsafe_b64encode(codecs.decode(uuid4().hex,"hex")).decode()[:8]whilename[0]=="_":name=urlsafe_b64encode(codecs.decode(uuid4().hex,"hex")).decode()[:8]ifnotisinstance(name,str):raiseValueError("Name argument must be a string, "f"got '{name}', which is of type '{type(name)}'")returnname
[docs]defmake_path(path):"""Expand environment variables on `~pathlib.Path` construction. Parameters ---------- path : str, `pathlib.Path` Path to expand. """# TODO: raise error or warning if environment variables that don't resolve are used# e.g. "spam/$DAMN/ham" where `$DAMN` is not defined# Otherwise this can result in cryptic errors later onifpathisNone:returnNoneelse:returnPath(os.path.expandvars(path))
[docs]defrecursive_merge_dicts(a,b):"""Recursively merge two dictionaries. Entries in 'b' override entries in 'a'. The built-in update function cannot be used for hierarchical dicts, see: http://stackoverflow.com/questions/3232943/update-value-of-a-nested-dictionary-of-varying-depth/3233356#3233356 Parameters ---------- a : dict Dictionary to be merged. b : dict Dictionary to be merged. Returns ------- c : dict Merged dictionary. Examples -------- >>> from gammapy.utils.scripts import recursive_merge_dicts >>> a = dict(a=42, b=dict(c=43, e=44)) >>> b = dict(d=99, b=dict(c=50, g=98)) >>> c = recursive_merge_dicts(a, b) >>> print(c) {'a': 42, 'b': {'c': 50, 'e': 44, 'g': 98}, 'd': 99} """c=a.copy()fork,vinb.items():ifkincandisinstance(c[k],dict):c[k]=recursive_merge_dicts(c[k],v)else:c[k]=vreturnc
defrequires_module(module_name):""" Decorator that conditionally enables a method or property based on the availability of a module. If the specified module is available, the decorated method or property is returned as-is. If the module is not available: - For methods: replaces the method with one that raises ImportError when called. - For properties: replaces the property with one that raises ImportError when accessed. Parameters ---------- module_name : str The name of the module to check for. Returns ------- function or property The original object if the module is available, otherwise a fallback. """defdecorator(obj):try:__import__(module_name)returnobj# Module is availableexceptImportError:ifisinstance(obj,property):returnproperty(lambdaself:raise_import_error(module_name,is_property=True))elifisinstance(obj,(types.FunctionType,types.MethodType)):@functools.wraps(obj)defwrapper(*args,**kwargs):raise_import_error(module_name)returnwrapperelse:raiseTypeError("requires_module can only be used on methods or properties.")returndecoratordefraise_import_error(module_name,is_property=False):""" Raises an ImportError with a descriptive message about a missing module. Parameters ---------- module_name : str The name of the required module. is_property : bool Whether the error is for a property (affects the error message). """kind="property"ifis_propertyelse"method"raiseImportError(f"The '{module_name}' module is required to use this {kind}.")# Mapping of AST operators to NumPy functions_OPERATORS={ast.And:np.logical_and,ast.Or:np.logical_or,ast.Eq:operator.eq,ast.NotEq:operator.ne,ast.Lt:operator.lt,ast.LtE:operator.le,ast.Gt:operator.gt,ast.GtE:operator.ge,ast.In:lambdaa,b:np.isin(a,b),ast.NotIn:lambdaa,b:~np.isin(a,b),}deflogic_parser(table,expression):""" Parse and apply a logical expression to filter rows from an Astropy Table. This function evaluates a logical expression on each row of the input table and returns a new table containing only the rows that satisfy the expression. The expression can reference any column in the table by name and supports logical operators (`and`, `or`), comparison operators (`<`, `>`, `==`, `!=`, `in`), lists, and constants. Parameters ---------- table : `~astropy.table.Table` The input table to filter. expression : str The logical expression to evaluate on each row. The expression can reference any column in the table by name. Returns ------- table : `~astropy.table.Table` A table view containing only the rows that satisfy the expression. If no rows match the condition, an empty table with the same column names and data types as the input table is returned. Examples -------- Given a table with columns 'OBS_ID' and 'EVENT_TYPE': >>> from astropy.table import Table >>> data = {'OBS_ID': [1, 2, 3, 4], 'EVENT_TYPE': ['1', '3', '4', '2']} >>> table = Table(data) >>> expression = '(OBS_ID < 3) and (OBS_ID > 1) and ((EVENT_TYPE in ["3", "4"]) or (EVENT_TYPE == "3"))' >>> filtered_table = logic_parser(table, expression) >>> print(filtered_table) OBS_ID EVENT_TYPE ------ ---------- 2 3 """defeval_node(node):ifisinstance(node,ast.BoolOp):op_func=_OPERATORS[type(node.op)]values=[eval_node(v)forvinnode.values]result=values[0]forvinvalues[1:]:result=op_func(result,v)returnresultelifisinstance(node,ast.Compare):left=eval_node(node.left)forop,comparatorinzip(node.ops,node.comparators):right=eval_node(comparator)op_func=_OPERATORS[type(op)]left=op_func(left,right)returnleftelifisinstance(node,ast.Name):ifnode.idnotintable.colnames:raiseKeyError(f"Column '{node.id}' not found in the table. Available columns: {table.colnames}")returntable[node.id]elifisinstance(node,ast.Constant):returnnode.valueelifisinstance(node,ast.List):return[eval_node(elt)foreltinnode.elts]else:raiseValueError(f"Unsupported expression type: {type(node)}")expr_ast=ast.parse(expression,mode="eval")mask=eval_node(expr_ast.body)returntable[mask]