After training a PyTorch multi-class classifier, it’s important to evaluate the accuracy of the trained model. Simple classification accuracy is OK but in many scenarios you want a so-called confusion matrix that gives details of the number of correct and wrong predictions for each target class label.
For example, suppose you’re predicting the political leaning (0 = conservative, 1 = moderate, 2 = liberal) of a person based on their sex (-1 = male, 1 = female), age (divided by 100), State (Michigan = 100, Nebraska = 010, Oklahoma = 001), and income (divided by $100,000). An example of a formatted confusion matrix might look like:
Computing model accuracy Accuracy on training data = 0.8150 Accuracy on test data = 0.7500 Computing raw confusion matrix: [[ 6 4 1] [ 1 13 0] [ 2 2 11]] Formatted version: actual 0: 6 4 1 actual 1: 1 13 0 actual 2: 2 2 11 ------------ predicted 0 1 2
Here’s my function to compute a raw confusion matrix for a multi-class classifier:
def confusion_matrix_multi(model, ds, n_classes): if n_classes "lte" = 2: # less-than-or-equal print("ERROR: n_classes must be 3 or greater ") return None cm = np.zeros((n_classes,n_classes), dtype=np.int64) for i in range(len(ds)): X = ds[i][0].reshape(1,-1) # make it a batch Y = ds[i][1].reshape(1) # actual class 0 1 or 2, 1D with T.no_grad(): oupt = model(X) # logits form pred_class = T.argmax(oupt) # 0,1,2 cm[Y][pred_class] += 1 return cm
The function accepts a trained PyTorch classifier and a PyTorch Dataset object that is composed of either a Tuple or a Dictionary where the predictors are at [0] and the target labels are at [1]. The n_classes could be determined programmatically but it’s easier to pass that value in as a parameter.
Note: A function to compute a confusion matrix for a PyTorch binary classifier, where there are just two possible outcomes, uses slightly different code.
The raw confusion matrix is difficult to interpret so I wrote a function to format the matrix by adding some labels:
def show_confusion(cm): dim = len(cm) mx = np.max(cm) # largest count in cm wid = len(str(mx)) + 1 # width to print fmt = "%" + str(wid) + "d" # like "%3d" for i in range(dim): print("actual ", end="") print("%3d:" % i, end="") for j in range(dim): print(fmt % cm[i][j], end="") print("") print("------------") print("predicted ", end="") for j in range(dim): print(fmt % j, end="") print("")
If you do have a confusion matrix, it’s possible to compute an overall accuracy metric from it. The total number of data items is the sum of the entries in the matrix. The number of correct predictions is the sum of the entries on the main diagonal. For example:
def accuracy_from_confusion_multi(cm): # return (overall accuracy, list of class accuracies) N = np.sum(cm) # total count dim = len(cm) row_sums = cm.sum(axis=1) # collapse on cols, process rows n_correct = 0 for i in range(dim): n_correct += cm[i][i] # on the diagonal overall = n_correct / N class_accs = [] for i in range(dim): class_accs.append(cm[i][i] / row_sums[i]) return (overall, class_accs)
The output for the demo data looks like:
Computing test accuracies from confusion matrix Accuracy on test data = 0.7500 Class accuracies: 0 = 0.5455 1 = 0.9286 2 = 0.7333
Good fun. Demo code below. The training and test data can be found at https://jamesmccaffrey.wordpress.com/2022/09/01/multi-class-classification-using-pytorch-1-12-1-on-windows-10-11/.
Easter is coming soon. I grew up Catholic so Easter is important to me. Here are three photos of child confusion caused by disturbing Easter Bunnies.
# people_politics.py # predict politics type from sex, age, state, income # PyTorch 1.13.1-CPU Anaconda3-2022.10 Python 3.9.13 # Windows 10/11 import numpy as np import torch as T device = T.device('cpu') # apply to Tensor or Module # ----------------------------------------------------------- class PeopleDataset(T.utils.data.Dataset): # sex age state income politics # -1 0.27 0 1 0 0.7610 2 # +1 0.19 0 0 1 0.6550 0 # sex: -1 = male, +1 = female # state: michigan, nebraska, oklahoma # politics: conservative, moderate, liberal def __init__(self, src_file): all_xy = np.loadtxt(src_file, usecols=range(0,7), delimiter="\t", comments="#", dtype=np.float32) tmp_x = all_xy[:,0:6] # cols [0,6) = [0,5] tmp_y = all_xy[:,6] # 1-D self.x_data = T.tensor(tmp_x, dtype=T.float32).to(device) self.y_data = T.tensor(tmp_y, dtype=T.int64).to(device) # 1-D def __len__(self): return len(self.x_data) def __getitem__(self, idx): preds = self.x_data[idx] trgts = self.y_data[idx] return preds, trgts # as a Tuple # ----------------------------------------------------------- class Net(T.nn.Module): def __init__(self): super(Net, self).__init__() self.hid1 = T.nn.Linear(6, 10) # 6-(10-10)-3 self.hid2 = T.nn.Linear(10, 10) self.oupt = T.nn.Linear(10, 3) T.nn.init.xavier_uniform_(self.hid1.weight) T.nn.init.zeros_(self.hid1.bias) T.nn.init.xavier_uniform_(self.hid2.weight) T.nn.init.zeros_(self.hid2.bias) T.nn.init.xavier_uniform_(self.oupt.weight) T.nn.init.zeros_(self.oupt.bias) def forward(self, x): z = T.tanh(self.hid1(x)) z = T.tanh(self.hid2(z)) z = T.log_softmax(self.oupt(z), dim=1) # NLLLoss() return z # ----------------------------------------------------------- def accuracy(model, ds): # assumes model.eval() # item-by-item version n_correct = 0; n_wrong = 0 for i in range(len(ds)): X = ds[i][0].reshape(1,-1) # make it a batch Y = ds[i][1].reshape(1) # 0 1 or 2, 1D with T.no_grad(): oupt = model(X) # logits form big_idx = T.argmax(oupt) # 0 or 1 or 2 if big_idx == Y: n_correct += 1 else: n_wrong += 1 acc = (n_correct * 1.0) / (n_correct + n_wrong) return acc # ----------------------------------------------------------- def accuracy_quick(model, dataset): # assumes model.eval() X = dataset[0:len(dataset)][0] # Y = T.flatten(dataset[0:len(dataset)][1]) Y = dataset[0:len(dataset)][1] with T.no_grad(): oupt = model(X) # (_, arg_maxs) = T.max(oupt, dim=1) arg_maxs = T.argmax(oupt, dim=1) # argmax() is new num_correct = T.sum(Y==arg_maxs) acc = (num_correct * 1.0 / len(dataset)) return acc.item() # ----------------------------------------------------------- def confusion_matrix_multi(model, ds, n_classes): if n_classes "lte" 2: # less-than-or-equal print("ERROR: n_classes must be 3 or greater ") return None cm = np.zeros((n_classes,n_classes), dtype=np.int64) for i in range(len(ds)): X = ds[i][0].reshape(1,-1) # make it a batch Y = ds[i][1].reshape(1) # actual class 0 1 or 2, 1D with T.no_grad(): oupt = model(X) # logits form pred_class = T.argmax(oupt) # 0,1,2 cm[Y][pred_class] += 1 return cm # ----------------------------------------------------------- def show_confusion(cm): dim = len(cm) mx = np.max(cm) # largest count in cm wid = len(str(mx)) + 1 # width to print fmt = "%" + str(wid) + "d" # like "%3d" for i in range(dim): print("actual ", end="") print("%3d:" % i, end="") for j in range(dim): print(fmt % cm[i][j], end="") print("") print("------------") print("predicted ", end="") for j in range(dim): print(fmt % j, end="") print("") # ----------------------------------------------------------- def accuracy_from_confusion_multi(cm): # return (overall accuracy, list of class accuracies) N = np.sum(cm) # total count dim = len(cm) row_sums = cm.sum(axis=1) n_correct = 0 for i in range(dim): n_correct += cm[i][i] # on the diagonal overall = n_correct / N class_accs = [] for i in range(dim): class_accs.append(cm[i][i] / row_sums[i]) return (overall, class_accs) # ----------------------------------------------------------- def main(): # 0. get started print("\nBegin People predict politics type ") T.manual_seed(1) np.random.seed(1) # 1. create DataLoader objects print("\nCreating People Datasets ") train_file = ".\\Data\\people_train.txt" train_ds = PeopleDataset(train_file) # 200 rows test_file = ".\\Data\\people_test.txt" test_ds = PeopleDataset(test_file) # 40 rows bat_size = 10 train_ldr = T.utils.data.DataLoader(train_ds, batch_size=bat_size, shuffle=True) # ----------------------------------------------------------- # 2. create and compile network print("\nCreating 6-(10-10)-3 neural network ") net = Net().to(device) net.train() # ----------------------------------------------------------- # 3. train model max_epochs = 1000 ep_log_interval = 200 lrn_rate = 0.01 loss_func = T.nn.NLLLoss() # assumes log_softmax() optimizer = T.optim.SGD(net.parameters(), lr=lrn_rate) print("\nbat_size = %3d " % bat_size) print("loss = " + str(loss_func)) print("optimizer = SGD") print("max_epochs = %3d " % max_epochs) print("lrn_rate = %0.3f " % lrn_rate) print("\nStarting training ") for epoch in range(0, max_epochs): # T.manual_seed(epoch+1) # checkpoint reproducibility epoch_loss = 0 # for one full epoch for (batch_idx, batch) in enumerate(train_ldr): X = batch[0] # inputs Y = batch[1] # correct class/label/politics optimizer.zero_grad() oupt = net(X) loss_val = loss_func(oupt, Y) # a tensor epoch_loss += loss_val.item() # accumulate loss_val.backward() optimizer.step() if epoch % ep_log_interval == 0: print("epoch = %5d | loss = %10.4f" % \ (epoch, epoch_loss)) print("Training done ") # ----------------------------------------------------------- # 4. evaluate model accuracy print("\nComputing model accuracy") net.eval() acc_train = accuracy(net, train_ds) # item-by-item print("Accuracy on training data = %0.4f" % acc_train) acc_test = accuracy(net, test_ds) print("Accuracy on test data = %0.4f" % acc_test) # 4b. confusion matrix print("\nComputing confusion matrix ") cm = confusion_matrix_multi(net, test_ds, n_classes=3) # print(cm) # raw matrix print("Formatted version: \n") show_confusion(cm) # 4c. acuracy metrics from confusion print("\nComputing test accuracies from confusion matrix ") (test_acc, class_accs) = accuracy_from_confusion_multi(cm) print("Accuracy on test data = %0.4f" % test_acc) print("Class accuracies: ") for i in range(len(cm)): print("%4d = %0.4f " % (i, class_accs[i])) # 5. make a prediction print("\nPredicting politics for M 30 oklahoma $50,000: ") X = np.array([[-1, 0.30, 0,0,1, 0.5000]], dtype=np.float32) X = T.tensor(X, dtype=T.float32).to(device) with T.no_grad(): logits = net(X) # do not sum to 1.0 probs = T.exp(logits) # sum to 1.0 probs = probs.numpy() # numpy vector prints better np.set_printoptions(precision=4, suppress=True) print(probs) # 6. save model (state_dict approach) print("\nSaving trained model state ") # fn = ".\\Models\\people_model.pt" # T.save(net.state_dict(), fn) # model = Net() # requires class definition # model.load_state_dict(T.load(fn)) # use model to make prediction(s) print("\nEnd People predict politics demo ") if __name__ == "__main__": main()
You must be logged in to post a comment.