A Naive Transformer Architecture for MNIST Classification Using PyTorch

Transformer architecture neural networks are very complex. Unlike some of my colleagues, I’m not a naturally brilliant guy. But my primary strength is persistence. I continue to probe the complexities of transformer systems, one example at a time, day after day, week after week.

I wondered if I could put together a system that uses a TransformerEncoder, in a completely naive way, for image classification. After a day of experiments, I got a system working. The classification accuracy wasn’t very good, but at least the system worked. And most importantly, I learned new things about transformer systems.

For my demo, I tackled the MNIST handwritten digits dataset. Each data item is a grayscale (0 = white to 255 = black), 28 by 28 (784) pixels, image of a digit from ‘0’ to ‘9’. The full MNIST dataset has 60,000 training images and 10,000 test images. I used a small subset of just 1,000 training images and 100 test images. See https://jamesmccaffrey.wordpress.com/2022/01/21/working-with-mnist-data/.

My transformer system is naive because it accepts each set of 784 pixel values as a one-dimensional vector. Most transformer systems are designed for natural language processing and the input is a sequence of integer tokens where each integer represents a word. In my naive system, each integer is a pixel value. Therefore my “vocabulary size” is 256.

In NLP, each word/token/integer is mapped to a word embedding, typically about 100 to 500 values. In my naive system, I’m not sure if the idea of a pixel embedding makes sense, but I used an embed_dim of 4 just as a wild guess.

Transformer architecture systems have a hellacious number of hyperparameters. I used guesses for these — nhead, dropout, num_layers, etc., etc. The network definition I used is:

class TransformerNet(T.nn.Module):
  def __init__(self):
    # vocab_size = 256, embed = 4
    super(TransformerNet, self).__init__() 
    self.embed = T.nn.Embedding(256, 4)  # word embedding
    self.pos_enc = \
      PositionalEncoding(4, dropout=0.20)  # positional
    self.enc_layer = T.nn.TransformerEncoderLayer(d_model=4,
      nhead=2, dim_feedforward=100, 
      batch_first=True)  # d_model divisible by nhead
    self.trans_enc = T.nn.TransformerEncoder(self.enc_layer,
      num_layers=6)  # 6 layers default
    self.fc1 = T.nn.Linear(4*784, 10)  # 10 classes

  def forward(self, x):
    # x = 784 pixels. length = fixed.
    z = self.embed(x)  # pixels to embed vector
    z = z.reshape(-1, 784, 4)  # bat seq embed 
    z = self.pos_enc(z) 
    z = self.trans_enc(z) 
    z = z.reshape(-1, 4*784)  # torch.Size([bs, xxx])
    z = T.log_softmax(self.fc1(z), dim=1)  # NLLLoss()
    return z 

In addition to all the architecture hyperparameters, there are all the usual training hyperparameters — optimization algorithm, its learning rate, batch size, etc., etc. Again, I used guesses for these values.

In the end, the system scored 89.10 percent accuracy on the training data and 86.00 accuracy on the test data. This isn’t very good. A convolutional architecture can easily score close to 98 percent accuracy.

But at least the demo transformer system seemed to work. And I took one more step forward in understanding these super-prediction beasts.



One of the things I love about machine learning is that a person can take an idea and transform it into the reality of a working computer program. Fiction writing has the same characteristic — an idea transformed into the reality of a printed story.

“Super-Science Fiction” was a magazine published from 1956 to 1959. There were only 18 issues and most of the stories weren’t very good. But several famous authors and artists got their starts at the magazine. Left: Cover by artist F. Kelly Freas. Center: Also by Freas. Right: By artist Edmund Emshwiller.


Demo code.

# mnist_transformer.py
# PyTorch 1.12.1-CPU Anaconda3-2020.02  Python 3.7.6
# Windows 10/11

# naive Transformer architecture for MNIST

# reads MNIST subset data from text file rather than using
# built-in black box Dataset from torchvision

