Source code for gammapy.visualization.heatmap

# Licensed under a 3-clause BSD style license - see LICENSE.rst
import numpy as np
import matplotlib.pyplot as plt

# taken from the matploltlib documentation
# https://matplotlib.org/3.1.0/gallery/images_contours_and_fields/image_annotated_heatmap.html#sphx-glr-gallery-images-contours-and-fields-image-annotated-heatmap-py

__all__ = [
    "annotate_heatmap",
    "plot_heatmap",
]


[docs] def plot_heatmap( data, row_labels, col_labels, ax=None, cbar_kw=None, cbarlabel="", **kwargs ): """ Create a heatmap from a numpy array and two lists of labels. Parameters ---------- data : `~numpy.ndarray` Data array. row_labels : list or `~numpy.ndarray` List or array of labels for the rows. col_labels : list or `~numpy.ndarray` List or array of labels for the columns. ax : `matplotlib.axes.Axes`, optional Axis instance to which the heatmap is plotted. Default is None. If None, the current one is used. cbar_kw : dict, optional A dictionary with arguments to `matplotlib.Figure.colorbar`. Default is None. cbarlabel : str, optional The label for the color bar. Default is "". **kwargs : dict, optional Other keyword arguments forwarded to `matplotlib.axes.Axes.imshow`. """ if ax is None: ax = plt.gca() if cbar_kw is None: cbar_kw = {} # Plot the heatmap im = ax.imshow(data, **kwargs) # Create colorbar cbar = ax.figure.colorbar(im, ax=ax, **cbar_kw) cbar.ax.set_ylabel(cbarlabel, rotation=-90, va="bottom") # We want to show all ticks... ax.set_xticks(np.arange(data.shape[1])) ax.set_yticks(np.arange(data.shape[0])) # ... and label them with the respective list entries. ax.set_xticklabels(col_labels) ax.set_yticklabels(row_labels) # Let the horizontal axes labeling appear on top. ax.tick_params(top=True, bottom=False, labeltop=True, labelbottom=False) # Rotate the tick labels and set their alignment. plt.setp(ax.get_xticklabels(), rotation=-30, ha="right", rotation_mode="anchor") # Turn spines off and create white grid. for edge, spine in ax.spines.items(): spine.set_visible(False) ax.set_xticks(np.arange(data.shape[1] + 1) - 0.5, minor=True) ax.set_yticks(np.arange(data.shape[0] + 1) - 0.5, minor=True) ax.grid(which="minor", color="w", linestyle="-", linewidth=1.5) ax.tick_params(which="minor", bottom=False, left=False) return im, cbar
[docs] def annotate_heatmap( im, data=None, valfmt="{x:.2f}", textcolors=("black", "white"), threshold=None, **textkw, ): """ A function to annotate a heatmap. Parameters ---------- im The AxesImage to be labeled. data : `~numpy.ndarray`, optional Data used to annotate. Default is None. If None, the image's data is used. valfmt : str format or `matplotlib.ticker.Formatter`, optional The format of the annotations inside the heatmap. This should either use the string format method, e.g. "$ {x:.2f}" or be a `matplotlib.ticker.Formatter` instance. Default is "{x:.2f}". textcolors : list or `~numpy.ndarray`, optional Two color specifications. The first is used for values below a threshold, the second for those above. Default is ["black", "white"]. threshold : float, optional Value in data units according to which the colors from textcolors are applied. Default is None. If None the middle of the colormap is used as separation. **kwargs : dict, optional Other keyword arguments forwarded to each call to `text` used to create the text labels. """ import matplotlib if not isinstance(data, (list, np.ndarray)): data = im.get_array() # Normalize the threshold to the images color range. if threshold is not None: threshold = im.norm(threshold) else: threshold = im.norm(data.max()) / 2.0 # Set default alignment to center, but allow it to be # overwritten by textkw. kw = dict(horizontalalignment="center", verticalalignment="center") kw.update(textkw) # Get the formatter in case a string is supplied if isinstance(valfmt, str): valfmt = matplotlib.ticker.StrMethodFormatter(valfmt) # Loop over the data and create a `Text` for each "pixel". # Change the text's color depending on the data. texts = [] for i in range(data.shape[0]): for j in range(data.shape[1]): kw.update(color=textcolors[int(im.norm(data[i, j]) > threshold)]) text = im.axes.text(j, i, valfmt(data[i, j], None), **kw) texts.append(text) return texts