A Neural Network Mini-Batcher

When training a neural network, you can update weights after reading each training data item (usually called online training) or you can update weights after reading all data items (usually called batch training). Both types of training have their pros and cons.

An intermediate approach is to use mini-batch training. For example, if you have 100 training items, you could process 25 at a time.

I wrote a little mini-batcher class in Python. The idea is best explained with a concrete example. Suppose there are just 10 training items, indexed [0] through [9]. If the batch size is set to 3, then the first batch might be the training data item at indices (6, 2, 9) — you want to scramble the order of the indices so you don’t get stuck in a fixed pattern.

The second batch might be items (0, 7, 4) and the third batch might be items (5, 1, 3). At this point, there aren’t enough unused items to make a batch of three items (only item 8 hasn’t been used yet) so you’d reshuffle the indices and start pulling new batches of three indices.

My class is defined:

import numpy as np
import random

class Batcher:
  def __init__(self, total_size, batch_size, seed=0):
    self.indices = np.arange(total_size)
    self.tot_size = total_size
    self.bat_size = batch_size
    self.rnd = random.Random(seed)
    self.rnd.shuffle(self.indices)
    self.curr = 0

  def next_batch(self):
    if self.curr + self.bat_size > self.tot_size:
      print("Fatal logic error in next_batch()")

    result = np.zeros(shape=[self.bat_size], dtype=np.int)
    for i in range(0, self.bat_size):
      result[i] = self.indices[i + self.curr]
    
    self.curr += self.bat_size
    if self.curr + self.bat_size > self.tot_size: 
      self.rnd.shuffle(self.indices)  # reset
      self.curr = 0

    return result

The code is a bit more complicated than you might guess, but overall there’s just an array of scrambled indices and on each call to method next_batch(), an array of indices is returned.

The demo calling code is:

def main():
  print("\nBegin Batcher demo \n")

  t_size = 10
  b_size = 3
  print("Setting total num items = " + str(t_size))
  print("Setting batch size = " + str(b_size))

  batcher = Batcher(t_size, b_size)
  for i in range(0, 12):
    if i % 3 == 0:
      print("")
    batch = batcher.next_batch()
    print(batch)

  print("\nEnd demo \n")

if __name__ == "__main__":
  main()

Helper code like this isn’t particularly interesting, but when you code machine learning systems you have to deal with a certain amount of plumbing code.

Advertisements
This entry was posted in Machine Learning. Bookmark the permalink.