import numpy as np
import matplotlib.pyplot as plt
import torch as T

device = T.device('cpu')
T.set_num_threads(1)

# -----------------------------------------------------------

class MNIST_Dataset(T.utils.data.Dataset):
  # 784 tab-delim pixel values (0-255) then label (0-9)
  def __init__(self, src_file):
    all_xy = np.loadtxt(src_file, usecols=range(785),
      delimiter="\t", comments="#", dtype=np.float32)

    tmp_x = all_xy[:, 0:784]  # all rows, cols [0,783]
    # no pixel normalization or reshape
    tmp_y = all_xy[:, 784]    # 1-D required

    self.x_data = \
      T.tensor(tmp_x, dtype=T.int64).to(device)
    self.y_data = \
      T.tensor(tmp_y, dtype=T.int64).to(device) 

  def __len__(self):
    return len(self.x_data)

  def __getitem__(self, idx):
    lbl = self.y_data[idx] 
    pixels = self.x_data[idx] 
    return (pixels, lbl)

# -----------------------------------------------------------

class TransformerNet(T.nn.Module):
  def __init__(self):
    # vocab_size = 256, embed = 4
    super(TransformerNet, self).__init__()  # old syntax

    self.embed = T.nn.Embedding(256, 4)  # word embedding

    self.pos_enc = \
      PositionalEncoding(4, dropout=0.20)  # positional

    self.enc_layer = T.nn.TransformerEncoderLayer(d_model=4,
      nhead=2, dim_feedforward=100, 
      batch_first=True)  # d_model divisible by nhead

    self.trans_enc = T.nn.TransformerEncoder(self.enc_layer,
      num_layers=6)  # 6 layers default

    self.fc1 = T.nn.Linear(4*784, 10)  # 10 classes

  def forward(self, x):
    # x = 784 pixels. length = fixed.
    z = self.embed(x)  # pixels to embed vector
    z = z.reshape(-1, 784, 4)  # bat seq embed 
    z = self.pos_enc(z) 
    z = self.trans_enc(z) 
    z = z.reshape(-1, 4*784)  # torch.Size([bs, xxx])
    z = T.log_softmax(self.fc1(z), dim=1)  # NLLLoss()
    return z 

# -----------------------------------------------------------

class PositionalEncoding(T.nn.Module):  # documentation code
  def __init__(self, d_model: int, dropout: float=0.1,
   max_len: int=5000):
    super(PositionalEncoding, self).__init__()  # old syntax
    self.dropout = T.nn.Dropout(p=dropout)
    pe = T.zeros(max_len, d_model)  # like 10x4
    position = \
      T.arange(0, max_len, dtype=T.float).unsqueeze(1)
    div_term = T.exp(T.arange(0, d_model, 2).float() * \
      (-np.log(10_000.0) / d_model))
    pe[:, 0::2] = T.sin(position * div_term)
    pe[:, 1::2] = T.cos(position * div_term)
    pe = pe.unsqueeze(0).transpose(0, 1)
    self.register_buffer('pe', pe)  # allows state-save

  def forward(self, x):
    x = x + self.pe[:x.size(0), :]
    return self.dropout(x)

# -----------------------------------------------------------

def accuracy(model, ds):
  # assumes model.eval() called
  ldr = T.utils.data.DataLoader(ds,
    batch_size=len(ds), shuffle=False)
  n_correct = 0
  for data in ldr:
    (pixels, labels) = data
    with T.no_grad():
      oupts = model(pixels)  # log_softmax values
    (_, predicteds) = T.max(oupts, dim=1)
    n_correct += (predicteds == labels).sum().item()

  acc = (n_correct * 1.0) / len(ds)
  return acc

# -----------------------------------------------------------

