Simplex Optimization using Python

Simplex optimization is a technique to find the minimum value of some function. In most situations the goal is to find values that minimize some sort of error.

In simplex optimization, you have three virtual points, where each point represents a possible solution. Each point has an associated error so there’ll be a best point, a worst point, and an “other” point at any given time. The three points form a triangle (in fancy math terms a simplex is a triangle, hence the name of the technique).

SimplexOptimizationUsingPython

In very high level pseudo-code simplex optimization resembles:

loop
  compute a centroid from the other and best
  create expanded, reflected,
    and contracted points
  if any are better than worst, replace worst
  otherwise shrink worst, other towards best
end loop

Geometrically, the centroid is a point midway between the best point and the other point. The expanded and reflected points search outside the current triangle. The contracted point searches inside the triangle.

If you graph the positions of the three points in each iteration, you get what looks like a triangle creeping along until it surrounds the target location, and then the triangle shrinks around the target. It sort of resembles the movement of a single-celled amoeba, so simplex optimization is also called amoeba method simulation. And one specific variation is called the Nelder-Mead algorithm.

I coded up an implementation of simplex optimization to solve f(x) = x0^2 + x1^2 + x3^2 + x4^2 = 0 in Python:

# simplex.py
# python 3.4.3
# demo of simplex optimization
# aka amoeba method optimization
# solves x0^2 + x1^2 + x2^2 + . . . = 0
# (the 'Sphere' function)

import random
import math    # sqrt

# ------------------------------------

def show_vector(vector):
  for i in range(len(vector)):
    if i % 8 == 0: # 8 columns
      print("\n", end="")
    if vector[i] >= 0.0:
      print(' ', end="")
    print("%.4f" % vector[i], end="") # 4 decimals
    print(" ", end="")
  print("\n")

# ------------------------------------

def error(position):
  # Euclidean distance to (0, 0, .. 0)
  dim = len(position)
  target = [0.0 for i in range(dim)]
  dist = 0.0
  for i in range(dim):
    dist += (position[i] - target[i])**2
  return math.sqrt(dist)

# ------------------------------------

class Point:
  def __init__(self, dim, minx, maxx):
    self.position = [0.0 for i in range(dim)]

    for i in range(dim):
      self.position[i] = ((maxx - minx) *
        random.random() + minx)

    self.error = error(self.position) # curr error

# ------------------------------------

def Solve(dim, max_epochs, minx, maxx):
  points = [Point(dim, minx, maxx) for i in range(3)] # 3 points

  for i in range(dim): points[0].position[i] = minx
  for i in range(dim): points[2].position[i] = maxx

  best_idx = -1
  other_idx = -1
  worst_idx = -1

  centroid = [0.0 for i in range(dim)]
  expanded = [0.0 for i in range(dim)]
  reflected = [0.0 for i in range(dim)]
  contracted = [0.0 for i in range(dim)]
  arbitrary = [0.0 for i in range(dim)]

  epoch = 0
  while epoch < max_epochs:
    epoch += 1
        
    # identify best, other, worst
    if (points[0].error < points[1].error and
    points[0].error < points[2].error):
      if points[1].error < points[2].error:
        best_idx = 0; other_idx = 1; worst_idx = 2
      else:
        best_idx = 0; other_idx = 2; worst_idx = 1
    elif (points[1].error < points[0].error and
    points[1].error < points[2].error):
      if points[0].error < points[2].error:
        best_idx = 1; other_idx = 0; worst_idx = 2
      else:
        best_idx = 1; other_idx = 2; worst_idx = 0
    else:
      if points[0].error < points[1].error:
        best_idx = 2; other_idx = 0; worst_idx = 1
      else:
        best_idx = 2; other_idx = 1; worst_idx = 0

    if epoch <= 9 or epoch >= 30:
      print("--------------------")
      print("epoch = " + str(epoch) + " ", end="")
      print("best error = ", end="")
      print("%.6f" % points[best_idx].error, end="")

    if epoch == 10:
      print("--------------------")
      print(" . . . ")

    if points[best_idx].error < 1.0e-4:
      if epoch <= 9 or epoch >= 30:
        print(" reached small error. halting")
      break;

    # make the centroid
    for i in range(dim):
      centroid[i] = (points[other_idx].position[i] +
      points[best_idx].position[i]) / 2.0

    # try the expanded point
    for i in range(dim):
      expanded[i] = centroid[i] + (2.0 * (centroid[i] -
      points[worst_idx].position[i]))
    expanded_err = error(expanded)
    if expanded_err < points[worst_idx].error:
      if epoch <= 9 or epoch >= 30:
        print(" expanded found better error than worst error")
      for i in range(dim): 
        points[worst_idx].position[i] = expanded[i]
      points[worst_idx].error = expanded_err
      continue

    # try the reflected point
    for i in range(dim):
      reflected[i] = centroid[i] + (1.0 * (centroid[i] -
      points[worst_idx].position[i]))
    reflected_err = error(reflected)
    if reflected_err < points[worst_idx].error:
      if epoch <= 9 or epoch >= 30:
        print(" reflected found better error than worst error")
      for i in range(dim):
        points[worst_idx].position[i] = reflected[i]
      points[worst_idx].error = reflected_err
      continue

    # try the contracted point
    for i in range(dim):
      contracted[i] = centroid[i] + (-0.5 * (centroid[i] -
      points[worst_idx].position[i]))
    contracted_err = error(contracted)
    if contracted_err < points[worst_idx].error:
      if epoch <= 9 or epoch >= 30:
        print(" contracted found better error than worst error")
      for i in range(dim):
        points[worst_idx].position[i] = contracted[i]
      points[worst_idx].error = contracted_err
      continue

    # try a random point
    for i in range(dim):
      arbitrary[i] = ((maxx - minx) * random.random() + minx)
    arbitrary_err = error(arbitrary)
    if arbitrary_err < points[worst_idx].error:
      if epoch <= 9 or epoch >= 30:
        print(" arbitrary found better error than worst error")
      for i in range(dim):
        points[worst_idx].position[i] = arbitrary[i]
      points[worst_idx].error = arbitrary_err
      continue

    # could not find better point so shrink worst and other
    if epoch <= 9 or epoch >= 30:
      print(" shrinking")
    # 1. worst -> best
    for i in range(dim):
      points[worst_idx].position[i] = (points[worst_idx].position[i]
      + points[best_idx].position[i]) / 2.0
    points[worst_idx].error = error(points[worst_idx].position)

    # 2. other -> best
    for i in range(dim):
      points[other_idx].position[i] = (points[other_idx].position[i]
      + points[best_idx].position[i]) / 2.0
    points[other_idx].error = error(points[other_idx].position)

  # end-while

  print("--------------------")
  print("\nBest position found=")
  show_vector(points[best_idx].position)

# ------------------------------------

print("\nBegin simplex optimization using Python demo\n")
dim = 5
random.seed(0)

print("Goal is to solve the Sphere function in " +
 str(dim) + " variables")
print("Function has known min = 0.0 at (", end="")
for i in range(dim-1):
  print("0, ", end="")
print("0)")

max_epochs = 1000

print("Setting max_epochs    = " + str(max_epochs))
print("\nStarting simplex algorithm\n")

Solve(dim, max_epochs, -10.0, 10.0)

print("\nSimplex algorithm complete")

print("\nEnd simplex optimization demo\n")

There are a lot of small but important details in the code. Simplex optimization is one of the oldest swarm optimization techniques, and it’s not used too often.

This entry was posted in Machine Learning. Bookmark the permalink.