I Refactor the New TorchText Documentation AG News Demo

Version 0.9 of the PyTorch TorchText library was released a few days ago. The new version has many significant changes from versions 0.8 and earlier. It will take me many hours, spread out over several months, to master the new library.

But the longest journey begins with a single step. My first step was to take the documentation example, and get it to run. (If you think getting a natural language processing demo program to run is easy, you’ve never tried it before). And then after the demo was running, my next step is to refactor the code. The third through nth steps are to carefully dissect the code to gain a full understanding of what the code does.

The documentation example is a text classification problem. The source data is the AG News dataset. There are 120,000 training news articles and 7,600 test news articles. Each news article is labeled as one of four classes: 1 = “World”, 2 =”Sports”, 3 = “Business”, 4 = “Sci/Tec”.

The documentation example has code fragments and snippets extracted from a Jupyter notebook. I much prefer my computer programs to be programs. It took me about 6 hours, but I eventually refactored the fragments into a program and got the program to run.

Next, I spent another few hours on some more refactoring. The documentation demo has some very clever and good coding ideas, and a few pretty hideous coding patterns too.

I’ve pasted my code below, but it is far from anything I’d use. In other words, don’t use this code becuase it will reflect badly on your programming skills. So, I have some not-so-good code which runs but which still needs a lot of work.

But it’s a good start to the journey.

From left to right. “Journey to the Center of the Earth” (1959). “Journey to the Seventh Planet” (1962). “Journey to the Center of Time” (1967). “Journey to the Far Side of the Sun” (1969).

Code below (long).

# new_torchtext_demo.py

import numpy as np
import torch as T
import torchtext as tt
import collections
import time

device = T.device("cpu")

# -----------------------------------------------------------

def make_vocab():
  train_iter, _ = tt.datasets.AG_NEWS()
  counter = collections.Counter()
  for (label, line) in train_iter:
  result = tt.vocab.Vocab(counter, min_freq=1)
  lngth = len(result)
  return result, lngth

# -----------------------------------------------------------

tokenizer = tt.data.utils.get_tokenizer("basic_english")
vocab, vocab_size = make_vocab()

text_pipeline = lambda x: [vocab[token] \
  for token in tokenizer(x)]
label_pipeline = lambda x: int(x) - 1

# -----------------------------------------------------------

