PyTorch Multi-Class Accuracy by Class

I was presenting an all-day PyTorch workshop and one of my examples was multi-class classification. The goal was to predict the job-type of an employee (0 = mgmt, 1 = supp, 2 = tech) from sex, age, city (anaheim, boulder, concord), income.

My demo had a program-defined accuracy() function to compute the overall classification of the trained model. The result was 81.50% accuracy on the training data.

One of the attendees in the workshop, Alex, pointed out that in a non-demo scenario, you should compute accuracy for each of the three job types. This lets you see situations such as when your model does well on most classes but fails badly when predicting on one of the classes.

I coded up a demo. I defined an accuracy() function as:

def accuracy(model, ds, num_classes):
  # assumes model.eval()
  # by class, item-by-item version
  counts = np.zeros((num_classes,2), dtype=np.int64)  
  for i in range(len(ds)):
    X = ds[i][0].reshape(1,-1)  # make it a batch
    Y = ds[i][1]  # 0 1 or 2
    with T.no_grad():
      oupt = model(X)  # logits form

    big_idx = T.argmax(oupt)  # 0 or 1 or 2
    if big_idx == Y.item():
      counts[Y.item()][0] += 1  # correct
      counts[Y.item()][1] += 1  # wrong

  pcts = np.zeros(num_classes, dtype=np.float32)
  for c in range(num_classes):
    pcts[c] = counts[c][0] * 1.0 / (counts[c][0] + \

  num_correct = 0; num_wrong = 0
  for c in range(num_classes):
    num_correct += counts[c][0]
    num_wrong += counts[c][1]
  overall = num_correct * 1.0 / (num_correct + \

  return counts, pcts, overall

The demo code iterates through data one item at a time. This approach is slow but allows you to diagnose problems. A much faster approach is to send all inputs to the model, then fetch all outputs, and then analyze for correct / wrong. The code would look something like:

X = dataset[0:len(ds)][0]        # all inputs
Y = T.flatten(ds[0:len(ds)][1])  # all targets

with T.no_grad():
  oupt = model(X)
arg_maxs = T.argmax(oupt, dim=1) # all predicteds
. . . 

My program-defined accuracy() function returns three values for the demo data:

[[46 15]
 [77 11]
 [40 11]]

[0.7541  0.8750  0.7843]


The first value is the counts of correct and wrong predictions. So, for class 0 = mgmt, there were 46 correct predictions and 15 wrong predictions. For class 1 = supp, there were 77 correct and 11 wrong. And so on.

The second value is the percent accuracy by class. So, for class 0 the accuracy was 75.41%, and 87.50% accuracy for class 1, and 78.43% accuracy for class 2. The fact that the accuracy metrics were similar for all classes is good.

The third return value is the overall result of the model across all classes, which is 81.5% accuracy.

Good fun.

This entry was posted in PyTorch. Bookmark the permalink.

Leave a Reply

Please log in using one of these methods to post your comment: Logo

You are commenting using your account. Log Out /  Change )

Twitter picture

You are commenting using your Twitter account. Log Out /  Change )

Facebook photo

You are commenting using your Facebook account. Log Out /  Change )

Connecting to %s