The Simplest Possible PyTorch Transformer Sequence-to-Sequence Example

I’ve been looking at PyTorch transformer architecture (TA) networks. TA networks are among the most complex software components I’ve ever worked with, in terms of both conceptual complexity and engineering difficulty.

I set out to implement the simplest possible transformer sequence-to-sequence example I could make. I discovered that even a simple example is extremely complicated. My simplifying tricks include: 1.) instead of natural language, all inputs and outputs are integer tokens so I don’t have to deal with tokenization and vocabulary creation, 2.) all input and output sequences have exactly 8 tokens so I don’t have to deal with source and target pad-masking.

The Data

The first step was to create some synthetic training data. I set PAD = 0, SOS = 1, EOS = 2, UNK = 3. I set ordinary tokens from 4 to 9. I generated 1,000 training items that look like:

1,4,5,8,9,8,5,6,8,2,1,5,6,9,4,9,6,7,9,2

The first 10 values are the input sequence with a leading start-of-sequence 1 and and end-of-sequence 2. The second 10 values are the target. Conceptually:

src = 1, 4,5,8,9,8,5,6,8, 2
tgt = 1, 5,6,9,4,9,6,7,9, 2

Because of my simplifications, there are no 0 (PAD) or 3 (UNK) tokens. I generated the data programmatically. The input sequence is 1, followed by eight random values between 4 and 9, followed by 2. The output sequence is 1, followed by eight values where each value is 1 more than the corresponding input (input of 9 wraps around to 4), followed by 2. The code for the Python program that generated the 1,000 training items is at the bottom of this blog post.

The Transformer Network

The transformer network is deceptively simple looking. I used a tiny embedding dim of 4. I set a vocabulary size of 12 even though there are only 10 tokens — during debugging I wanted to differentiate the vocabulary size from the sequence size of 10 when SOS and EOS are included.

class TransformerNet(T.nn.Module):
  def __init__(self):
    # vocab_size = 12, embed_dim = 4, seq_len = 9/10
    super(TransformerNet, self).__init__()   # classic syntax
    self.embed = T.nn.Embedding(12, 4)       # word embedding
    self.pos_enc = PositionalEncoding(4)     # positional
    self.trans = T.nn.Transformer(d_model=4, nhead=2, \
      dropout=0.0, batch_first=True)  # d_model div by nhead
    self.fc = T.nn.Linear(4, 12)  # embed_dim to vocab_size
    
  def forward(self, src, tgt, tgt_mask):
    s = self.embed(src)
    t = self.embed(tgt)

    s = self.pos_enc(s)  # [bs,seq=10,embed]
    t = self.pos_enc(t)  # [bs,seq=9,embed]

    z = self.trans(src=s, tgt=t, tgt_mask=tgt_mask)
    z = self.fc(z)     
    return z 

I used the batch_first option — dealing with the shapes of all the data (src, tgt, tgt_in, tgt_expected, etc.) was very difficult and time-consuming. I used a program-defined PositionalEncoding layer I copied from the PyTorch documentation.


The PyTorch Transformer() class is made of a TransformerEncoder() and a TransformerDecoder(). Both are very complex and have a lot of parameters. My demo uses most of the default values to hide the complexity.


The Transformer class and its forward() method have a gazillion parameters. I used most of the default values but reduced the nhead parameter to 2 and didn’t use dropout. The output of the network is a set of 12 logits that indirectly represent the pseudo-probabilities of each of the 12 tokens.

Training

Training a Transformer network has a couple of major differences compared to training a simple architecture network. In a simple network you pass a batch of input values and get a batch of output values. But in a TA sequence-to-sequence network, you pass an input sequence, a target sequence that’s been shifted, and a target mask. These ideas are conceptually very tricky and a full explanation would take pages. I spent many days reading through the PyTorch documentation and dissecting a few of the examples I found on the Internet. There are a lot of details I don’t fully understand yet.

The key training code is:

