Computing and Displaying a Confusion Matrix for a PyTorch Neural Network Multi-Class Classifier

After training a PyTorch multi-class classifier, it’s important to evaluate the accuracy of the trained model. Simple classification accuracy is OK but in many scenarios you want a so-called confusion matrix that gives details of the number of correct and wrong predictions for each target class label.

For example, suppose you’re predicting the political leaning (0 = conservative, 1 = moderate, 2 = liberal) of a person based on their sex (-1 = male, 1 = female), age (divided by 100), State (Michigan = 100, Nebraska = 010, Oklahoma = 001), and income (divided by $100,000). An example of a formatted confusion matrix might look like:

Computing model accuracy
Accuracy on training data = 0.8150
Accuracy on test data = 0.7500

Computing raw confusion matrix:
[[ 6  4  1]
 [ 1 13  0]
 [ 2  2 11]]

Formatted version:
actual     0:  6  4  1
actual     1:  1 13  0
actual     2:  2  2 11
------------
predicted      0  1  2

Here’s my function to compute a raw confusion matrix for a multi-class classifier:

def confusion_matrix_multi(model, ds, n_classes):
  if n_classes "lte" = 2:  # less-than-or-equal
    print("ERROR: n_classes must be 3 or greater ")
    return None

  cm = np.zeros((n_classes,n_classes), dtype=np.int64)
  for i in range(len(ds)):
    X = ds[i][0].reshape(1,-1)  # make it a batch
    Y = ds[i][1].reshape(1)  # actual class 0 1 or 2, 1D
    with T.no_grad():
      oupt = model(X)  # logits form
    pred_class = T.argmax(oupt)  # 0,1,2
    cm[Y][pred_class] += 1
  return cm

The function accepts a trained PyTorch classifier and a PyTorch Dataset object that is composed of either a Tuple or a Dictionary where the predictors are at [0] and the target labels are at [1]. The n_classes could be determined programmatically but it’s easier to pass that value in as a parameter.

Note: A function to compute a confusion matrix for a PyTorch binary classifier, where there are just two possible outcomes, uses slightly different code.

The raw confusion matrix is difficult to interpret so I wrote a function to format the matrix by adding some labels:

def show_confusion(cm):
  dim = len(cm)
  mx = np.max(cm)             # largest count in cm
  wid = len(str(mx)) + 1      # width to print
  fmt = "%" + str(wid) + "d"  # like "%3d"
  for i in range(dim):
    print("actual   ", end="")
    print("%3d:" % i, end="")
    for j in range(dim):
      print(fmt % cm[i][j], end="")
    print("")
  print("------------")
  print("predicted    ", end="")
  for j in range(dim):
    print(fmt % j, end="")
  print("")

If you do have a confusion matrix, it’s possible to compute an overall accuracy metric from it. The total number of data items is the sum of the entries in the matrix. The number of correct predictions is the sum of the entries on the main diagonal. For example:

def accuracy_from_confusion_multi(cm):
  # return (overall accuracy, list of class accuracies)
  N = np.sum(cm)  # total count
  dim = len(cm)
  row_sums = cm.sum(axis=1)  # collapse on cols, process rows

  n_correct = 0
  for i in range(dim):
    n_correct += cm[i][i]  # on the diagonal
  overall = n_correct / N

  class_accs = []
  for i in range(dim):
    class_accs.append(cm[i][i] / row_sums[i])

  return (overall, class_accs)

The output for the demo data looks like:

Computing test accuracies from confusion matrix
Accuracy on test data = 0.7500
Class accuracies:
   0 = 0.5455
   1 = 0.9286
   2 = 0.7333

Good fun. Demo code below. The training and test data can be found at https://jamesmccaffrey.wordpress.com/2022/09/01/multi-class-classification-using-pytorch-1-12-1-on-windows-10-11/.



Easter is coming soon. I grew up Catholic so Easter is important to me. Here are three photos of child confusion caused by disturbing Easter Bunnies.


# people_politics.py
# predict politics type from sex, age, state, income
# PyTorch 1.13.1-CPU Anaconda3-2022.10  Python 3.9.13
# Windows 10/11 

import numpy as np
import torch as T
device = T.device('cpu')  # apply to Tensor or Module

# -----------------------------------------------------------