def collate_batch(batch):
  label_list, text_list, offsets = [], [], [0]
  for (_label, _text) in batch:
    processed_text = T.tensor(text_pipeline(_text), 
  label_list = T.tensor(label_list, dtype=T.int64)
  offsets = T.tensor(offsets[:-1]).cumsum(dim=0)
  text_list = T.cat(text_list)
  return label_list.to(device), text_list.to(device), \

# -----------------------------------------------------------

class TextClassificationModel(T.nn.Module):

  def __init__(self, vocab_size, embed_dim, num_class):
    super(TextClassificationModel, self).__init__()
    self.embedding = T.nn.EmbeddingBag(vocab_size, 
      embed_dim, sparse=True)
    self.fc = T.nn.Linear(embed_dim, num_class)

  def init_weights(self):
    lim = 0.5
    self.embedding.weight.data.uniform_(-lim, lim)
    self.fc.weight.data.uniform_(-lim, lim)

  def forward(self, text, offsets):
    embedded = self.embedding(text, offsets)
    return self.fc(embedded)

# -----------------------------------------------------------

def train(model, dataloader, optimizer, criterion, epoch):
  total_acc, total_count = 0, 0
  log_interval = 500
  start_time = time.time()

  for idx, (label, text, offsets) in enumerate(dataloader):
    predited_label = model(text, offsets)
    loss = criterion(predited_label, label)
    T.nn.utils.clip_grad_norm_(model.parameters(), 0.1)
    total_acc += (predited_label.argmax(1) == \
    total_count += label.size(0)
    if idx % log_interval == 0 and idx != 0:
      elapsed = time.time() - start_time
      print('| epoch {:3d} | {:5d}/{:5d} batches '
            '| accuracy {:8.3f}'.format(epoch, idx, \
                len(dataloader), total_acc/total_count))
      total_acc, total_count = 0, 0
      start_time = time.time()

# -----------------------------------------------------------

def evaluate(model, dataloader, criterion):
  total_acc, total_count = 0, 0

  with T.no_grad():
    for idx, (label, text, offsets) in enumerate(dataloader):
      predited_label = model(text, offsets)
      loss = criterion(predited_label, label)
      total_acc += (predited_label.argmax(1) == \
      total_count += label.size(0)
  return total_acc/total_count

# -----------------------------------------------------------

def predict(model, text, text_pipeline):
  with T.no_grad():
    text = T.tensor(text_pipeline(text))
    output = model(text, T.tensor([0]))
    return output.argmax(1).item() + 1

# -----------------------------------------------------------

def main():
  print("\nBegin AG News classification demo \n")


  num_classes = 4
  emsize = 64
  model = TextClassificationModel(vocab_size, \
    emsize, num_classes).to(device)

  # Hyperparameters
  EPOCHS = 3
  LR = 5.0         # learning rate
  BATCH_SIZE = 64  # for training

  criterion = T.nn.CrossEntropyLoss()
  optimizer = T.optim.SGD(model.parameters(), lr=LR)
  scheduler = T.optim.lr_scheduler.StepLR(optimizer,
    1.0, gamma=0.1)
  total_accu = None

  train_iter, test_iter = tt.datasets.AG_NEWS()  # reset

  train_dataset = list(train_iter)
  test_dataset = list(test_iter)
  num_train = int(len(train_dataset) * 0.95)
  split_train_, split_valid_ = \
      [num_train, len(train_dataset) - num_train])

  train_dataloader = T.utils.data.DataLoader(split_train_, \
    batch_size=BATCH_SIZE, shuffle=True, 
  valid_dataloader = T.utils.data.DataLoader(split_valid_, \
    batch_size=BATCH_SIZE, shuffle=True, 
  test_dataloader = T.utils.data.DataLoader(test_dataset, \
    batch_size=BATCH_SIZE, shuffle=True, 

# -----------------------------------------------------------

  print("Starting training \n")

  for epoch in range(1, EPOCHS + 1):
    epoch_start_time = time.time()
    train(model, train_dataloader, optimizer, 
      criterion, epoch)
    accu_val = evaluate(model, valid_dataloader, criterion)
    # replace "gt" with Boolean operator symbol
    if total_accu is not None and total_accu "gt" accu_val:
       total_accu = accu_val
    print('-' * 59)
    print('| end of epoch {:3d} | time: {:5.2f}s | '
          'valid accuracy {:8.3f} '.format(epoch,
           time.time() - epoch_start_time, accu_val))
    print('-' * 59)
  print("\nDone ")

# -----------------------------------------------------------

  print("\nComputing model accuracy on test dataset ")
  accu_test = evaluate(model, test_dataloader, criterion)
  print("test accuracy = {:8.3f}".format(accu_test)) 

  ag_news_label = {1: "World", 2: "Sports", 3: "Business",
    4: "Sci/Tec"}

  text_str = "Last night the Lakers beat the Rockets by " \
    + "a score of 100-95. John Smith scored 23 points."
  print("\nInput text: ")

  c = predict(model, text_str, text_pipeline)
  print("\nPredicted class: " + str(c))

  print("\nEnd demo \n")

# -----------------------------------------------------------

if __name__ == "__main__":
This entry was posted in PyTorch. Bookmark the permalink.

1 Response to I Refactor the New TorchText Documentation AG News Demo

  1. Thorsten Kleppe says:

    A 5 minutes topic is this not. Definitely an impressive demo and a brave post.
    I like the new time info. Quicker testing leads to better models.

    The dataset can be found here.

Leave a Reply

Please log in using one of these methods to post your comment:

WordPress.com Logo

You are commenting using your WordPress.com account. Log Out /  Change )

Google photo

You are commenting using your Google account. Log Out /  Change )

Twitter picture

You are commenting using your Twitter account. Log Out /  Change )

Facebook photo

You are commenting using your Facebook account. Log Out /  Change )

Connecting to %s