K-Fold Cross-Validation for Neural Networks

I wrote an article “Understanding and Using K-Fold Cross-Validation for Neural Networks” that appears in the October 2013 issue of Visual Studio Magazine. See http://visualstudiomagazine.com/articles/2013/10/01/understanding-and-using-kfold.aspx. Exactly what k-fold cross-validation is, and why it is used, are somewhat difficult to explain clearly. Let me try. The main technical challenge when working with a neural network is training the network, which means finding values for the network’s many weights and biases so that for a given set of input values, the network’s computed output values closely match known outputs of a set of training data.

Because neural networks are universal function approximators, given enough time, it is always possible (in theory) to find a set of weights and biases so that computed outputs exactly match training data outputs. But if you use those weights and bias values on new, previously unseen data, your neural network will predict very poorly. This is called over-fitting. (I breezed through this but over-fitting is a very deep concept).

OK, so the problem is over-fitting. There are many ways to deal with over-fitting. K-fold cross validation is one. The idea is to break the training data into k subsets, where k is usually 10. Then you run your training algorithm (the three most common approaches are back-propagation, particle swarm optimization, and genetic algorithm optimization) 10 times. On the first training run you use the 9/10 of the training data to train, and then compute the network’s accuracy using the 1/10 of the remaining data. This process is repeated, so that each 1/10 subset is used exactly once as the validation set. When finished you take the average of the 10 accuracies and use it as the overall estimate of the accuracy of the network. In short, k-fold cross-validation gives you an estimate of a neural network’s accuracy whehn the network was constructed using particular values for number of hidden nodes and trainingb parameters.

How does this help? Well, if you do k-fold cross-validation repeatedly, and during the training phase use different values for the training technique’s parameters (different techniques have different parameters – back-prop needs learning rate and momentum, particle swarm needs inertia, cognitive and social weights, and so on) and also try different numbers of hidden nodes, you can find the best values for number of hidden nodes and training parameters. Then with these in hand you can finally train your network using all your data, with the best umber of hidden nodes and training parameters.

In pseudo-code:

loop "many" times
  pick a number of hidden nodes
  pick training parameters (learning rate, etc.)
  // k-fold
  divide train data into 10 parts
  for i = 1 to 10
    train network using 9 parts
    compute accuracy using 1 part
  end for
  compute average accuracy of the 10 runs

  if avg accuracy best found so far
    save number hidden nodes used
    save training parameters used
    save best average accuracy value
  end if
end loop

train network using all data
  (using best number hidden nodes,
   and best training parameters)
estimated accuracy is best accuracy found above

Anyway, k-fold cross-validation was difficult for me to grasp because there are so many inter-related issues, but after I thought the process over enough times, it finally made sense.


This entry was posted in Machine Learning. Bookmark the permalink.