class PeopleDataset(T.utils.data.Dataset):
  # sex  age    state    income   politics
  # -1   0.27   0  1  0   0.7610   2
  # +1   0.19   0  0  1   0.6550   0
  # sex: -1 = male, +1 = female
  # state: michigan, nebraska, oklahoma
  # politics: conservative, moderate, liberal

  def __init__(self, src_file):
    all_xy = np.loadtxt(src_file, usecols=range(0,7),
      delimiter="\t", comments="#", dtype=np.float32)
    tmp_x = all_xy[:,0:6]   # cols [0,6) = [0,5]
    tmp_y = all_xy[:,6]     # 1-D

    self.x_data = T.tensor(tmp_x, 
      dtype=T.float32).to(device)
    self.y_data = T.tensor(tmp_y,
      dtype=T.int64).to(device)  # 1-D

  def __len__(self):
    return len(self.x_data)

  def __getitem__(self, idx):
    preds = self.x_data[idx]
    trgts = self.y_data[idx] 
    return preds, trgts  # as a Tuple

# -----------------------------------------------------------

class Net(T.nn.Module):
  def __init__(self):
    super(Net, self).__init__()
    self.hid1 = T.nn.Linear(6, 10)  # 6-(10-10)-3
    self.hid2 = T.nn.Linear(10, 10)
    self.oupt = T.nn.Linear(10, 3)

    T.nn.init.xavier_uniform_(self.hid1.weight)
    T.nn.init.zeros_(self.hid1.bias)
    T.nn.init.xavier_uniform_(self.hid2.weight)
    T.nn.init.zeros_(self.hid2.bias)
    T.nn.init.xavier_uniform_(self.oupt.weight)
    T.nn.init.zeros_(self.oupt.bias)

  def forward(self, x):
    z = T.tanh(self.hid1(x))
    z = T.tanh(self.hid2(z))
    z = T.log_softmax(self.oupt(z), dim=1)  # NLLLoss() 
    return z

# -----------------------------------------------------------

def accuracy(model, ds):
  # assumes model.eval()
  # item-by-item version
  n_correct = 0; n_wrong = 0
  for i in range(len(ds)):
    X = ds[i][0].reshape(1,-1)  # make it a batch
    Y = ds[i][1].reshape(1)  # 0 1 or 2, 1D
    with T.no_grad():
      oupt = model(X)  # logits form

    big_idx = T.argmax(oupt)  # 0 or 1 or 2
    if big_idx == Y:
      n_correct += 1
    else:
      n_wrong += 1

  acc = (n_correct * 1.0) / (n_correct + n_wrong)
  return acc

# -----------------------------------------------------------

def accuracy_quick(model, dataset):
  # assumes model.eval()
  X = dataset[0:len(dataset)][0]
  # Y = T.flatten(dataset[0:len(dataset)][1])
  Y = dataset[0:len(dataset)][1]
  with T.no_grad():
    oupt = model(X)
  # (_, arg_maxs) = T.max(oupt, dim=1)
  arg_maxs = T.argmax(oupt, dim=1)  # argmax() is new
  num_correct = T.sum(Y==arg_maxs)
  acc = (num_correct * 1.0 / len(dataset))
  return acc.item()

# -----------------------------------------------------------

def confusion_matrix_multi(model, ds, n_classes):
  if n_classes "lte" 2:  # less-than-or-equal
    print("ERROR: n_classes must be 3 or greater ")
    return None

  cm = np.zeros((n_classes,n_classes), dtype=np.int64)
  for i in range(len(ds)):
    X = ds[i][0].reshape(1,-1)  # make it a batch
    Y = ds[i][1].reshape(1)  # actual class 0 1 or 2, 1D
    with T.no_grad():
      oupt = model(X)  # logits form
    pred_class = T.argmax(oupt)  # 0,1,2
    cm[Y][pred_class] += 1
  return cm

# -----------------------------------------------------------

def show_confusion(cm):
  dim = len(cm)
  mx = np.max(cm)             # largest count in cm
  wid = len(str(mx)) + 1      # width to print
  fmt = "%" + str(wid) + "d"  # like "%3d"
  for i in range(dim):
    print("actual   ", end="")
    print("%3d:" % i, end="")
    for j in range(dim):
      print(fmt % cm[i][j], end="")
    print("")
  print("------------")
  print("predicted    ", end="")
  for j in range(dim):
    print(fmt % j, end="")
  print("")

# -----------------------------------------------------------  

def accuracy_from_confusion_multi(cm):
  # return (overall accuracy, list of class accuracies)
  N = np.sum(cm)  # total count
  dim = len(cm)
  row_sums = cm.sum(axis=1)

  n_correct = 0
  for i in range(dim):
    n_correct += cm[i][i]  # on the diagonal
  overall = n_correct / N

  class_accs = []
  for i in range(dim):
    class_accs.append(cm[i][i] / row_sums[i])

  return (overall, class_accs)