. . . 
for bix,batch in enumerate(train_ldr):
  src = batch[0]  # src  [bs,10] inc sos eos
  tgt = batch[1]  # tgt  [bs,10]

  tgt_in = tgt[:,:-1]  # [bs,9] remove trail eos
  tgt_expect = tgt[:,1:]  # [bs,9] remove lead sos
  t_mask = \
    T.nn.Transformer.generate_square_subsequent_mask(9)
      
  # no padding so no src_pad_mask, tgt_pad_mask

  preds = net(src, tgt_in, \
    tgt_mask=t_mask) # [bs,seq,vocab]

  # get preds shape to conform to tgt_expect
  preds = preds.permute(0,2,1)  # now [bs, vocab, seq] 

  loss_val = loss_func(preds, tgt_expect)
  epoch_loss += loss_val.item()

  opt.zero_grad()
  loss_val.backward()  # compute gradients
  opt.step()     # update weights
. . .

If you’re reading this blog post to help you understand Transformer sequence-to-sequence, I’ll reiterate that this code is extraordinarily tricky and complex. For example, unlike a simple neural network, here the shapes of the two sets of data passed to the CrossEntropyLoss loss_func() function are different sizes. Just wading through that issue alone took me a couple of days of reading documentation and experimentation.

Using the Trained Model

Using a standard neural network is simple: feed it some input and capture the output prediction. Using a Transformer sequence-to-sequence trained model is a significant challenge in itself.

src = T.tensor([[1, 4,5,6,7,6,5,4, 2]], 
  dtype=T.int64).to(device)
# should predict 5,6,7,8,7,6,5
tgt_in = T.tensor([[1]], dtype=T.int64).to(device)
t_mask = \
  T.nn.Transformer.generate_square_subsequent_mask(1)

with T.no_grad():
  preds = model(src, tgt_in, tgt_mask=t_mask)
# result is 12 logits where largest is at the
# predicted token

First I set up an arbitrary src sequence of 4,5,6,7,6,5,4. The predicted sequence should be 5,6,7,8,7,6,5. By feeding a tgt_in value of 1 (start-of-sequence) to the trained network, I figured the output should be the first token in the target — 5, which it was.

To predict the second output token, you’d concatenate the predicted first token to the tgt_in giving [1,5] and then feed it to the trained model (and hopefully get a 6). You could continue this process until you get a prediction of EOS = 2.

Note: I wrote a follow-up post about using a trained sequence-to-sequence model at jamesmccaffrey.wordpress.com/2022/09/12/using-the-simplest-possible-transformer-sequence-to-sequence-example/

In Conclusion

Because Transformer Architecture systems are so fantastically complex, I’m nearly certain the my demo example has some conceptual errors and some engineering errors. But it’s a step in the direction of ultimately understanding these beasts.



