PyTorch Training Checkpoint Exact Recovery Reproducibility

About a year ago I spent many days figuring out how to save a PyTorch training checkpoint in such a way that it’s possible to load the saved information and resume training in a way that’s exactly like the original training. It was a difficult problem. Briefly, I discovered that because there’s no way to save the state of a DataLoader object, you must manually reset the global PyTorch random number seed on each training epoch.

One of the challenges of working with the PyTorch neural network library is that the library is constantly being updated. About every 12-16 weeks or so there’s a new version of PyTorch released and the new version could introduce changes that break existing code written for the previous version.

I recently updated from PyTorch v1.9 to v1.10 and decided to see if my old checkpoint-saving code still worked. It did.

Here’s the key code (with many details left out) to train a neural and save checkpoints so that training is exactly reproducible:

  print("Starting training")
  net.train() 
  for epoch in range(0, max_epochs):
    T.manual_seed(1 + epoch)  # for recovery reproducibility
    epoch_loss = 0  # for one full epoch
    for (batch_idx, batch) in enumerate(train_ldr):
      X = batch['predictors']  # inputs
      Y = batch['targets']     # correct class/label/job
      optimizer.zero_grad()
      oupt = net(X)
      loss_val = loss_func(oupt, Y)  # a tensor
      epoch_loss += loss_val.item()  # accumulate
      loss_val.backward()
      optimizer.step()

    if epoch % ep_log_interval == 0:
      print("epoch = %4d   loss = %0.4f" % (epoch, epoch_loss))

      dt = time.strftime("%Y_%m_%d-%H_%M_%S")
      fn = ".\\Log\\" + str(dt) + str("-") + \
        str(epoch) + "_checkpoint.pt"

      info_dict = { 
        'epoch' : epoch,
        'numpy_random_state' : np.random.get_state(),
        'torch_random_state' : T.random.get_rng_state(),
        'net_state' : net.state_dict(),
        'optimizer_state' : optimizer.state_dict()
      }
      T.save(info_dict, fn)
  print("Done training ")

The key statement is: T.manual_seed(1 + epoch). This resets the PyTorch random number seed which in turn resets the DataLoader.

To use the saved checkpoint from a different program, the key code is:

  fn = ".\\Log\\2021_12_13-10_04_42-500_checkpoint.pt"
  chkpt = T.load(fn)

  np.random.set_state(chkpt['numpy_random_state'])
  T.random.set_rng_state(chkpt['torch_random_state'])

  print("Resuming training from checkpoint")
  net.train()  # or net = net.train()
  net = Net().to(device)
  net.load_state_dict(chkpt['net_state'])

  . . . 

  epoch_saved = chkpt['epoch'] + 1
  for epoch in range(epoch_saved, max_epochs):
    T.manual_seed(1 + epoch)  # for recovery reproducibility
    epoch_loss = 0  # for one full epoch
    for (batch_idx, batch) in enumerate(train_ldr):
      X = batch['predictors']  # inputs
      Y = batch['targets']     # correct class/label/job
      optimizer.zero_grad()
      oupt = net(X)
      loss_val = loss_func(oupt, Y)  # a tensor
      epoch_loss += loss_val.item()  # accumulate
      loss_val.backward()
      optimizer.step()

    if epoch % ep_log_interval == 0:
      print("epoch = %4d   loss = %0.4f" % (epoch, epoch_loss))
      dt = time.strftime("%Y_%m_%d-%H_%M_%S")
      fn = ".\\Log\\" + str(dt) + str("-") + \
        str(epoch) + "_checkpoint.pt"
      info_dict = { 
        'epoch' : epoch,
        'numpy_random_state' : np.random.get_state(),
        'torch_random_state' : T.random.get_rng_state(),
        'net_state' : net.state_dict(),
        'optimizer_state' : optimizer.state_dict() 
      }
      T.save(info_dict, fn)
  print("Done training ")

There are many details I’ve left out — it would take several pages to fully explain everything.

In some neural network scenarios it’s not necessary to get exactly reproducible results. If you don’t manually reset the PyTorch global seed on each epoch, when you reload a saved checkpoint, the resumed training will be close to, but not exactly the same as, the training that occurs without saving the checkpoint. This is because the DataLoader object will have slightly different state and therefore will resume training using a different batch of training items.

Very complicated. Very interesting.



Getting reproducible checkpoint training is mildly difficult. Clones are reproducible humans. Left: In “The Sixth Day” (2000), Adam Gibson (played by actor Arnold Schwarzenegger) discovers he is an illegal clone of himself. Has a happy ending. Center: In “Cloud Atlas” (2012), one of the six story lines is about Sonmi-451 (Bae Doona), a “fabricant” who is a fast food worker and prostitute in the future. Does not turn out well for her. Right: In “Oblivion” (2013), Jack Harper (played by Tom Cruise) discovers that he is a clone created by aliens who have conquered Earth. Does not turn out well for the aliens.


This entry was posted in PyTorch. Bookmark the permalink.