def main():
  # 0. setup
  print("\nBegin MNIST with Transformer demo ")
  np.random.seed(1)
  T.manual_seed(1)

  # 1. create Dataset
  print("\nCreating 1000-item train Dataset from text file ")
  train_file = ".\\Data\\mnist_train_1000.txt"
  train_ds = MNIST_Dataset(train_file)

  test_file = ".\\Data\\mnist_test_100.txt"
  test_ds = MNIST_Dataset(test_file)

  bat_size = 20
  train_ldr = T.utils.data.DataLoader(train_ds,
    batch_size=bat_size, shuffle=True)

  # 2. create network
  print("\nCreating Transformer network ")
  net = TransformerNet().to(device)
  
# -----------------------------------------------------------

  # 3. train model
  max_epochs = 20  
  ep_log_interval = 2
  lrn_rate = 0.01
  
  loss_func = T.nn.NLLLoss()  # assumes log-softmax()
  optimizer = T.optim.SGD(net.parameters(), lr=lrn_rate)
  # optimizer = T.optim.Adam(net.parameters(), lr=lrn_rate)
  
  print("\nbat_size = %3d " % bat_size)
  print("loss = " + str(loss_func))
  print("optimizer = SGD")
  print("lrn_rate = %0.3f " % lrn_rate)
  print("max_epochs = %3d " % max_epochs)

  print("\nStarting training")
  net.train()  # set mode
  for epoch in range(0, max_epochs):
    ep_loss = 0.0  # for one full epoch
    for (batch_idx, batch) in enumerate(train_ldr):
      (X, y) = batch  # X = pixels, y = target labels
      optimizer.zero_grad()
      oupt = net(X)
      loss_val = loss_func(oupt, y)  # a tensor
      ep_loss += loss_val.item()  # accumulate
      loss_val.backward()  # compute grads
      optimizer.step()     # update weights
    if epoch % ep_log_interval == 0:
      print("epoch = %4d   |  loss = %9.4f" % (epoch, ep_loss))

  print("Done ") 

# -----------------------------------------------------------

  # 4. evaluate model accuracy
  print("\nComputing model accuracy")
  net.eval()
  acc_train = accuracy(net, train_ds)  # all at once
  print("Accuracy on training data = %0.4f" % acc_train)

  net.eval()
  acc_test = accuracy(net, test_ds)  # all at once
  print("Accuracy on test data = %0.4f" % acc_test)

# -----------------------------------------------------------

  # 5. use model
  print("\nMaking prediction for fake image: ")
  x = np.zeros(shape=(28,28), dtype=np.int64)
  for row in range(5,23):
    x[row][9] = 180  # vertical line
  for rc in range(9,19):
    x[rc][rc] = 250  # diagonal
  for col in range(5,15):  
    x[14][col] = 200  # horizontal

  plt.tight_layout()
  plt.imshow(x, cmap=plt.get_cmap('gray_r'))
  plt.show()

  x = x.reshape(1, 784)
  x = T.tensor(x, dtype=T.int64).to(device)
  with T.no_grad():
    oupt = net(x)  # 10 log-soft logits
  pred_probs = T.exp(oupt)
  print("\nPrediction probabilities: ")
  np.set_printoptions(precision=4, suppress=True)
  print(pred_probs.numpy())

  digits = ['zero', 'one', 'two', 'three', 'four', 'five', 
    'six', 'seven', 'eight', 'nine' ]
  am = T.argmax(oupt) # 0 to 9
  print("\nPredicted class is \'" + digits[am] + "\'")

# -----------------------------------------------------------

  # 6. save model
  print("\nSaving trained model state")
  fn = ".\\Models\\mnist_model.pt"
  T.save(net.state_dict(), fn)  

  print("\nEnd MNIST Transformer demo ")

if __name__ == "__main__":
  main()
This entry was posted in PyTorch, Transformers. Bookmark the permalink.

Leave a Reply

Please log in using one of these methods to post your comment:

WordPress.com Logo

You are commenting using your WordPress.com account. Log Out /  Change )

Twitter picture

You are commenting using your Twitter account. Log Out /  Change )

Facebook photo

You are commenting using your Facebook account. Log Out /  Change )

Connecting to %s