Making scikit Confusion Matrices Easier to Understand

I was looking at logistic regression with the scikit-learn (scikit or sklearn for short) library. There is a built-in scikit confusion_matrix(y_actuals, y_predicteds) function to compute and display a confusion matrix. But the output of printing the result of confusion_matrix() isn’t very easy to understand.

I ran a demo with this code:

from sklearn.metrics import confusion_matrix
# get test_y actual data
y_predicteds = model.predict(test_x)
cm = confusion_matrix(test_y, y_predicteds)
print("Confusion matrix raw: ")

The output was:

Confusion matrix raw:
 [[17  9]
  [ 2 12]]

It’s not clear which counts are which. As it turns out, the scikit documentation says, “Confusion matrix whose i-th row and j-th column entry indicates the number of samples with true label being i-th class and predicted label being j-th class.” In other words the entries are:

actual 0  |   17    9
actual 1  |    2   12
predicted      0    1

I coded up a show_confusion(cm) function to display a confusion matrix cm with some rudimentary labels. The code is:

def show_confusion(cm):
  ct_act0_pred0 = cm[0][0]  # TN
  ct_act0_pred1 = cm[0][1]  # FP wrongly predicted as pos
  ct_act1_pred0 = cm[1][0]  # FN wrongly predicted as neg 
  ct_act1_pred1 = cm[1][1]  # TP
  print("actual 0  | %4d %4d" % (ct_act0_pred0, ct_act0_pred1))
  print("actual 1  | %4d %4d" % (ct_act1_pred0, ct_act1_pred1))
  print("           ----------")
  print("predicted      0    1")

This function is hard-coded for binary classification. Here’s a general version that works for both binary classification and multi-class classification (three or more label values):

def show_confusion(cm):
  dim = len(cm)
  mx = np.max(cm)             # largest count in cm
  wid = len(str(mx)) + 1      # width to print
  fmt = "%" + str(wid) + "d"  # like "%3d"
  for i in range(dim):
    print("actual   ", end="")
    print("%3d:" % i, end="")
    for j in range(dim):
      print(fmt % cm[i][j], end="")
  print("predicted    ", end="")
  for j in range(dim):
    print(fmt % j, end="")

Making a scikit confusion matrix less confusing — good fun.

See for the complete code.

Confused dogs matrix.

This entry was posted in Scikit. Bookmark the permalink.

Leave a Reply

Please log in using one of these methods to post your comment: Logo

You are commenting using your account. Log Out /  Change )

Twitter picture

You are commenting using your Twitter account. Log Out /  Change )

Facebook photo

You are commenting using your Facebook account. Log Out /  Change )

Connecting to %s