Computing Precision and Recall from Scratch for PyTorch Binary Classifiers

I was talking to some relatively young colleagues who had recently joined my company. We were looking at the results of a binary classifier. My colleagues used library code to compute precision and recall metrics, but they didn’t really understand what those metrics were.

Precision and recall are alternative forms of accuracy. Accuracy for a binary classifier is easy: the number of correct predictions made divided by the total number of predictions. Precision and recall are defined in terms of “true positives”, “true negatives”, “false positives”, and “false negatives”. For a binary classifer (class 0 = negative, class 1 = positive), these are the only four possible outcomes of a prediction.

It’s easy to mix the four possible results up. In my mind, “true” means “correct” prediction, “false” means “wrong” prediction. Therefore, a “false positive” is an incorrect prediction for an item that is actually class 1.

TP = true positive   = correct prediction of actual class 1
FP = false positive  = wrong prediction of actual class 1

TN = true negative   = correct prediction of actual class 0
FN = false negative  = wrong prediction of actual class 0

note: N = total number of items = TP + FP + TN + FN

accuracy  = (TP + TN)  / N
precision = TP / (TP + FP)
recall    = TP / (TP + FN)
F1        = harmonic_mean(precision, recall)
          = 2 / [(1 / precision) + (1 / recall)]

Suppose you have 100 data items. You build a binary classifier and get these results on some data that has known class labels:

--------------------------------------------------
                     predicted
                  class 1  class 0
actual   class 1    40       20
         class 0    10       30
--------------------------------------------------

accuracy  = (40 + 30) / 100 = 0.70
precision = 40 / (40 + 20)  = 0.67
recall    = 40 / (40 + 10)  = 0.80
F1        = 2 / ((1 / 0.67) + (1 / 0.80)) = 0.73

The table of results is called a confusion matrix. Unfortunately there is no standard format so you have to interpret carefully.

The F1 score is the harmonic mean of precision and recall — just a way of summarizing them together.


def metrics(model, ds, thresh=0.5):
  # accuracy  = (TP + TN)  / N
  # precision = TP / (TP + FP)
  # recall    = TP / (TP + FN)
  # F1        = 2 / [(1 / precision) + (1 / recall)]

  tp = 0; tn = 0; fp = 0; fn = 0
  for i in range(len(ds)):
    inpts = ds[i]['predictors']  # dictionary style
    target = ds[i]['sex']    # float32  [0.0] or [1.0]
    with T.no_grad():
      p = model(inpts)       # between 0.0 and 1.0

    # should really avoid 'target == 1.0'
    if target == 1.0 and p "gte" thresh:    # TP
      tp += 1
    elif target == 1.0 and p "lt" thresh:   # FP
      fp += 1
    elif target == 0.0 and p "lt" thresh:   # TN
      tn += 1
    elif target == 0.0 and p "gte" thresh:  # FN
      fn += 1

  N = tp + fp + tn + fn
  if N != len(ds):
    print("FATAL LOGIC ERROR")

  accuracy = (tp + tn) / (N * 1.0)
  precision = (1.0 * tp) / (tp + fp)
  recall = (1.0 * tp) / (tp + fn)
  f1 = 2.0 / ((1.0 / precision) + (1.0 / recall))
  return (accuracy, precision, recall, f1)

Computing precision and recall from scratch is easy. You don’t need a library function that introduces a system dependency.


So why are precision and recall even needed? They’re mostly useful in situations where the true class labels are highly skewed. For example, suppose you have training data with 95 class 0 items and 5 class 1 items. You can get 95% accuracy by just predicting class 0 for every data item.

With skewed data, if you just guess one class or the other, you can get excellent accuracy but either precision or recall will be very poor.

Now here’s where it gets a bit tricky. Most binary classifiers generate a single value between 0.0 and 1.0. It’s standard practice to set a threshold value of 0.5 and then a result value that’s less than 0.5 is a prediction of class 0, and a result value that’s greater than 0.5 is a prediction of class 1.

You can adjust the accuracy of a binary classification model by adjusting the value of the threshold. For example, suppose you have 95 class 0 items and just 5 class 1 items. If you set the threshold to something like 0.9 then most result values will be less than 0.9 and so most predictions will be class 0 and you’ll get good accuracy. But again, if you use this threshold-adjusting cheat, either precision or recall will be poor.

Finally, if you run your data through your classifier, and set the threshold to 0.0, 0.10, 0.20, . . , 0.90, 1.0 then each value of the threshold will genrate a different number of TP, FP, TN, FN. For each of the 11 threshold values, if you make a graph with FP on the x-axis and TP on the y-axis, you’ll get 11 dots. If you connect the dots you get what’s called a ROC curve (receiver operating characteristic).

The key point I made to my new colleagues is that precision and recall aren’t scary magic that needs a complex library. You can easily compute precision and recall, which gives you complete control over your code (no dependencies) and more importantly, gives you a full understanding of what you’re doing.


American, Chinese, and Russian scary-magic witches. Left: The Wicked Witch of the West, from “The Wizard of Oz” (1939). Center: The White Bone Witch from “The Monkey King 2” (2016). Right: Pannochka from “Forbidden Empire” (2014).

This entry was posted in Machine Learning, PyTorch. Bookmark the permalink.

Leave a Reply

Please log in using one of these methods to post your comment:

WordPress.com Logo

You are commenting using your WordPress.com account. Log Out /  Change )

Google photo

You are commenting using your Google account. Log Out /  Change )

Twitter picture

You are commenting using your Twitter account. Log Out /  Change )

Facebook photo

You are commenting using your Facebook account. Log Out /  Change )

Connecting to %s