PyTorch Multi-Class Accuracy By Class Using a Set-Wise Approach

I recently revisited multi-class classification using PyTorch. My demo was to predict a person’s political type (conservative, moderate, liberal) based on sex, age, state (michigan, nebraska, oklahoma), and annual income. See jamesmccaffrey.wordpress.com/2022/09/01/multi-class-classification-using-pytorch-1-12-1-on-windows-10-11/.

My demo computed overall model accuracy. I decided to implement a function to compute accuracy by class. I usually compute accuracy by class using a simple item-by-item iteration. For example, see https://jamesmccaffrey.wordpress.com/2022/07/12/pytorch-multi-class-accuracy-by-class/. I decided to implement an accuracy by class function using a set approach that processes all items at once rather than iterating.


Note: This blog post essentially computes a subset of a confusion matrix. A slightly more versatile approach is to just go ahead and do the entire confusion matrix. See the post at https://jamesmccaffrey.wordpress.com/2023/03/15/computing-and-displaying-a-confusion-matrix-for-a-pytorch-neural-network-classifier/.


Here’s the result:

def do_acc(model, dataset, n_classes):
  X = dataset[0:len(dataset)][0]  # all X values
  Y = dataset[0:len(dataset)][1]  # all Y values
  with T.no_grad():
    oupt = model(X)  #  all logits

  for c in range(n_classes):
    idxs = np.where(Y==c)  # indices where Y is c
    logits_c = oupt[idxs]  # logits corresponding to Y == c
    arg_maxs_c = T.argmax(logits_c, dim=1)  # predicted class
    num_correct = T.sum(arg_maxs_c == c)
    acc_c = num_correct.item() / len(arg_maxs_c)
    print("%0.4f " % acc_c)

Writing the function took me a bit longer than I had expected. The coding part of my brain thinks iteratively rather than set-wise. This is why I’m most comfortable with languages like C# and standard Python, and less comfortable with SQL and things like Python list comprehensions.

Good fun.



Three books where it’s difficult to classify the accuracy of the title without more information. Left: “It Must’ve Been the Fish Sticks”. Center: “How to Talk to Your Cat About Gun Safety”. Right: “Mommy Drinks Because You’re Bad”.


This entry was posted in PyTorch. Bookmark the permalink.

Leave a comment