An Example of Warm-Start Training for PyTorch

When you train a neural network from scratch, the weights and biases of the network are initialized with random values. In warm-start training, instead of using random values for initialization, you use the weights and biases of a different trained network. This assumes the existing trained network has the same architecture as the network you are training.

Example of warm-start training. A network for Fashion-MNIST data was initialized using the weights and biases from a network that was trained on MNIST data. The model wasn’t significantly better than one trained from scratch.

I put together a complete end-to-end demo of warm-start training.

1. I created and trained from scratch a network for the MNIST digits dataset. I saved the model weights and biases.
2. I created and trained from scratch a network for the Fashion-MNIST dataset. The test accuracy was 78.00%.
3. I created a network for Fashion-MNIST warm-start initialized with the trained MNIST network weights. After training, the test accuracy was 79.00%.

So in this example, the warm-start training helped a little bit, but not much.

Based on my experience with warm-start training, this result is typical. Sometimes warm-start helps and sometimes it doesn’t. In some examples, warm-start training gives worse results than training from scratch.

For most scenarios, warm-start training isn’t useful. The only scenario where you might want to consider warm-start training is when you have a huge network that needs to be trained, and also you have a pre-trained network that has the same (or nearly the same) architecture and the pre-trained network was trained using data that is highly similar to the new data.

If you don’t have a huge network, then you just train it from scratch. If the pre-trained network has a different architecture than your new network, you can’t copy the pre-trained weights to the new network. If the pre-trained network was trained using data that’s significantly different from the new data, there’s not reason to believe that warm-start training will help.

From a coding perspective, warm-start training is simple. The key code in my demo is:

import torch as T
device = T.device('cpu')
. . .

# 2a. create Fashion-MNIST network
print("\Creating network with 2 conv and 3 linear ")
net = Net().to(device)

# 2b. init network with trained MNIST weights and biases
print("Loading pre-trained MNIST wts and biases ")
fn = ".\\Models\\"

# 3. train model as usual
. . .

In practice, warm-start training usually pops up in situations where you have new training data arriving repeatedly. It would seem to make sense to train with the new data starting with the existing weights and biases, rather than re-training the network from scratch with all the old data plus the new data. However, that approach doesn’t work well. In most cases you get a better trained model by starting over from scratch.

Spacesuit fashion in science fiction movies. Left: “Voyage to the Prehistoric Planet” (1965) – Explorers go to Venus where they find various monsters. Center: “Sunshine” (2007) – A team of astronauts try to reignite the sun using a fission bomb. Right: “The Fifth Element” (1997) – The Mondoshawans are friendly aliens with unusual spacesuits who guard the Earth from the Evil.

This entry was posted in PyTorch. Bookmark the permalink.

Leave a Reply

Please log in using one of these methods to post your comment: Logo

You are commenting using your account. Log Out /  Change )

Twitter picture

You are commenting using your Twitter account. Log Out /  Change )

Facebook photo

You are commenting using your Facebook account. Log Out /  Change )

Connecting to %s