## Computing and Displaying a Confusion Matrix for a PyTorch Neural Network Multi-Class Classifier

After training a PyTorch multi-class classifier, it’s important to evaluate the accuracy of the trained model. Simple classification accuracy is OK but in many scenarios you want a so-called confusion matrix that gives details of the number of correct and wrong predictions for each target class label. For example, suppose you’re predicting the political leaning (0 = conservative, 1 = moderate, 2 = liberal) of a person based on their sex (-1 = male, 1 = female), age (divided by 100), State (Michigan = 100, Nebraska = 010, Oklahoma = 001), and income (divided by \$100,000). An example of a formatted confusion matrix might look like:

```Computing model accuracy
Accuracy on training data = 0.8150
Accuracy on test data = 0.7500

Computing raw confusion matrix:
[[ 6  4  1]
[ 1 13  0]
[ 2  2 11]]

Formatted version:
actual     0:  6  4  1
actual     1:  1 13  0
actual     2:  2  2 11
------------
predicted      0  1  2
```

Here’s my function to compute a raw confusion matrix for a multi-class classifier:

```def confusion_matrix_multi(model, ds, n_classes):
if n_classes "lte" = 2:  # less-than-or-equal
print("ERROR: n_classes must be 3 or greater ")
return None

cm = np.zeros((n_classes,n_classes), dtype=np.int64)
for i in range(len(ds)):
X = ds[i].reshape(1,-1)  # make it a batch
Y = ds[i].reshape(1)  # actual class 0 1 or 2, 1D
oupt = model(X)  # logits form
pred_class = T.argmax(oupt)  # 0,1,2
cm[Y][pred_class] += 1
return cm
```

The function accepts a trained PyTorch classifier and a PyTorch Dataset object that is composed of either a Tuple or a Dictionary where the predictors are at  and the target labels are at . The n_classes could be determined programmatically but it’s easier to pass that value in as a parameter.

Note: A function to compute a confusion matrix for a PyTorch binary classifier, where there are just two possible outcomes, uses slightly different code.

The raw confusion matrix is difficult to interpret so I wrote a function to format the matrix by adding some labels:

```def show_confusion(cm):
dim = len(cm)
mx = np.max(cm)             # largest count in cm
wid = len(str(mx)) + 1      # width to print
fmt = "%" + str(wid) + "d"  # like "%3d"
for i in range(dim):
print("actual   ", end="")
print("%3d:" % i, end="")
for j in range(dim):
print(fmt % cm[i][j], end="")
print("")
print("------------")
print("predicted    ", end="")
for j in range(dim):
print(fmt % j, end="")
print("")
```

If you do have a confusion matrix, it’s possible to compute an overall accuracy metric from it. The total number of data items is the sum of the entries in the matrix. The number of correct predictions is the sum of the entries on the main diagonal. For example:

```def accuracy_from_confusion_multi(cm):
# return (overall accuracy, list of class accuracies)
N = np.sum(cm)  # total count
dim = len(cm)
row_sums = cm.sum(axis=1)  # collapse on cols, process rows

n_correct = 0
for i in range(dim):
n_correct += cm[i][i]  # on the diagonal
overall = n_correct / N

class_accs = []
for i in range(dim):
class_accs.append(cm[i][i] / row_sums[i])

return (overall, class_accs)
```

The output for the demo data looks like:

```Computing test accuracies from confusion matrix
Accuracy on test data = 0.7500
Class accuracies:
0 = 0.5455
1 = 0.9286
2 = 0.7333
```

Good fun. Demo code below. The training and test data can be found at https://jamesmccaffrey.wordpress.com/2022/09/01/multi-class-classification-using-pytorch-1-12-1-on-windows-10-11/. Easter is coming soon. I grew up Catholic so Easter is important to me. Here are three photos of child confusion caused by disturbing Easter Bunnies.