# -----------------------------------------------------------  

def main():
  # 0. get started
  print("\nBegin People predict politics type ")
  T.manual_seed(1)
  np.random.seed(1)
  
  # 1. create DataLoader objects
  print("\nCreating People Datasets ")

  train_file = ".\\Data\\people_train.txt"
  train_ds = PeopleDataset(train_file)  # 200 rows

  test_file = ".\\Data\\people_test.txt"
  test_ds = PeopleDataset(test_file)    # 40 rows

  bat_size = 10
  train_ldr = T.utils.data.DataLoader(train_ds,
    batch_size=bat_size, shuffle=True)

# -----------------------------------------------------------

  # 2. create and compile network
  print("\nCreating 6-(10-10)-3 neural network ")
  net = Net().to(device)
  net.train()

# -----------------------------------------------------------

  # 3. train model
  max_epochs = 1000
  ep_log_interval = 200
  lrn_rate = 0.01

  loss_func = T.nn.NLLLoss()  # assumes log_softmax()
  optimizer = T.optim.SGD(net.parameters(), lr=lrn_rate)

  print("\nbat_size = %3d " % bat_size)
  print("loss = " + str(loss_func))
  print("optimizer = SGD")
  print("max_epochs = %3d " % max_epochs)
  print("lrn_rate = %0.3f " % lrn_rate)

  print("\nStarting training ")
  for epoch in range(0, max_epochs):
    # T.manual_seed(epoch+1)  # checkpoint reproducibility
    epoch_loss = 0  # for one full epoch

    for (batch_idx, batch) in enumerate(train_ldr):
      X = batch[0]  # inputs
      Y = batch[1]  # correct class/label/politics

      optimizer.zero_grad()
      oupt = net(X)
      loss_val = loss_func(oupt, Y)  # a tensor
      epoch_loss += loss_val.item()  # accumulate
      loss_val.backward()
      optimizer.step()

    if epoch % ep_log_interval == 0:
      print("epoch = %5d  |  loss = %10.4f" % \
        (epoch, epoch_loss))

  print("Training done ")

# -----------------------------------------------------------

  # 4. evaluate model accuracy
  print("\nComputing model accuracy")
  net.eval()
  acc_train = accuracy(net, train_ds)  # item-by-item
  print("Accuracy on training data = %0.4f" % acc_train)
  acc_test = accuracy(net, test_ds) 
  print("Accuracy on test data = %0.4f" % acc_test)

  # 4b. confusion matrix
  print("\nComputing confusion matrix ")
  cm = confusion_matrix_multi(net, test_ds, n_classes=3)
  # print(cm)  # raw matrix
  print("Formatted version: \n")
  show_confusion(cm)

  # 4c. acuracy metrics from confusion
  print("\nComputing test accuracies from confusion matrix ")
  (test_acc, class_accs) = accuracy_from_confusion_multi(cm)
  print("Accuracy on test data = %0.4f" % test_acc)
  print("Class accuracies: ")
  for i in range(len(cm)):
    print("%4d = %0.4f " % (i, class_accs[i]))

  # 5. make a prediction
  print("\nPredicting politics for M  30  oklahoma  $50,000: ")
  X = np.array([[-1, 0.30,  0,0,1,  0.5000]], dtype=np.float32)
  X = T.tensor(X, dtype=T.float32).to(device) 

  with T.no_grad():
    logits = net(X)  # do not sum to 1.0
  probs = T.exp(logits)  # sum to 1.0
  probs = probs.numpy()  # numpy vector prints better
  np.set_printoptions(precision=4, suppress=True)
  print(probs)

  # 6. save model (state_dict approach)
  print("\nSaving trained model state ")
  # fn = ".\\Models\\people_model.pt"
  # T.save(net.state_dict(), fn)

  # model = Net()  # requires class definition
  # model.load_state_dict(T.load(fn))
  # use model to make prediction(s)

  print("\nEnd People predict politics demo ")

if __name__ == "__main__":
  main()
This entry was posted in PyTorch. Bookmark the permalink.

Leave a Reply

Please log in using one of these methods to post your comment:

WordPress.com Logo

You are commenting using your WordPress.com account. Log Out /  Change )

Twitter picture

You are commenting using your Twitter account. Log Out /  Change )

Facebook photo

You are commenting using your Facebook account. Log Out /  Change )

Connecting to %s