Cross Validation and Neural Networks and Over-Fitting

Most of the information I see on the Internet about the relationship between cross validation and neural networks is either incomplete or just plain incorrect. Let me explain. Most references, including Wikipedia, state that the purpose of performing cross validation is to get an estimate of the accuracy (or error rate) of a neural network model. That is true, but what the majority of references fail to add is that in most cases the purpose of getting an estimate of an error rate is so that you can do model selection — pick the best from several neural networks. And in most cases you have several models and are dealing with over-fitting. The explanation is a bit subtle.

Suppose you want to create a neural network to predict something. You can start by generating several variations of neural networks, with different combinations of learning rate, momentum, number of hidden nodes, and possibly other features. Let’s say you have 12 variations. Next you perform k-fold cross validation on each of the 12 variations and get an error estimate for each of the 12 variations. (I’ll assume that you understand the mechanics of k-fold cross validation). Now you select the one neural network variation that generated the smallest average error on the k training sets. You have selected the best model in some sense. Finally, you use the learning rate, momentum, and number of hidden nodes of the best variation, with the entire data set as training data, to generate the set of weights and biases for the neural network.

CrossValidationSpreadsheeyt

This cross validation approach deals with the problem of over-fitting. If instead of the procedure above, you just tried different neural network variations on the entire data set, you would likely find a variation that fit your data very well — perhaps a variation with many hidden nodes — but the neural network would likely over-fit, and perform poorly when trying to predict using new, previously unseen data.

Now, another possibility to the scenario I just described goes as follows. You decide a priori on a set of neural network’s learning rate, number hidden nodes, and so on. You train on all the available data (first rather than at the end as above), and only then do k-fold cross validation in order to get an estimate of how well your neural network will perform on new data. In this scenario, cross validation is being used just to get an estimate of the error/accuracy of your neural network. But you have ignored the over-fitting issue.

Before I finish, let me mention hold-out validation. You begin up front by separating the data set into a training set (typically the first 80% of the data) and a test set (the remaining 205). Then the ideas are similar. You can train several neural network variations using the training set only and pick the one variation that performs best on the test set. Or you can just train one variation and then estimate the generalization error using the test set.

Finally, there is another related process called train-validate-test. This is used in conjunction with early-stopping, but that’s another story.

This entry was posted in Machine Learning. Bookmark the permalink.