Cross Entropy for Binary Classification

Suppose you are using a neural network to try and predict the color of a traffic signal. The possible values are red = (1, 0, 0), yellow = (0, 1, 0), green = (0, 0, 1). Your neural network spits out an answer like (0.25, 0.60, 0.15) which can be interpreted as probabilities, so in this case your prediction is “yellow”.

The most common way to measure the error during training is cross entropy error, which for the example above is CE error = -1 * [ ln(0.25)*0 + ln(0.60)*1 + ln(0.15)*0 ] = 0.5108. In other words, it’s just minus one times the log of the prediction probability associated with the class encoded with a 1.

But suppose you have a binary prediction problem, for example, you want to predict if a person is Male (class = 0) or Female (class = 1). The most common approach is to spit out a single value that represents the probability of the class encoded as 1. For example, if the predicted output is 0.65 then because that value is greater than 0.5, the prediction is Female.

Cross entropy error works only with two or more values that sum to 1.0 (a probability distribution) so you can’t directly use CE error if you are doing binary classification with a single output node.

One solution is just to create a dummy, second output node. For example, if the computed output is p = 0.65 then set q = 0.35 so the virtual output is (0.35, 0.65) corresponding to Male = (1, 0) and Female = (0, 1) and you can compute CE as usual. Another solution is to use code like this Python implementation:

def mce_error(fm, lm, W, b):
  # mean binary cross entropy error
  # fm is feature matrix, lm is labels array
  # W is weights, b is bias
  err = 0.0
  for i in range(0, len(fm)):  # walk thru each item
    X = fm[i]
    y = lm[i]  # target = 0 or 1
    z = 0.0
    for k in range(0, len(X)):
      z += X[k] * W[k]
    z += b
    p = 1.0 / (1.0 + np.exp(-z))  # computed result [0.0, 1.0)

    if y == 1:  # normal case
      err += -np.log(p)
    else:
      err += -np.log(1 - p) # 1-p is P(y=1)

  return err / len(fm)  # mean CE error

To monitor error during training, I typically avoid using cross entropy error for binary classification when I’m using a single output node. I’ll usually use two explicit output nodes and encoding, except when I’m using logistic regression, in which case I typically use a single output node with ordinary squared error for monitoring (gradient ascent log-likelihood generates the learning rule).

Advertisements
This entry was posted in Machine Learning. Bookmark the permalink.