# Licensed under a 3-clause BSD style license - see LICENSE.rstimportnumpyasnpimportmatplotlib.pyplotasplt# taken from the matploltlib documentation# https://matplotlib.org/3.1.0/gallery/images_contours_and_fields/image_annotated_heatmap.html#sphx-glr-gallery-images-contours-and-fields-image-annotated-heatmap-py__all__=["annotate_heatmap","plot_heatmap",]
[docs]defplot_heatmap(data,row_labels,col_labels,ax=None,cbar_kw=None,cbarlabel="",**kwargs):""" Create a heatmap from a numpy array and two lists of labels. Parameters ---------- data : `~numpy.ndarray` Data array. row_labels : list or `~numpy.ndarray` List or array of labels for the rows. col_labels : list or `~numpy.ndarray` List or array of labels for the columns. ax : `matplotlib.axes.Axes`, optional Axis instance to which the heatmap is plotted. Default is None. If None, the current one is used. cbar_kw : dict, optional A dictionary with arguments to `matplotlib.Figure.colorbar`. Default is None. cbarlabel : str, optional The label for the color bar. Default is "". **kwargs : dict, optional Other keyword arguments forwarded to `matplotlib.axes.Axes.imshow`. """ifaxisNone:ax=plt.gca()ifcbar_kwisNone:cbar_kw={}# Plot the heatmapim=ax.imshow(data,**kwargs)# Create colorbarcbar=ax.figure.colorbar(im,ax=ax,**cbar_kw)cbar.ax.set_ylabel(cbarlabel,rotation=-90,va="bottom")# We want to show all ticks...ax.set_xticks(np.arange(data.shape[1]))ax.set_yticks(np.arange(data.shape[0]))# ... and label them with the respective list entries.ax.set_xticklabels(col_labels)ax.set_yticklabels(row_labels)# Let the horizontal axes labeling appear on top.ax.tick_params(top=True,bottom=False,labeltop=True,labelbottom=False)# Rotate the tick labels and set their alignment.plt.setp(ax.get_xticklabels(),rotation=-30,ha="right",rotation_mode="anchor")# Turn spines off and create white grid.foredge,spineinax.spines.items():spine.set_visible(False)ax.set_xticks(np.arange(data.shape[1]+1)-0.5,minor=True)ax.set_yticks(np.arange(data.shape[0]+1)-0.5,minor=True)ax.grid(which="minor",color="w",linestyle="-",linewidth=1.5)ax.tick_params(which="minor",bottom=False,left=False)returnim,cbar
[docs]defannotate_heatmap(im,data=None,valfmt="{x:.2f}",textcolors=("black","white"),threshold=None,**textkw,):""" A function to annotate a heatmap. Parameters ---------- im The AxesImage to be labeled. data : `~numpy.ndarray`, optional Data used to annotate. Default is None. If None, the image's data is used. valfmt : str format or `matplotlib.ticker.Formatter`, optional The format of the annotations inside the heatmap. This should either use the string format method, e.g. "$ {x:.2f}" or be a `matplotlib.ticker.Formatter` instance. Default is "{x:.2f}". textcolors : list or `~numpy.ndarray`, optional Two color specifications. The first is used for values below a threshold, the second for those above. Default is ["black", "white"]. threshold : float, optional Value in data units according to which the colors from textcolors are applied. Default is None. If None the middle of the colormap is used as separation. **kwargs : dict, optional Other keyword arguments forwarded to each call to `text` used to create the text labels. """importmatplotlibifnotisinstance(data,(list,np.ndarray)):data=im.get_array()# Normalize the threshold to the images color range.ifthresholdisnotNone:threshold=im.norm(threshold)else:threshold=im.norm(data.max())/2.0# Set default alignment to center, but allow it to be# overwritten by textkw.kw=dict(horizontalalignment="center",verticalalignment="center")kw.update(textkw)# Get the formatter in case a string is suppliedifisinstance(valfmt,str):valfmt=matplotlib.ticker.StrMethodFormatter(valfmt)# Loop over the data and create a `Text` for each "pixel".# Change the text's color depending on the data.texts=[]foriinrange(data.shape[0]):forjinrange(data.shape[1]):kw.update(color=textcolors[int(im.norm(data[i,j])>threshold)])text=im.axes.text(j,i,valfmt(data[i,j],None),**kw)texts.append(text)returntexts