Source code for megadetector.visualization.plot_utils

"""

plot_utils.py

Utility functions for plotting, particularly for plotting confusion matrices
and precision-recall curves.

"""

#%% Imports

import numpy as np

# This also imports mpl.{cm, axes, colors}
import matplotlib.figure


#%% Plotting functions

[docs] def plot_confusion_matrix(matrix, classes, normalize=False, title='Confusion matrix', cmap=matplotlib.cm.Blues, vmax=None, use_colorbar=True, y_label=True, fmt= '{:.0f}', fig=None): """ Plots a confusion matrix. Args: matrix (np.ndarray): shape [num_classes, num_classes], confusion matrix where rows are ground-truth classes and columns are predicted classes classes (list of str): class names for each row/column normalize (bool, optional): whether to perform row-wise normalization; by default, assumes values in the confusion matrix are percentages title (str, optional): figure title cmap (matplotlib.colors.colormap, optional): colormap for cell backgrounds vmax (float, optional): value corresponding to the largest value of the colormap; if None, the maximum value in [matrix] will be used use_colorbar (bool, optional): whether to show colorbar y_label (bool, optional): whether to show class names on the y axis fmt (str, optional): format string for rendering numeric values fig (Figure, optional): existing figure to which we should render, otherwise creates a new figure Returns: matplotlib.figure.Figure: the figure we rendered to or created """ num_classes = matrix.shape[0] assert matrix.shape[1] == num_classes assert len(classes) == num_classes normalized_matrix = matrix.astype(np.float64) / ( matrix.sum(axis=1, keepdims=True) + 1e-7) if normalize: matrix = normalized_matrix fig_h = 3 + 0.3 * num_classes fig_w = fig_h if use_colorbar: fig_w += 0.5 if fig is None: fig = matplotlib.figure.Figure(figsize=(fig_w, fig_h), tight_layout=True) ax = fig.subplots(1, 1) im = ax.imshow(normalized_matrix, interpolation='nearest', cmap=cmap, vmax=vmax) ax.set_title(title) if use_colorbar: cbar = fig.colorbar(im, fraction=0.046, pad=0.04, ticks=[0.0, 0.25, 0.5, 0.75, 1.0]) cbar.set_ticklabels(['0%', '25%', '50%', '75%', '100%']) tick_marks = np.arange(num_classes) ax.set_xticks(tick_marks) ax.set_yticks(tick_marks) ax.set_xticklabels(classes, rotation=90) ax.set_xlabel('Predicted class') if y_label: ax.set_yticklabels(classes) ax.set_ylabel('Ground-truth class') for i, j in np.ndindex(matrix.shape): v = matrix[i, j] ax.text(j, i, fmt.format(v), horizontalalignment='center', verticalalignment='center', color='white' if normalized_matrix[i, j] > 0.5 else 'black') return fig
# ...def plot_confusion_matrix(...)
[docs] def plot_precision_recall_curve(precisions, recalls, title='Precision/recall curve', xlim=(0.0,1.05), ylim=(0.0,1.05)): """ Plots a precision/recall curve given lists of (ordered) precision and recall values. Args: precisions (list of float): precision for corresponding recall values, should have same length as [recalls]. recalls (list of float): recall for corresponding precision values, should have same length as [precisions]. title (str, optional): plot title xlim (tuple, optional): x-axis limits as a length-2 tuple ylim (tuple, optional): y-axis limits as a length-2 tuple Returns: matplotlib.figure.Figure: the (new) figure """ assert len(precisions) == len(recalls) fig = matplotlib.figure.Figure(tight_layout=True) ax = fig.subplots(1, 1) ax.step(recalls, precisions, color='b', alpha=0.2, where='post') ax.fill_between(recalls, precisions, alpha=0.2, color='b', step='post') ax.set_xlabel('Recall') ax.set_ylabel('Precision') ax.set_title(title) ax.set_xlim(xlim[0],xlim[1]) ax.set_ylim(ylim[0],ylim[1]) return fig
# ...def plot_precision_recall_curve(...)
[docs] def plot_stacked_bar_chart(data, series_labels=None, col_labels=None, x_label=None, y_label=None, log_scale=False): """ Plot a stacked bar chart, for plotting e.g. species distribution across locations. Reference: https://stackoverflow.com/q/44309507 Args: data (np.ndarray or list of list): data to plot; rows (series) are species, columns are locations series_labels (list of str, optional): series labels, typically species names col_labels (list of str, optional): column labels, typically location names x_label (str, optional): x-axis label y_label (str, optional): y-axis label log_scale (bool, optional): whether to plot the y axis in log-scale Returns: matplotlib.figure.Figure: the (new) figure """ data = np.asarray(data) num_series, num_columns = data.shape ind = np.arange(num_columns) fig = matplotlib.figure.Figure(tight_layout=True) ax = fig.subplots(1, 1) colors = matplotlib.cm.rainbow(np.linspace(0, 1, num_series)) # stacked bar charts are made with each segment starting from a y position cumulative_size = np.zeros(num_columns) for i_row, row_data in enumerate(data): if series_labels is None: label = 'series_{}'.format(str(i_row).zfill(2)) else: label = series_labels[i_row] ax.bar(ind, row_data, bottom=cumulative_size, label=label, color=colors[i_row]) cumulative_size += row_data if (col_labels is not None) and (len(col_labels) < 25): ax.set_xticks(ind) ax.set_xticklabels(col_labels, rotation=90) elif (col_labels is not None): ax.set_xticks(list(range(0, len(col_labels), 20))) ax.set_xticklabels(col_labels[::20], rotation=90) if x_label is not None: ax.set_xlabel(x_label) if y_label is not None: ax.set_ylabel(y_label) if log_scale: ax.set_yscale('log') # To fit the legend in, shrink current axis by 20% box = ax.get_position() ax.set_position([box.x0, box.y0, box.width * 0.8, box.height]) # Put a legend to the right of the current axis ax.legend(loc='center left', bbox_to_anchor=(0.99, 0.5), frameon=False) return fig
# ...def plot_stacked_bar_chart(...)
[docs] def calibration_ece(true_scores, pred_scores, num_bins): r""" Expected calibration error (ECE) as defined in equation (3) of Guo et al. "On Calibration of Modern Neural Networks." (2017). Implementation modified from sklearn.calibration.calibration_curve() in order to implement ECE calculation. See: https://github.com/scikit-learn/scikit-learn/issues/18268 Args: true_scores (list of int): true values, length N, binary-valued (0 = neg, 1 = pos) pred_scores (list of float): predicted confidence values, length N, pred_scores[i] is the predicted confidence that example i is positive num_bins (int): number of bins to use (`M` in eq. (3) of Guo 2017) Returns: tuple: a length-three tuple containing: - accs: np.ndarray, shape [M], type float64, accuracy in each bin, M <= num_bins because bins with no samples are not returned - confs: np.ndarray, shape [M], type float64, mean model confidence in each bin - ece: float, expected calibration error """ assert len(true_scores) == len(pred_scores) bins = np.linspace(0., 1. + 1e-8, num=num_bins + 1) binids = np.digitize(pred_scores, bins) - 1 bin_sums = np.bincount(binids, weights=pred_scores, minlength=len(bins)) bin_true = np.bincount(binids, weights=true_scores, minlength=len(bins)) bin_total = np.bincount(binids, minlength=len(bins)) nonzero = bin_total != 0 accs = bin_true[nonzero] / bin_total[nonzero] confs = bin_sums[nonzero] / bin_total[nonzero] weights = bin_total[nonzero] / len(true_scores) ece = np.abs(accs - confs) @ weights return accs, confs, ece
# ...def calibration_ece(...)
[docs] def plot_calibration_curve(true_scores, pred_scores, num_bins, name='calibration', plot_perf=True, plot_hist=True, ax=None, **fig_kwargs): """ Plots a calibration curve. Args: true_scores (list of int): true values, length N, binary-valued (0 = neg, 1 = pos) pred_scores (list of float): predicted confidence values, length N, pred_scores[i] is the predicted confidence that example i is positive num_bins (int): number of bins to use (`M` in eq. (3) of Guo 2017) name (str, optional): label in legend for the calibration curve plot_perf (bool, optional): whether to plot y=x line indicating perfect calibration plot_hist (bool, optional): whether to plot histogram of counts ax (Axes, optional): if given then no legend is drawn, and fig_kwargs are ignored fig_kwargs (dict): only used if [ax] is None Returns: matplotlib.figure.Figure: the (new) figure """ accs, confs, ece = calibration_ece(true_scores, pred_scores, num_bins) created_fig = False if ax is None: created_fig = True fig = matplotlib.figure.Figure(**fig_kwargs) ax = fig.subplots(1, 1) ax.plot(confs, accs, 's-', label=name) # 's-': squares on line ax.set(xlabel='Model confidence', ylabel='Actual accuracy', title=f'Calibration plot (ECE: {ece:.02g})') ax.set(xlim=[-0.05, 1.05], ylim=[-0.05, 1.05]) if plot_perf: ax.plot([0, 1], [0, 1], color='black', label='perfect calibration') ax.grid(True) if plot_hist: ax1 = ax.twinx() bins = np.linspace(0., 1. + 1e-8, num=num_bins + 1) counts = ax1.hist(pred_scores, alpha=0.5, label='histogram of examples', bins=bins, color='tab:red')[0] max_count = np.max(counts) ax1.set_ylim([-0.05 * max_count, 1.05 * max_count]) ax1.set_ylabel('Count') if created_fig: fig.legend(loc='upper left', bbox_to_anchor=(0.15, 0.85)) return ax.figure
# ...def plot_calibration_curve(...)