Creating a Custom torchtext Dataset from a Text File

When working on natural language processing problems, preparing the data is always a nightmare. The PyTorch torchtext library has functions for text processing. But virtually every example on the Internet uses built-in datasets such as torchtext.datasets.WikiText2. In any realistic scenario, you need to create a Dataset from your own data. I decided to explore how to create a custom dataset using torchtext. When I searched the Internet, I found very few examples, which raised some red flags for me. These red flags were warranted, as I’ll explain.

The process for creating custom datsets varies depending on the type of problem you’re doing. My demo assumes a sequence-to-value scenario. I made some dummy movie reviews, each of which has a label from 0 (very bad) to 4 (excellent). The goal would be to predict a movie rating from a review. A sequence-to-value scenario is the simplest and you really need to understand it in order to understand more complicated problems like sequence-to-sequence (for example, for English to German translation).



Warning: Before I go any further — at the time I’m writing this blog post in January 2021, the torchtext library is under massive revision and almost all of the key classes are being deprecated. So the code presented here will likely be completely out of date by the time you read this. So, why am I wasting my time? I’m hoping that learning how to create a custom dataset using the old approach will help me understand the new approach when it becomes available in a few weeks or months.

At a high level, the idea is to start with a text file(s) of source data, create a TabularDataset object(s), then create iterators which can be used to batch and serve up the data.

First, I created a training data file:

001,Film was terrible,0
002,The movie was pretty good,3
003,Excellent experience in every way!,4
004,Not too bad at all,2
005,OK,1

I also created a validation file and a test file:

006,A great movie,4
007,Decent but not great,2
008,Worst movie in a long time,0
009,Terrible movie,0
010,Experience was good,3
011,An excellent film,4

Each file has three comma-delimited fields — an ID, a movie review, and a rating label from 0 to 4. In a non-demo scenario, I’d probably use tab-delimited fields so that the review text could contain comma characters if necessary.

The first step in the demo program is to define RAW, TEXT, and LABEL Field objects that describe how to parse each field, and store the field data in memory. It’s customary to capitalize but the field names are arbitrary, so I could have named the Field objects ID, REVIEW, RATING, or something similar.

print("\nCreating RAW, TEXT, LABEL Field objects ")
RAW = tt.data.RawField()
TEXT = tt.data.Field(sequential=True,
  init_token='(sos)',  # start of sequence
  eos_token='(eos)',   # replace parens with less, greater
  lower=True,
  tokenize=tt.data.utils.get_tokenizer("basic_english"),)
LABEL = tt.data.Field(sequential=False,
  use_vocab=False,
  unk_token=None,
  is_target=True)

The Field class is very complex in itself (over 20 parameters) and a full-ish explanation would require several pages, but you can infer most of the key ideas by looking at the code.

Next, the three source files are read into memory and three datasets are created:

print("\nSplitting data files into train, valid, test objects ")
(train_obj, valid_obj, test_obj) = tt.data.TabularDataset.splits(
  path=".\\.data",
  train='train.csv',
  validation='validation.csv',
  test='test.csv',
  format='csv',
  fields=[('id', RAW), ('review', TEXT), ('label', LABEL)])

The TabularDataset is best for sequence-to-value problems. There are three other dataset types for other scenarios (and like everything else, all are deprecated). This demo creates all three TabularDataset objects (train_obj, valid_obj, test_obj) at the same time using the ugly splits() method. It would be more clear, but require three times as much code, to create the three dataset objects one at a time. Just like the Field class, TabularDataset is very complex, but this example shows most of the key ideas.

Next, the demo creates a single vocabulary, using just the training dataset object. The vocabulary allows raw words to be converted into index values.

print("\nCreating vocabulary dict from train data object ")
TEXT.build_vocab(train_obj)
print("The idx of \'good\' is ", TEXT.vocab.stoi['good'])  # 13
print("The string value of 8 is ", TEXT.vocab.itos[8])     # 'bad'

Right about this point in my exploration I started to realize why almost all of torchtext is being revised: the APIs are really bad. If this interface doesn’t make much sense to you, you are correct — the design is very weak.