```# people_politics.py
# predict politics type from sex, age, state, income
# PyTorch 1.13.1-CPU Anaconda3-2022.10  Python 3.9.13
# Windows 10/11

import numpy as np
import torch as T
device = T.device('cpu')  # apply to Tensor or Module

# -----------------------------------------------------------

class PeopleDataset(T.utils.data.Dataset):
# sex  age    state    income   politics
# -1   0.27   0  1  0   0.7610   2
# +1   0.19   0  0  1   0.6550   0
# sex: -1 = male, +1 = female
# politics: conservative, moderate, liberal

def __init__(self, src_file):
tmp_x = all_xy[:,0:6]   # cols [0,6) = [0,5]
tmp_y = all_xy[:,6]     # 1-D

self.x_data = T.tensor(tmp_x,
dtype=T.float32).to(device)
self.y_data = T.tensor(tmp_y,
dtype=T.int64).to(device)  # 1-D

def __len__(self):
return len(self.x_data)

def __getitem__(self, idx):
preds = self.x_data[idx]
trgts = self.y_data[idx]
return preds, trgts  # as a Tuple

# -----------------------------------------------------------

class Net(T.nn.Module):
def __init__(self):
super(Net, self).__init__()
self.hid1 = T.nn.Linear(6, 10)  # 6-(10-10)-3
self.hid2 = T.nn.Linear(10, 10)
self.oupt = T.nn.Linear(10, 3)

T.nn.init.xavier_uniform_(self.hid1.weight)
T.nn.init.zeros_(self.hid1.bias)
T.nn.init.xavier_uniform_(self.hid2.weight)
T.nn.init.zeros_(self.hid2.bias)
T.nn.init.xavier_uniform_(self.oupt.weight)
T.nn.init.zeros_(self.oupt.bias)

def forward(self, x):
z = T.tanh(self.hid1(x))
z = T.tanh(self.hid2(z))
z = T.log_softmax(self.oupt(z), dim=1)  # NLLLoss()
return z

# -----------------------------------------------------------

def accuracy(model, ds):
# assumes model.eval()
# item-by-item version
n_correct = 0; n_wrong = 0
for i in range(len(ds)):
X = ds[i].reshape(1,-1)  # make it a batch
Y = ds[i].reshape(1)  # 0 1 or 2, 1D
oupt = model(X)  # logits form

big_idx = T.argmax(oupt)  # 0 or 1 or 2
if big_idx == Y:
n_correct += 1
else:
n_wrong += 1

acc = (n_correct * 1.0) / (n_correct + n_wrong)
return acc

# -----------------------------------------------------------

def accuracy_quick(model, dataset):
# assumes model.eval()
X = dataset[0:len(dataset)]
# Y = T.flatten(dataset[0:len(dataset)])
Y = dataset[0:len(dataset)]
oupt = model(X)
# (_, arg_maxs) = T.max(oupt, dim=1)
arg_maxs = T.argmax(oupt, dim=1)  # argmax() is new
num_correct = T.sum(Y==arg_maxs)
acc = (num_correct * 1.0 / len(dataset))
return acc.item()

# -----------------------------------------------------------

def confusion_matrix_multi(model, ds, n_classes):
if n_classes "lte" 2:  # less-than-or-equal
print("ERROR: n_classes must be 3 or greater ")
return None

cm = np.zeros((n_classes,n_classes), dtype=np.int64)
for i in range(len(ds)):
X = ds[i].reshape(1,-1)  # make it a batch
Y = ds[i].reshape(1)  # actual class 0 1 or 2, 1D
oupt = model(X)  # logits form
pred_class = T.argmax(oupt)  # 0,1,2
cm[Y][pred_class] += 1
return cm

# -----------------------------------------------------------

def show_confusion(cm):
dim = len(cm)
mx = np.max(cm)             # largest count in cm
wid = len(str(mx)) + 1      # width to print
fmt = "%" + str(wid) + "d"  # like "%3d"
for i in range(dim):
print("actual   ", end="")
print("%3d:" % i, end="")
for j in range(dim):
print(fmt % cm[i][j], end="")
print("")
print("------------")
print("predicted    ", end="")
for j in range(dim):
print(fmt % j, end="")
print("")

# -----------------------------------------------------------

def accuracy_from_confusion_multi(cm):
# return (overall accuracy, list of class accuracies)
N = np.sum(cm)  # total count
dim = len(cm)
row_sums = cm.sum(axis=1)

n_correct = 0
for i in range(dim):
n_correct += cm[i][i]  # on the diagonal
overall = n_correct / N

class_accs = []
for i in range(dim):
class_accs.append(cm[i][i] / row_sums[i])

return (overall, class_accs)

# -----------------------------------------------------------

def main():
# 0. get started
print("\nBegin People predict politics type ")
T.manual_seed(1)
np.random.seed(1)

print("\nCreating People Datasets ")

train_file = ".\\Data\\people_train.txt"
train_ds = PeopleDataset(train_file)  # 200 rows

test_file = ".\\Data\\people_test.txt"
test_ds = PeopleDataset(test_file)    # 40 rows

bat_size = 10
batch_size=bat_size, shuffle=True)

# -----------------------------------------------------------

# 2. create and compile network
print("\nCreating 6-(10-10)-3 neural network ")
net = Net().to(device)
net.train()

# -----------------------------------------------------------

# 3. train model
max_epochs = 1000
ep_log_interval = 200
lrn_rate = 0.01

loss_func = T.nn.NLLLoss()  # assumes log_softmax()
optimizer = T.optim.SGD(net.parameters(), lr=lrn_rate)

print("\nbat_size = %3d " % bat_size)
print("loss = " + str(loss_func))
print("optimizer = SGD")
print("max_epochs = %3d " % max_epochs)
print("lrn_rate = %0.3f " % lrn_rate)

print("\nStarting training ")
for epoch in range(0, max_epochs):
# T.manual_seed(epoch+1)  # checkpoint reproducibility
epoch_loss = 0  # for one full epoch

for (batch_idx, batch) in enumerate(train_ldr):
X = batch  # inputs
Y = batch  # correct class/label/politics

oupt = net(X)
loss_val = loss_func(oupt, Y)  # a tensor
epoch_loss += loss_val.item()  # accumulate
loss_val.backward()
optimizer.step()

if epoch % ep_log_interval == 0:
print("epoch = %5d  |  loss = %10.4f" % \
(epoch, epoch_loss))

print("Training done ")

# -----------------------------------------------------------

# 4. evaluate model accuracy
print("\nComputing model accuracy")
net.eval()
acc_train = accuracy(net, train_ds)  # item-by-item
print("Accuracy on training data = %0.4f" % acc_train)
acc_test = accuracy(net, test_ds)
print("Accuracy on test data = %0.4f" % acc_test)

# 4b. confusion matrix
print("\nComputing confusion matrix ")
cm = confusion_matrix_multi(net, test_ds, n_classes=3)
# print(cm)  # raw matrix
print("Formatted version: \n")
show_confusion(cm)

# 4c. acuracy metrics from confusion
print("\nComputing test accuracies from confusion matrix ")
(test_acc, class_accs) = accuracy_from_confusion_multi(cm)
print("Accuracy on test data = %0.4f" % test_acc)
print("Class accuracies: ")
for i in range(len(cm)):
print("%4d = %0.4f " % (i, class_accs[i]))

# 5. make a prediction
print("\nPredicting politics for M  30  oklahoma  \$50,000: ")
X = np.array([[-1, 0.30,  0,0,1,  0.5000]], dtype=np.float32)
X = T.tensor(X, dtype=T.float32).to(device)

logits = net(X)  # do not sum to 1.0
probs = T.exp(logits)  # sum to 1.0
probs = probs.numpy()  # numpy vector prints better
np.set_printoptions(precision=4, suppress=True)
print(probs)

# 6. save model (state_dict approach)
print("\nSaving trained model state ")
# fn = ".\\Models\\people_model.pt"
# T.save(net.state_dict(), fn)

# model = Net()  # requires class definition