Training a PyTorch Neural Network Using an Evolutionary Algorithm

I’ve been interested in evolutionary algorithms for a long time. In the back of my mind, I had the idea of experimenting with using a evolutionary algorithm to train a PyTorch neural network. So I did.

I got a demo up and running, but a.) the results weren’t very good, and b.) the demo program was very, very complicated, and so c.) there are many complex details that need further investigation.

I was mostly motivated by curiosity. But in pragmatic terms, being able to train a neural system using an evolutionary algorithm will allow new kinds of neural architectures that can’t be trained using gradient techniques.

I used one of my standard multi-class classification problems. The goal is to predict a person’s political leaning (conservative = 0, moderate = 1, liberal = 2) from sex (male = -1, female = +1), age divided by 100), state (Michigan = 1 0 0, Nebraska = 0 1 0, Oklahoma = 0 0 1), and income (divided by $100,000). The data looks like:

 1   0.24   1   0   0   0.2950   2
-1   0.39   0   0   1   0.5120   1
 1   0.63   0   1   0   0.7580   0
-1   0.36   1   0   0   0.4450   1
. . .

In very high-level pseudo-code, my program is:

create initial population of solutions / neural weights
sort population from best to worst
loop many times
  pick two parents from population
  create a child
  mutate the child
  replace a weak solution in population with child
end-loop
return best solution / neural weights found

Each part of the pseudo-code has many alternative implementations. For example, how are two parent solutions selected? And the engineering details are tricky too. For example, how are solutions in the population sorted from best to worst?

As I developed my demo prograam, I was struck by the fact that I had to use many tricks and techniques that I’ve learned over the past six years that PyTorch has been available. For example, one of the helper functions is:

def compute_error(model, train_ds):
  X = train_ds.get_x_data()
  y = train_ds.get_y_data()
  with T.no_grad():
    oupt = model(X)  # all outputs log-soft logits form
  return -T.mean(T.diag(oupt[:,y])) 

There are only a few lines of code here but explaining the function thoroughly would require a full page of text.

Anyway, training a PyTorch neural network using an evolutionary algorithm was a fun challenge.



Science Fiction stories featuring evolution sometimes imagine that people will grow big brains. Left: In “The Sixth Finger” (1963) episode of the TV series “Outer Limits”, a scientist discovers a way to accelerate evolution and uses it on an uneducated coal miner. Center: In “This Island Earth” (1955), aliens from the planet Metaluna kidnap Earth scientists to help them in a war against the planet Zagon. Right: In “The Menagerie” (1966) episode of the TV series “Star Trek”, the people of planet Talos IV have big brains and mind control.


Demo code. Under development — almost certainly has many bugs so don’t use as-is.

# people_evo_train.py
# predict politics type from sex, age, state, income
# PyTorch 1.12.1-CPU Anaconda3-2020.02  Python 3.7.6
# Windows 10/11 

import numpy as np
import torch as T
device = T.device('cpu')  # apply to Tensor or Module

# -----------------------------------------------------------

class PeopleDataset(T.utils.data.Dataset):
  # sex  age    state    income   politics
  # -1   0.27   0  1  0   0.7610   2
  # +1   0.19   0  0  1   0.6550   0
  # sex: -1 = male, +1 = female
  # state: michigan (100), nebraska(010), oklahoma (001)
  # politics: conservative, moderate, liberal

  def __init__(self, src_file):
    all_xy = np.loadtxt(src_file, usecols=range(0,7),
      delimiter="\t", comments="#", dtype=np.float32)
    tmp_x = all_xy[:,0:6]   # cols [0,6) = [0,5]
    tmp_y = all_xy[:,6]     # 1-D

    self.x_data = T.tensor(tmp_x, 
      dtype=T.float32).to(device)
    self.y_data = T.tensor(tmp_y,
      dtype=T.int64).to(device)  # 1-D

  def __len__(self):
    return len(self.x_data)

  def __getitem__(self, idx):
    preds = self.x_data[idx]
    trgts = self.y_data[idx] 
    return preds, trgts  # as a Tuple

  def get_x_data(self):
    return self.x_data

  def get_y_data(self):
    return self.y_data

# -----------------------------------------------------------

class Net(T.nn.Module):
  def __init__(self):
    super(Net, self).__init__()
    self.hid1 = T.nn.Linear(6, 10)  # 6-(10-10)-3
    self.hid2 = T.nn.Linear(10, 10)
    self.oupt = T.nn.Linear(10, 3)

    T.nn.init.xavier_uniform_(self.hid1.weight)
    T.nn.init.zeros_(self.hid1.bias)
    T.nn.init.xavier_uniform_(self.hid2.weight)
    T.nn.init.zeros_(self.hid2.bias)
    T.nn.init.xavier_uniform_(self.oupt.weight)
    T.nn.init.zeros_(self.oupt.bias)

  def forward(self, x):
    z = T.tanh(self.hid1(x))
    z = T.tanh(self.hid2(z))
    z = T.log_softmax(self.oupt(z), dim=1)  # NLLLoss() 
    return z