Next, the demo creates a BucketIterator object for the training data and then uses it to serve up the training data in batches of size 2, in random order. The BucketIterator does a good job of batching item with similar length in order to minimize the amount of padding needed in each batch (so that all items in a batch have the same length).

print("\nCreating a BucketIterator on the train_object ")

train_iter  = tt.data.BucketIterator(
  dataset=train_obj,
  batch_size = 2,
  sort_key=lambda x: len(x.review),
  shuffle=True,
  device=device)

print("\nIterating train data (batch_size=2) ")
for item in train_iter:
  print("\n=====\n")
  print(item.id)
  print(item.review)
  print(item.label)

The third batch of two items looks like:

['001', '005']
tensor([[ 2,  2],
        [12, 17],
        [ 4,  3],
        [19,  1],
        [ 3,  1]])
tensor([0, 1])

The items are stored in columns so the first item in the batch is ID 001, (2, 12, 4, 9, 3) where 2 = “start-of-sequence”, 12 = film, 4 = was, 19 = terrible, 3 = “end-of-sequence”. The second item in the batch has two 1 values which are padding.

At this point, you’d probably have to do some reshaping to get the batch in the right shape for your neural network.

I’ve left out a ton of details. For example, there are many alternatives to the tokenizer used in the TEXT Field. But this example should get you up and running if you want to create a custom torchtext dataset for a sequence-to-value problem. In a non-demo scenario, preparing data for NLP can take many days or weeks.



Three country’s flags from an Internet search for “red flags”. I like flags so I when I saw these I knew which country was represented by each flag. From left to right: Denmark, China, Albania.


# demo_torchtext.py
# torchtext 0.8 - undergoing massive rearchitecting

import warnings   # being significantly revamped
import random
import numpy as np
import torch as T
import torchtext as tt

device = T.device("cpu")
warnings.filterwarnings("ignore")  # else warnings spew
random.seed(2)
T.manual_seed(1)
np.random.seed(1)

print("\nBegin torchtext from raw data demo ")

print("\nCreating RAW, TEXT, LABEL Field objects ")
RAW = tt.data.RawField()
TEXT = tt.data.Field(sequential=True,
  init_token='',  # start of sequence
  eos_token='',   # end of sequence
  lower=True,
  tokenize=tt.data.utils.get_tokenizer("basic_english"),)
LABEL = tt.data.Field(sequential=False,
  use_vocab=False,
  unk_token=None,
  is_target=True)

print("\nSplitting into train, valid, test ")
(train_obj, valid_obj, test_obj) = \
  tt.data.TabularDataset.splits(
  path=".\\.data",
  train='train.csv',
  validation='validation.csv',
  test='test.csv',
  format='csv',
  fields=[('id', RAW), ('review', TEXT),
    ('label', LABEL)])

print("\nThe \'review\' field for item [2] is: ")
print(train_obj[2].review)

print("\nCreating vocabulary object ")
TEXT.build_vocab(train_obj)
print("The idx of \'good\' is ",
  TEXT.vocab.stoi['good'])  # 13
print("The string value of 8 is ",
  TEXT.vocab.itos[8])  # 'bad'
 
print("\nCreating a train BucketIterator ")

train_iter  = tt.data.BucketIterator(
  dataset=train_obj,
  batch_size = 2,
  sort_key=lambda x: len(x.review),
  shuffle=True,
  device=device)

print("\nIterating train data (batch_size=2) ")
for item in train_iter:
  print("\n=====\n")
  print(item.id)
  print(item.review)
  print(item.label)
  
print("\nEnd of demo ")
This entry was posted in PyTorch. Bookmark the permalink.

Leave a Reply

Please log in using one of these methods to post your comment:

WordPress.com Logo

You are commenting using your WordPress.com account. Log Out /  Change )

Google photo

You are commenting using your Google account. Log Out /  Change )

Twitter picture

You are commenting using your Twitter account. Log Out /  Change )

Facebook photo

You are commenting using your Facebook account. Log Out /  Change )

Connecting to %s