Neural Network Back-Propagation and De-Modularizing

I ran across an unusual scenario recently where it was beneficial to “de-modularize” some code. I’m not sure if “de-modularize” is a word or not, but I mean refactoring some code that was in two functions to one larger function. The code I was working with was neural network training using the back-propagation algorithm. My original code resembled this:

Method Train:

loop until done
    for each data item
      compute output values
      update weight values
    end for
  end loop

In other words, the training method was a wrapper around two methods, ComputeOutputs and UpdateWeights. Nice and modular. The update-weights method resembled:

Method UpdateWeights:

compute output gradients
  compute hidden gradients
  use gradients to compute deltas
  update weights using deltas
  modify weights using deltas from previous call

I’m leaving out a lot of details, but the problem with the modular approach was storing the previous-deltas. The previous deltas are stored in two matrices and two arrays. One approach is to put these four data structures as class members. But that’s ugly because it doesn’t make replacing the training method easy. A second approach is to place the data structures inside method Train. But this means I’d have to pass them as four additional parameters to method UpdateWeights or create a “state-context” data structure that holds all four data structures and pass it as a parameter.

DeModularizingBackPropagation

In the end, the best solution was to de-modularize the code by ditching the UpdateWeights method and placing its code directly into method Train. Here’s the result:

public void Train2(double[][] trainData, int maxEpochs,
  double learnRate, double momentum)
{
  // integrated 'UpdateWeights' version 
  // back-prop specific arrays
  double[] oGrads = new double[numOutput]; // gradients
  double[] hGrads = new double[numHidden];

  // back-prop momentum specific arrays 
  double[][] ihPrevWeightsDelta = MakeMatrix(numInput,
    numHidden);
  double[] hPrevBiasesDelta = new double[numHidden];
  double[][] hoPrevWeightsDelta = MakeMatrix(numHidden,
    numOutput);
  double[] oPrevBiasesDelta = new double[numOutput];

  // train 
  int epoch = 0;
  double[] xValues = new double[numInput]; // inputs
  double[] tValues = new double[numOutput]; // targets

  int[] sequence = new int[trainData.Length];
  for (int i = 0; i < sequence.Length; ++i)
    sequence[i] = i;

  while (epoch < maxEpochs)
  {
    double mse = MeanSquaredError(trainData);
    if (mse < 0.040) break;

    Shuffle(sequence); // random order
    for (int ii = 0; ii < trainData.Length; ++ii)
    {
      int idx = sequence[ii];
      Array.Copy(trainData[idx], xValues, numInput);
      Array.Copy(trainData[idx], numInput, tValues, 0,
        numOutput);
      ComputeOutputs(xValues);
      //UpdateWeights(tValues, learnRate, momentum);

      // ---- Update-Weights section
      // 1. compute output gradients
      for (int i = 0; i < numOutput; ++i)
      {
        // derivative for softmax = (1 - y) * y 
        double derivative = (1 - outputs[i]) * outputs[i];
        oGrads[i] = derivative * (tValues[i] - outputs[i]);
      }

      // 2. compute hidden gradients
      for (int i = 0; i < numHidden; ++i)
      {
        // derivative of tanh = (1 - y) * (1 + y)
        double derivative = (1 - hOutputs[i]) *
          (1 + hOutputs[i]);
        double sum = 0.0;
        for (int j = 0; j < numOutput; ++j)
        {
          double x = oGrads[j] * hoWeights[i][j];
          sum += x;
        }
        hGrads[i] = derivative * sum;
      }

      // 3a. update hidden weights
      // weights can be updated in any order)
      for (int i = 0; i < numInput; ++i) // 0..2 (3)
      {
        for (int j = 0; j < numHidden; ++j) // 0..3 (4)
        {
          double delta = learnRate * hGrads[j] * inputs[i];
          ihWeights[i][j] += delta; 
          // now add momentum using previous delta.
          ihWeights[i][j] += momentum *
            ihPrevWeightsDelta[i][j];
          ihPrevWeightsDelta[i][j] = delta; 
        }
      }

      // 3b. update hidden biases
      for (int i = 0; i < numHidden; ++i)
      {
        double delta = learnRate * hGrads[i]; 
        hBiases[i] += delta;
        hBiases[i] += momentum *
          hPrevBiasesDelta[i]; // momentum
        hPrevBiasesDelta[i] = delta;
      }

      // 4. update hidden-output weights
      for (int i = 0; i < numHidden; ++i)
      {
        for (int j = 0; j < numOutput; ++j)
        {
          double delta = learnRate * oGrads[j] *
            hOutputs[i];
          hoWeights[i][j] += delta;
          hoWeights[i][j] += momentum *
            hoPrevWeightsDelta[i][j]; // momentum
          hoPrevWeightsDelta[i][j] = delta; // save
        }
      }

      // 4b. update output biases
      for (int i = 0; i < numOutput; ++i)
      {
        double delta = learnRate * oGrads[i] * 1.0;
        oBiases[i] += delta;
        oBiases[i] += momentum *
          oPrevBiasesDelta[i]; // momentum
        oPrevBiasesDelta[i] = delta; // save
      }
      // ---- end Update-Weights
    } // each training item
        ++epoch;
  } // whil
} // Train2

De-modularizing makes the training method long — well over one page of code, which is usually bad. But in this rare case, the de-modularized version is superior.

This entry was posted in Machine Learning. Bookmark the permalink.