# Licensed under a 3-clause BSD style license - see LICENSE.rst
import numpy as np
import matplotlib.pyplot as plt
# taken from the matploltlib documentation
# https://matplotlib.org/3.1.0/gallery/images_contours_and_fields/image_annotated_heatmap.html#sphx-glr-gallery-images-contours-and-fields-image-annotated-heatmap-py
__all__ = [
"annotate_heatmap",
"plot_heatmap",
]
[docs]
def plot_heatmap(
data, row_labels, col_labels, ax=None, cbar_kw=None, cbarlabel="", **kwargs
):
"""
Create a heatmap from a numpy array and two lists of labels.
Parameters
----------
data : `~numpy.ndarray`
Data array.
row_labels : list or `~numpy.ndarray`
List or array of labels for the rows.
col_labels : list or `~numpy.ndarray`
List or array of labels for the columns.
ax : `matplotlib.axes.Axes`, optional
Axis instance to which the heatmap is plotted. Default is None. If None, the current one is used.
cbar_kw : dict, optional
A dictionary with arguments to `matplotlib.Figure.colorbar`. Default is None.
cbarlabel : str, optional
The label for the color bar. Default is "".
**kwargs : dict, optional
Other keyword arguments forwarded to `matplotlib.axes.Axes.imshow`.
"""
if ax is None:
ax = plt.gca()
if cbar_kw is None:
cbar_kw = {}
# Plot the heatmap
im = ax.imshow(data, **kwargs)
# Create colorbar
cbar = ax.figure.colorbar(im, ax=ax, **cbar_kw)
cbar.ax.set_ylabel(cbarlabel, rotation=-90, va="bottom")
# We want to show all ticks...
ax.set_xticks(np.arange(data.shape[1]))
ax.set_yticks(np.arange(data.shape[0]))
# ... and label them with the respective list entries.
ax.set_xticklabels(col_labels)
ax.set_yticklabels(row_labels)
# Let the horizontal axes labeling appear on top.
ax.tick_params(top=True, bottom=False, labeltop=True, labelbottom=False)
# Rotate the tick labels and set their alignment.
plt.setp(ax.get_xticklabels(), rotation=-30, ha="right", rotation_mode="anchor")
# Turn spines off and create white grid.
for edge, spine in ax.spines.items():
spine.set_visible(False)
ax.set_xticks(np.arange(data.shape[1] + 1) - 0.5, minor=True)
ax.set_yticks(np.arange(data.shape[0] + 1) - 0.5, minor=True)
ax.grid(which="minor", color="w", linestyle="-", linewidth=1.5)
ax.tick_params(which="minor", bottom=False, left=False)
return im, cbar
[docs]
def annotate_heatmap(
im,
data=None,
valfmt="{x:.2f}",
textcolors=("black", "white"),
threshold=None,
**textkw,
):
"""
A function to annotate a heatmap.
Parameters
----------
im
The AxesImage to be labeled.
data : `~numpy.ndarray`, optional
Data used to annotate. Default is None. If None, the image's data is used.
valfmt : str format or `matplotlib.ticker.Formatter`, optional
The format of the annotations inside the heatmap. This should either
use the string format method, e.g. "$ {x:.2f}"
or be a `matplotlib.ticker.Formatter` instance. Default is "{x:.2f}".
textcolors : list or `~numpy.ndarray`, optional
Two color specifications. The first is used for
values below a threshold, the second for those above. Default is ["black", "white"].
threshold : float, optional
Value in data units according to which the colors from textcolors are
applied. Default is None. If None the middle of the colormap is used as
separation.
**kwargs : dict, optional
Other keyword arguments forwarded to each call to `text` used to create
the text labels.
"""
import matplotlib
if not isinstance(data, (list, np.ndarray)):
data = im.get_array()
# Normalize the threshold to the images color range.
if threshold is not None:
threshold = im.norm(threshold)
else:
threshold = im.norm(data.max()) / 2.0
# Set default alignment to center, but allow it to be
# overwritten by textkw.
kw = dict(horizontalalignment="center", verticalalignment="center")
kw.update(textkw)
# Get the formatter in case a string is supplied
if isinstance(valfmt, str):
valfmt = matplotlib.ticker.StrMethodFormatter(valfmt)
# Loop over the data and create a `Text` for each "pixel".
# Change the text's color depending on the data.
texts = []
for i in range(data.shape[0]):
for j in range(data.shape[1]):
kw.update(color=textcolors[int(im.norm(data[i, j]) > threshold)])
text = im.axes.text(j, i, valfmt(data[i, j], None), **kw)
texts.append(text)
return texts