I have always enjoyed the Tintin sequence of books. Left: “King Ottokar’s Sceptre” (#8, first published 1939, 1947 edition). Center: “The Blue Lotus” (#5 first published 1936, 1946 edition). Right: “Cigars of the Pharaoh” (#4, first published 1934, 1955 edition).


Demo code. Replace “lt”, “gt”, “lte”, “gte” with Boolean operator symbols.

# seq2seq.py
# Transformer seq-to-seq example

# PyTorch 1.10.0-CPU Anaconda3-2020.02  Python 3.7.6
# Windows 10/11

import numpy as np
import torch as T
import math

device = T.device('cpu')
T.set_num_threads(1)

# -----------------------------------------------------------

class DummySeq_Dataset(T.utils.data.Dataset):
  # one inpt = sos + 8 ints (4-9) + eos = (10 ints)
  # pad = 0 (not used), sos = 1, eos = 2

  def __init__(self, src_file):
    all_xy = np.loadtxt(src_file, usecols=range(0,20),
      delimiter=",", comments="#", dtype=np.int64)
    tmp_x = all_xy[:,0:10]   # cols [0,9] sos 8 vals eos
    tmp_y = all_xy[:,10:20]  # cols [10,19] sos 8 vals eos
    self.x_data = T.tensor(tmp_x, dtype=T.int64).to(device) 
    self.y_data = T.tensor(tmp_y, dtype=T.int64).to(device) 

  def __len__(self):
    return len(self.x_data)

  def __getitem__(self, idx):
    src_seq = self.x_data[idx]
    tgt_seq = self.y_data[idx] 
    return (src_seq, tgt_seq)  # as a tuple

# -----------------------------------------------------------

class TransformerNet(T.nn.Module):
  # a Transformer class has an internal TransformerEncoder
  # connected with an internal TransformerDecoder

  # nn.Transformer(d_model=512, nhead=8, num_encoder_layers=6,
  #   num_decoder_layers=6, dim_feedforward=2048, dropout=0.1,
  #   activation={function relu}, custom_encoder=None,
  #   custom_decoder=None, layer_norm_eps=1e-05,
  #   batch_first=False, norm_first=False,
  #   device=None, dtype=None)
  # Note: d_model = embed_dim must be divisible by nhead

  # Transformer.forward(src, tgt, src_mask=None, tgt_mask=None,
  #   memory_mask=None, src_key_padding_mask=None,
  #   tgt_key_padding_mask=None, memory_key_padding_mask=None)

  def __init__(self):
    # vocab_size = 12, embed_dim = d_model = 4, seq_len = 9/10
    super(TransformerNet, self).__init__()  # classic syntax
    self.embed = T.nn.Embedding(12, 4)       # word embedding
    self.pos_enc = PositionalEncoding(4)    # positional
    self.trans = T.nn.Transformer(d_model=4, nhead=2, \
      dropout=0.0, batch_first=True)  # d_model div by nhead
    self.fc = T.nn.Linear(4, 12)  # embed_dim to vocab_size
    
  def forward(self, src, tgt, tgt_mask):
    s = self.embed(src)
    t = self.embed(tgt)

    s = self.pos_enc(s)  # [bs,seq=10,embed]
    t = self.pos_enc(t)  # [bs,seq=9,embed]

    z = self.trans(src=s, tgt=t, tgt_mask=tgt_mask)
    z = self.fc(z)     
    return z 

# -----------------------------------------------------------

class PositionalEncoding(T.nn.Module):  # documentation code
  def __init__(self, d_model: int, dropout: float=0.0,
   max_len: int=5000):
    super(PositionalEncoding, self).__init__()  # old syntax
    self.dropout = T.nn.Dropout(p=dropout)
    pe = T.zeros(max_len, d_model)  # like 10x4
    position = \
      T.arange(0, max_len, dtype=T.float).unsqueeze(1)
    div_term = T.exp(T.arange(0, d_model, 2).float() * \
      (-np.log(10_000.0) / d_model))
    pe[:, 0::2] = T.sin(position * div_term)
    pe[:, 1::2] = T.cos(position * div_term)
    pe = pe.unsqueeze(0).transpose(0, 1)
    self.register_buffer('pe', pe)  # allows state-save

  def forward(self, x):
    x = x + self.pe[:x.size(0), :]
    return self.dropout(x)

# -----------------------------------------------------------

# deprecated: 
# use Transformer.generate_square_subsequent_mask() instead

def make_mask(sz):
  mask = T.zeros((sz,sz), dtype=T.float32).to(device)
  for i in range(sz):
    for j in range(sz):
      if j "gt" i: mask[i][j] = float('-inf')  # 
  return mask

  # if sz = 4
  # [[0.0, -inf, -inf, -inf],
  #  [0.0,  0.0, -inf, -inf],
  #  [0.0,  0.0,  0.0, -inf],
  #  [0.0,  0.0,  0.0,  0.0]])

# -----------------------------------------------------------

def main():
  # 0. get started
  print("\nBegin PyTorch Transformer seq-to-seq demo ")
  T.manual_seed(1)  
  np.random.seed(1)

  # 1. load data 
  print("\nLoading synthetic int-token train data ")
  train_file = ".\\Data\\train_data2_1000.txt"
  train_ds = DummySeq_Dataset(train_file) 

  bat_size = 10
  train_ldr = T.utils.data.DataLoader(train_ds,
    batch_size=bat_size, shuffle=True, drop_last=True)

  # 2. create Transformer network
  print("\nCreating batch-first Transformer network ")
  net = TransformerNet().to(device)
  net.train()

# -----------------------------------------------------------

  # 3. train the network
  loss_func = T.nn.CrossEntropyLoss()
  opt = T.optim.SGD(net.parameters(), lr=0.01)
  max_epochs = 200
  log_interval = 20  # display progress 

  print("\nStarting training ")
  for epoch in range(max_epochs):
    epoch_loss = 0.0  # loss for one full epoch
    for bix,batch in enumerate(train_ldr):
      src = batch[0]  # src  [bs,10] inc sos eos
      tgt = batch[1]  # tgt  [bs,10]

      tgt_in = tgt[:,:-1]  # [bs,9] remove trail eos
      tgt_expect = tgt[:,1:]  # [bs,9] remove lead sos
      t_mask = \
        T.nn.Transformer.generate_square_subsequent_mask(9)
      
      # no padding so no src_pad_mask, tgt_pad_mask

      preds = net(src, tgt_in, \
        tgt_mask=t_mask) # [bs,seq,vocab]

      # get preds shape to conform to tgt_expect
      preds = preds.permute(0,2,1)  # now [bs, vocab, seq] 

      loss_val = loss_func(preds, tgt_expect) # [bs,12,9] [bs,9]
      epoch_loss += loss_val.item()

      opt.zero_grad()
      loss_val.backward()  # compute gradients
      # T.nn.utils.clip_grad_value_(net.parameters(), 0.5)
      opt.step()     # update weights

    if epoch % log_interval == 0:
      print("epoch = %4d  |" % epoch, end="")
      print("   loss = %12.6f  |" % epoch_loss)

  print("Done ")

# -----------------------------------------------------------

  # 4. save trained model
  print("\nSaving trained model state")
  fn = ".\\Models\\transformer_seq_model.pt"
  net.eval()
  T.save(net.state_dict(), fn)

  # 5. use model
  print("\nCreating new Transformer seq-to-seq network ")
  model = TransformerNet().to(device)
  model.eval()  

  print("\nLoading saved model weights and biases ")
  fn = ".\\Models\\transformer_seq_model.pt"
  # model.load_state_dict(T.load(fn))

  src = T.tensor([[1, 4,5,6,7,6,5,4, 2]], 
    dtype=T.int64).to(device)
  # should predict 5,6,7,8,7,6,5
  tgt_in = T.tensor([[1]], dtype=T.int64).to(device)
  t_mask = \
    T.nn.Transformer.generate_square_subsequent_mask(1)

  with T.no_grad():
    preds = model(src, tgt_in, tgt_mask=t_mask)
  print("\nInput: ")
  print(src)
  print("\npredected pseudo-probs: ")
  print(preds)  # first output token should be 5
  pred_token = T.argmax(preds)
  print("\nfirst pred output token: " + str(pred_token))

  print("\nEnd PyTorch Transformer seq-to-seq demo ")

if __name__ == "__main__":
  main()

Program to generate training data:

# make_data.py

# make dummy data for Transformer seq2seq experiments
# each input seq is 8 ints from 4-9 inclusive
# the target seq vals are 1 greater

# PAD = 0, SOS = 1, EOS = 2, UNK = 3
# regular: 4,5,6,7,8,9

# ex:
# inpt = [1, 5,9,6,4,4,7,8,7, 2]
# oupt = [1, 6,4,7,5,5,8,9,8, 2]
# values greater than 9 wrap around to 4

import numpy as np

np.random.seed(1)

num_items = 1000
fout = open(".\\train_data2_1000.txt", "w")

for i in range(num_items):
  inpt = np.zeros(8, dtype=np.int64)
  for j in range(8):  # leave [9] = last cell = 0
    inpt[j] = np.random.randint(4,10)  # 4-9 inclusive

  oupt = np.zeros(8, dtype=np.int64)
  for j in range(8):
    oupt[j] = inpt[j] + 1

    if oupt[j] "gte" 10:
      oupt[j] = 4

  fout.write("1,")  # sos
  for j in range(8):
    fout.write(str(inpt[j]) + ",")
  fout.write("2,") # eos

  fout.write("1,")  # sos
  for j in range(8):
    fout.write(str(oupt[j]) + ",")
  fout.write("2")  # last val no trail comma
  fout.write("\n")
 
fout.close()
This entry was posted in PyTorch, Transformers. Bookmark the permalink.

Leave a Reply

Please log in using one of these methods to post your comment:

WordPress.com Logo

You are commenting using your WordPress.com account. Log Out /  Change )

Twitter picture

You are commenting using your Twitter account. Log Out /  Change )

Facebook photo

You are commenting using your Facebook account. Log Out /  Change )

Connecting to %s