You can now log precision recall and ROC curves, and confusion matrices natively using Weights & Biases. You can also use our heatmaps to create attention maps.

Try it out in a colab notebook →

ROC and PR curves in wandb.log()

Confusion Matrix

Computes the confusion matrix to evaluate the accuracy of a classification. It's useful for assessing the quality of model predictions and finding patterns in the predictions the model gets wrong.

The diagonal represents the predictions the model got right, i.e. where the actual label is equal to the predicted label.


# Confusion Matrices
wandb.sklearn.plot_confusion_matrix(y_test, y_pred, nb.classes_)

Precision Recall Curve

Computes the tradeoff between precision and recall for different thresholds. A high area under the curve represents both high recall and high precision, where high precision relates to a low false positive rate, and high recall relates to a low false negative rate.

High scores for both show that the classifier is returning accurate results (high precision), as well as returning a majority of all positive results (high recall). PR curves are useful when the classes are very imbalanced.


# Precision Recall
wandb.log({'pr': wandb.plots.precision_recall(y_test, y_probas, nb.classes_)})

ROC Curve

ROC curves plot true positive rate (y-axis) vs false positive rate (x-axis). The ideal score is a TPR = 1 and FPR = 0, which is the point on the top left. Typically we calculate the area under the ROC curve (AUC-ROC), and the greater the AUC-ROC the better.

Here we can see our model is slightly better at predicting the class Negative emotion, as evidenced by the larger area under the ROC.


wandb.log({'roc': wandb.plots.ROC(y_test, y_probas, nb.classes_)})

Heat Maps

Heatmaps that can be used to make attention maps, confusion matrices et all.

# ExplainText

         matrix_values (arr): 2D dataset of shape x_labels * y_labels, containing
                            heatmap values that can be coerced into an ndarray.
         x_labels  (list): Named labels for rows (x_axis).
         y_labels  (list): Named labels for columns (y_axis).
         show_text (bool): Show text values in heatmap cells.
wandb.log({'heatmap_with_text': wandb.plots.HeatMap(x_labels, y_labels, matrix_values, show_text=False)})

Here's an example of the attention maps for a Neural Machine Translation model that converts from english → french. We draw attention maps at the 2nd, 20th epochs and 100th. Here we can see that the model starts out by not knowing which words to pay attention to (and uses <res> to predict all words, and slowly learns which ones to pay attention to over the course of the next 100 epochs.