Implementing the BFGS Minimization Function using C#

Last weekend, just for fun, I decided I’d take a look at the BFGS (Broyden–Fletcher–Goldfarb–Shanno) numeric minimization algorithm. The algorithm finds the minimum value of a function. To get started, I examined the C language implementation presented in the book “Numerical Recipes in C” (NR). The function is listed section 10.7 and is called dfpsrch (“DFP search) because BFGS is really just a slight variation of the earlier DFP (Davidon-Fletcher-Powell) algorithm. The NR implementation calls helper function lnsrch (“line search”) in section 9.7.

So, I refactored the NR C language code to C#. It was quite interesting in the sense that I had to sharpen up my C skill to understand what was going on. The NR code is, basically, crazy.

BFGS

The BFGS algorithm (not to be confused with the even more complicated L-BFGS algorithm (“limited memory” version) is based on calculus techniques such as function gradients (first derivatives) and the Hessian matrix of second partial derivatives. The most striking thing about BFGS is the number of ways that the function can fail. As it turns out, robust implementations of L-BFGS and BFGS exist but they have a huge amount of special-case checks in the code. My point is that the NR code for BFGS is an OK way to start understanding the algorithm, but, in addition to being copyrighted, the NR code is not suitable for anything but experimentation because the code can fail in dozens of situations.

using System;
namespace BFGSExperiments
{
  class Program
  {
    static void Main(string[] args)
    {
      try
      {
        Console.WriteLine("\nBegin\n");

        int iter;
        double fret;
        double[] p = new double[] {1, 1};
        Minimize(p, 2, 0.0001, out iter, out fret,
          FunctionToMinimize, GradientOfFunction);

        Console.WriteLine("Iterations = " + iter);
        Console.WriteLine("Function min value = " +
          fret);
        Console.WriteLine("Minimized at " + p[0] +
          " " + p[1]);

        Console.WriteLine("\nEnd\n");
        Console.ReadLine();
      }
      catch (Exception ex)
      {
        Console.WriteLine(ex.Message);
        Console.ReadLine();
      }


    } // Main

    public delegate double SomeFunction(double[] fvals);
    public delegate void GradFunction(double[] x,
      double[] grads);

    // -----------------------------------------------------------

    public static double FunctionToMinimize(double[] x)
    {
      return -Math.Exp(-Math.Pow(x[0] - 1, 2)) -
        Math.Exp(-0.5 * Math.Pow(x[1] - 2, 2));
      // has min at x = 1, y = 2
    }

    public static void GradientOfFunction(double[] x,
      double[] grads)
    {
      grads[0] = 2 * Math.Exp(-Math.Pow(x[0] - 1, 2)) *
        (x[0] - 1);
      grads[1] = Math.Exp(-0.5 * Math.Pow(x[1] - 2, 2)) *
        (x[1] - 2);
    }

    // -----------------------------------------------------------

    public static void Minimize(double[] p, int n,
      double gtol, out int iter, out double fret,
      SomeFunction func, GradFunction dfunc)
    {
      // aka dfpmin, aka BFGS
      // starting point p[] of length n, minimize func, 
      // using its gradient gfunc
      // returns are p[] (location of min value),
      // iter (number iterations performed),
      // fret (min value of function at p[])
      const int ITMAX = 200;
      const double EPS = 3.0e-8;
      const double TOLX = 4.0 * EPS;
      const double STPMX = 100.0;

      int check, i, its, j;
      double den, fac, fad, fae, fp, stpmax, sum,
        sumdg, sumxi, temp, test;
      double[] dg, g, hdg, pnew, xi;
      double[][] hessin;

      iter = 0;
      sum = 0.0; // keep compiler happy

      dg = new double[n];
      g = new double[n];
      hdg = new double[n];
      hessin = MakeMatrix(n, n);
      pnew = new double[n];
      xi = new double[n];
      fp = func(p);
      Console.WriteLine("starting fp = " + fp);
      dfunc(p, g);

      Console.WriteLine("starting Grads: g[0] = " +
        g[0] + " g[1] = " + g[1]);
      Console.ReadLine();

      for (i = 0; i < n; ++i)  {
        for (j = 0; j < n; ++j) hessin[i][j] = 0.0;
        hessin[i][i] = 1.0;
        xi[i] = -g[i];
        sum += p[i] * p[i];
      }
      
      stpmax = STPMX * Math.Max(Math.Sqrt(sum),
        (double)n);
      for (its = 1; its <= ITMAX; ++its) // main loop
      {
        iter = its;
        LineSearch(n, p, fp, g, xi, pnew, out fret, 
          stpmax, out check, func);

        fp = fret;
        Console.WriteLine("fp in loop = " + fp);
        for (i = 0; i < n; ++i)  {
          xi[i] = pnew[i] - p[i];
          p[i] = pnew[i];
        }
        Console.WriteLine("New p0 p1 = " +
          p[0] + " " + p[1]);
        Console.ReadLine();
        test = 0.0;
        for (i = 0; i  test) test = temp;
        }

        if (test < TOLX) {
          Console.WriteLine("Exiting when test = " +
            test + " < tolx = " + TOLX);
          return;
        }

        for (i = 0; i < n; ++i) dg[i] = g[i];
        dfunc(p, g);
        test = 0.0;
        den = Math.Max(fret, 1.0);
        for (i = 0; i  test) test = temp;
        }
        if (test < gtol)  {
          Console.WriteLine("Exiting when test = " +
            test + " < gtol = " + gtol);
          return;
        }

        for (i = 0; i < n; ++i) dg[i] = g[i] - dg[i];
        for (i = 0; i < n; ++i) {
          hdg[i] = 0.0;
          for (j = 0; j < n; ++j)
            hdg[i] += hessin[i][j] * dg[j];
        }
        fac = fae = sumdg = sumxi = 0.0;
        for (i = 0; i  Math.Sqrt(EPS * sumdg * sumxi))
        {
          fac = 1.0 / fac;
          fad = 1.0 / fae;
          for (i = 0; i < n; ++i)
            dg[i] = fac * xi[i] - fad * hdg[i];
          for (i = 0; i < n; ++i) {
            for (j = 0; j < n; ++j) {
              hessin[i][j] += fac * xi[i] * xi[j]
                - fad * hdg[i] * hdg[j] +
                  fae * dg[i] * dg[j];
              hessin[j][i] = hessin[i][j];
            }
          }
        }

        for (i = 0; i < n; ++i) {
          xi[i] = 0.0;
          for (j = 0; j < n; ++j)
            xi[i] -= hessin[i][j] * g[j];
        }


      } // main loop
      throw new Exception("Too many iterations " + iter +
        " in Minimize");

    } // Minimize

    public static double[][] MakeMatrix(int rows, int cols)
    {
      double[][] result = new double[rows][];
      for (int i = 0; i < rows; ++i)
        result[i] = new double[cols];
      return result;
    }

    public static void LineSearch(int n, double[] xold,
      double fold, double[] g, double[] p, double[] x,
      out double f, double stpmax, out int check,
       SomeFunction func)
    {
      // aka lnsrch

      const double ALF = 1.0e-4;
      const double TOLX = 1.0e-7;

      int i;
      double a, alam, alam2, alamin, b, disc, f2, rhs1,
      rhs2, slope, sum, temp, test, tmplam;

      check = 0;
      for (sum = 0, i = 0; i  stpmax)
        for (i = 0; i < n; ++i)
          p[i] *= stpmax / sum;
 
      for (slope = 0.0, i = 0; i = 0.0)
        throw new Exception("Roundoff problem LineSearch");

      test = 0.0;
      for (i = 0; i  test) test = temp;
      }
      alamin = TOLX / test;
      alam = 1.0;
      
      int sanityCt = 0;
      int maxSanity = 1000000;
      f = 0.0; // keeps compiler happy
      f2 = 0.0;
      alam2 = 0.0;
      while (true && sanityCt < maxSanity)
      {
        ++sanityCt;

        for (i = 0; i < n; ++i)
          x[i] = xold[i] + alam * p[i];
        f = func(x);
        if (alam < alamin) {
          for (i = 0; i < n; ++i)
            x[i] = xold[i];
          check = 1;
          return;
        } else if (f <= fold + ALF * alam * slope)
            return;
        else  {
          if (alam == 1.0)
            tmplam = -slope / (2.0 * (f - fold - slope));
          else  {
            rhs1 = f - fold - alam * slope;
            rhs2 = f2 - fold - alam2 * slope;
            a = (rhs1 / (alam * alam) - rhs2 /
              (alam2 * alam2)) / (alam - alam2);
            b = (-alam2 * rhs1 / (alam * alam) + alam *
              rhs2 / (alam2 * alam2)) /
              (alam - alam2);
            if (a == 0.0) tmplam = -slope / (2.0 * b);
            else  {
              disc = b * b - 3.0 * a * slope;
              if (disc < 0.0) tmplam = 0.5 * alam;
              else if (b  0.5 * alam)
              tmplam = 0.5 * alam;
          } // else

        } // else
        alam2 = alam;
        f2 = f;
        alam = Math.Max(tmplam, 0.1 * alam);
      } // while

      if (sanityCt == maxSanity)
        throw new Exception("Insane in LineSearch");
    } // LineSearch
  } // Program
} // ns
Advertisements
This entry was posted in Machine Learning. Bookmark the permalink.