# -----------------------------------------------------------

def compute_error(model, train_ds):
  X = train_ds.get_x_data()
  y = train_ds.get_y_data()
  with T.no_grad():
    oupt = model(X)  # all outputs log-soft logits form
  return -T.mean(T.diag(oupt[:,y])) 

# -----------------------------------------------------------

def train_evo(model, pop_size, dim, max_gen, init_lim, 
  mutate_prob, mutate_lim, train_ds):
  # 0. get ready
  rnd = np.random.RandomState(1)
  interval = max_gen // 10
  
  # 1. init pop: Tuple(weights, associated errors)
  pop = []  # list of tuples, tuple is (np arr, float)
  for i in range(pop_size):  # each soln / set of wts
    wts = T.tensor(rnd.uniform(low=-init_lim, high=init_lim,
      size=(dim)), dtype=T.float32)
    load_weights(model, wts)
    err = compute_error(model, train_ds)
    pop.append((wts, err))
  pop = sorted(pop, key=lambda tup: tup[1])  # sort by error
 
  # 2. find best set of wts
  best_wts = pop[0][0].clone()
  best_err = pop[0][1]

  # 3. evolve
  for gen in range(max_gen):

    # 3a. pick two parents and make a child
    first = rnd.randint(0, pop_size // 2)  # good one
    second = rnd.randint(pop_size // 2, pop_size)  # weak one
    flip = rnd.randint(2)  # 0 or 1
    if flip == 0:
      parent_idxs = (first, second)
    else:
      parent_idxs = (second, first)
  
    # 3b. create child
    child_wts = T.zeros(dim)
    i = parent_idxs[0]; j = parent_idxs[1]
    parent1 = pop[i][0]
    parent2 = pop[j][0]
    for k in range(0, dim // 2):  # left half
      child_wts[k] = parent1[k]
    for k in range(dim // 2, dim):  # right half
      child_wts[k] = parent2[k]

    # 3c. mutate child
    lo = -mutate_lim; hi = mutate_lim
    for k in range(dim):
      # flip = rnd.randint(2) 
      q = rnd.random()  # [0.0, 1.0] 
      # if flip == 0:  # mutate 
      if q "lt" mutate_prob:
        child_wts[k] += (hi - lo) * rnd.random() + lo
    load_weights(model, child_wts)
    child_err = compute_error(model, train_ds)

    # 3d. is child new best wts?
    if child_err "lt" best_err:  # replace with operator
      # print("New best soln found at gen " + str(gen))
      best_wts = child_wts.clone()
      best_err = child_err
    else:
      # print("No improvement at gen " + str(gen))
      pass

    # 3e. replace weak pop wts with child
    idx = rnd.randint(pop_size // 2, pop_size)
    pop[idx] = (child_wts, child_err)  # Tuple

    # 3f. sort solns from best to worst
    pop = sorted(pop, key=lambda tup: tup[1]) 

    # 3g. show progress
    if gen % interval == 0:
      # err = compute_error(model, train_ds)
      acc = accuracy_quick(model, train_ds)
      # print("gen = %5d  |  err = %10.4f" %  (gen, best_err))
      print("gen = %5d  |  err = %10.4f  |  \
acc = %8.4f " %  (gen, best_err, acc))
 
# -----------------------------------------------------------

def load_weights(model, wts):
  if len(wts) != (10*6) + 10 + (10*10) + 10 + (3*10) + 3:
    print("FATAL: incorrect number wts in load_weights() ")

  model.hid1.weight.data = wts[0:60].reshape((10,6))
  model.hid1.bias.data = wts[60:70]
  model.hid2.weight.data = wts[70:170].reshape((10,10))
  model.hid2.bias.data = wts[170:180]
  model.oupt.weight.data = wts[180:210].reshape((3,10))
  model.oupt.bias.data = wts[210:213]

# -----------------------------------------------------------

def accuracy(model, ds):
  # assumes model.eval()
  # item-by-item version
  n_correct = 0; n_wrong = 0
  for i in range(len(ds)):
    X = ds[i][0].reshape(1,-1)  # make it a batch
    Y = ds[i][1].reshape(1)  # 0 1 or 2, 1D
    with T.no_grad():
      oupt = model(X)  # logits form

    big_idx = T.argmax(oupt)  # 0 or 1 or 2
    if big_idx == Y:
      n_correct += 1
    else:
      n_wrong += 1

  acc = (n_correct * 1.0) / (n_correct + n_wrong)
  return acc

# -----------------------------------------------------------

def accuracy_quick(model, dataset):
  # assumes model.eval()
  X = dataset[0:len(dataset)][0]
  # Y = T.flatten(dataset[0:len(dataset)][1])
  Y = dataset[0:len(dataset)][1]
  with T.no_grad():
    oupt = model(X)  #  [40,3]  logits

  # (_, arg_maxs) = T.max(oupt, dim=1)
  arg_maxs = T.argmax(oupt, dim=1)  # argmax() is new
  num_correct = T.sum(Y==arg_maxs)
  acc = (num_correct * 1.0 / len(dataset))
  return acc.item()

# -----------------------------------------------------------

def do_acc(model, dataset, n_classes):
  X = dataset[0:len(dataset)][0]  # all X values
  Y = dataset[0:len(dataset)][1]  # all Y values
  with T.no_grad():
    oupt = model(X)  #  [40,3]  all logits

  for c in range(n_classes):
    idxs = np.where(Y==c)  # indices where Y is c
    logits_c = oupt[idxs]  # logits corresponding to Y == c
    arg_maxs_c = T.argmax(logits_c, dim=1)  # predicted class
    num_correct = T.sum(arg_maxs_c == c)
    acc_c = num_correct.item() / len(arg_maxs_c)
    print("%0.4f " % acc_c)

# -----------------------------------------------------------

def main():
  # 0. get started
  print("\nPeople predict politics evolutionary training ")
  T.manual_seed(1)
  np.random.seed(1)
  
  # 1. create DataLoader objects
  print("\nCreating People Datasets ")

  train_file = ".\\Data\\people_train.txt"
  train_ds = PeopleDataset(train_file)  # 200 rows

  test_file = ".\\Data\\people_test.txt"
  test_ds = PeopleDataset(test_file)    # 40 rows

# -----------------------------------------------------------

  # 2. create network
  print("\nCreating 6-(10-10)-3 neural network ")
  net = Net().to(device)
  net.train()

# -----------------------------------------------------------

  # 3. train model
  pop_size = 6
  dim = (10*6) + 10 + (10*10) + 10 + (3*10) + 3  # 213
  max_gen = 10_000
  init_lim = 5.0
  mutate_prob = 0.50
  mutate_lim = 0.20

  print("\nSetting pop_size = %d " % pop_size)
  print("Setting init_lim = %0.1f " % init_lim)
  print("Setting imutate_prob = %0.2f " % mutate_prob)
  print("Setting mutate_lim = %0.2f " % mutate_lim)
  print("\nStarting evolutionary training ")
  train_evo(net, pop_size, dim, max_gen, init_lim, 
    mutate_prob, mutate_lim, train_ds)
  print("Done ")
 
# -----------------------------------------------------------

  # 4. evaluate model accuracy
  print("\nComputing model accuracy")
  net.eval()
  acc_train = accuracy_quick(net, train_ds) 
  print("Accuracy on training data = %0.4f" % acc_train)
  acc_test = accuracy_quick(net, test_ds) 
  print("Accuracy on test data = %0.4f" % acc_test)

  # print("\nAccuracy on test by class (fast set technique): ")
  # do_acc(net, test_ds, 3)

  # 5. make a prediction
  print("\nPredicting politics for M  30  oklahoma  $50,000: ")
  X = np.array([[-1, 0.30,  0,0,1,  0.5000]], dtype=np.float32)
  X = T.tensor(X, dtype=T.float32).to(device) 

  with T.no_grad():
    logits = net(X)  # do not sum to 1.0
  probs = T.exp(logits)  # sum to 1.0
  probs = probs.numpy()  # numpy vector prints better
  np.set_printoptions(precision=4, suppress=True)
  print(probs)

  # 6. save model (state_dict approach)
  print("\nSaving trained model state")
  # fn = ".\\Models\\people_model.pt"
  # T.save(net.state_dict(), fn)

  # saved_model = Net()  # requires class definintion
  # saved_model.load_state_dict(T.load(fn))
  # use saved_model to make prediction(s)

  print("\nEnd People predict politics demo")

if __name__ == "__main__":
  main()
This entry was posted in PyTorch. Bookmark the permalink.

Leave a Reply

Please log in using one of these methods to post your comment:

WordPress.com Logo

You are commenting using your WordPress.com account. Log Out /  Change )

Twitter picture

You are commenting using your Twitter account. Log Out /  Change )

Facebook photo

You are commenting using your Facebook account. Log Out /  Change )

Connecting to %s