Serving Up PyTorch Training Data Using The DataLoader collate_fn Parameter

When creating a deep neural network, writing code to prepare the training data and serve it up in batches to the network is almost always difficult and time consuming. A regular PyTorch DataLoader works great for tabular style data where all inputs have the same length. For variable length inputs, one approach is to use a DataLoader with a custom collate_fn function.

I was looking at an example of natural language processing which uses an EmbeddingBag layer. Briefly, an EmbeddingBag treats a sentence as one whole unit and converts the sentence into a vector of numeric values. This is different from a regular Embedding layer that converts each word in a sentence to a numeric vector.

To handle the training data I needed to use a custom DataLoader. A regular DataLoader accepts a PyTorch Dataset object, which must be implemented to fetch one item at a time. A custom DataLoader accepts a list of tuples and uses a program-defined collate_fn function to parse the list of tuples.

The idea is quite tricky and is best explained by example. I started with 8 movie reviews:

0, This was a BAD movie.
1, I liked this film! Highly recommended.
0, Just awful
1, Good film, acting
0, Don't waste your time - A real dud
0, Terrible
1, Great movie.
0, This was a waste of talent.

A label of 0 means negative review, 1 means positive review. The first step is to create of Vocabulary object that maps each possible word and punctuation character to a numeric index. That’s a difficult problem in itself.

The next step is to read the data into memory as a list of tuples. The resulting list would be:

[(0, This was a BAD movie.),
 (1, I liked this film! Highly recommended.),
 . .
 (0, This was a waste of talent.)]

Next, the goal is write code to batch up batch_size=3 reviews at a time and arrange them as a vector of labels, a vector of reviews converted to indexes, and a vector of offsets that indicate where each review starts. For example, the first batch would be:

labels : tensor([0, 1, 0])
reviews: tensor([ 4,  7,  3, 13,  6,  2,
                 19, 21,  4,  5,  9, 18, 24,  2,
                 20, 12])
offsets: tensor([ 0,  6, 14])

The first review is this=4, was=7, a=3, bad=13, movie=6, (period)=2. The three movies reviews are combined into a single tensor. The first review starts at [0], the second review starts at [6], and the third review starts at [14].

The code to convert the list of tuples into the desired format must be implemented in a collate_fn function that is passed to the DataLoader. Writing the collate_fn function was very tricky and difficult, and took me several hours.

The moral of the story is that working with natural language processing training data is difficult.

One of the many reasons why natural language processing problems are difficult is that English language words can have multiple meanings. The Merriam-Webster dictionary lists 23 different meanings for the word “model”. Left: The Ford Model T, built from 1908 to 1927, was the first affordable automobile. Left center: A fashion model is often an aspiration for what an ideal woman might look like. Right center: A scale model of a Hobbit house. Right: This image popped up from an Internet search for “machine learning model”. I must be doing something wrong — my ML models do not manifest themselves as glowing balls of energy.

Code below. Long.


# import numpy as np
import torch as T
import torchtext as tt
import collections

device = T.device("cpu")

# data file looks like:
# 0, This was a BAD movie.
# 1, I liked this film! Highly recommeneded.
# 0, Don't waste your time - a real dud
# 1, Good film. Great acting.
# 0, This was a waste of talent.
# 1, Great movie. Good acting.

def make_vocab(fn):
  toker ="basic_english")  # local
  counter_obj = collections.Counter()
  f = open(fn, "r")
  for line in f:
    line = line.strip()
    # print(line)
    txt = line.split(",")[1]
    split_and_lowered = toker(txt)
  result = tt.vocab.Vocab(counter_obj, min_freq=1)
  return result

# globals are needed for the collate_fn() function
g_tokenizer ="basic_english")  # global tokenizer
g_vocab = make_vocab(".\\Data\\reviews.txt")  # global vocabulary

def make_data_list(fn):
  # get all data into one big list of (label, review) tuples
  # result will be passed to DataLoader
  result = []
  f = open(fn, "r")
  for line in f:
    line = line.strip()
    parts = line.split(",")
    tpl = (parts[0], parts[1])
  return result 

def collate_data(batch):
  # rearrange a batch and compute offsets too
  # needs a global vocab and tokenizer
  label_lst, review_lst, offset_lst = [], [], [0]
  for (_lbl, _rvw) in batch:  # batch = 2 items
    label_lst.append(int(_lbl))  # string to int
    rvw_idxs = [g_vocab[tok] for tok in g_tokenizer(_rvw)]  # list of idxs
    rvw_idxs = T.tensor(rvw_idxs, dtype=T.int64)  # to tensor
    # offset_lst.append(rvw_idxs.size(0))

  label_lst = T.tensor(label_lst, dtype=T.int64).to(device)  # convert to tensor
  offset_lst = T.tensor(offset_lst[:-1]).cumsum(dim=0).to(device)    # whoa!
  review_lst =  # combine 2 tensors into 1

  return (label_lst, review_lst, offset_lst)

def main():
  print("\nBegin make PyTorch variable DataLoader demo ")

  print("\nLoading train data into meta-list: ")
  data_lst = make_data_list(".\\Data\\reviews.txt")

  print("\nCreating DataLoader from meta-list ")
  train_ldr =, \
    batch_size=3, shuffle=False, collate_fn=collate_data)

  print("\nServing up batches (size = 3): ")
  for b_ix, (labels, reviews, offsets) in enumerate(train_ldr):
    print("batch  : ", b_ix)
    print("labels : ", end=""); print(labels)
    print("reviews: ", end=""); print(reviews)
    print("offsets: ", end=""); print(offsets)
    # input()
  print("\nEnd demo ")

if __name__ == "__main__":
This entry was posted in PyTorch. Bookmark the permalink.

Leave a Reply

Please log in using one of these methods to post your comment: Logo

You are commenting using your account. Log Out /  Change )

Twitter picture

You are commenting using your Twitter account. Log Out /  Change )

Facebook photo

You are commenting using your Facebook account. Log Out /  Change )

Connecting to %s