Another Example of PyTorch Teacher-Student Training

The neural network teacher-student technique is designed to take a large network and reduce its size. This is sometimes called distillation, or model compression, or weight pruning, or lottery-ticket, and several other terms too.

The idea is simple. Start with a large neural network (the teacher) and train it using training data as usual. Then create a second smaller network (the student) and train it to reproduce the results of the teacher. For example, training the teacher looks like:

  for (batch_idx, batch) in enumerate(train_ldr):
    X = batch[0]  # the predictors / inputs
    Y = batch[1]  # the targets 
    oupt = teacher(X) 
. . . 

But training the student looks like:

  for (batch_idx, batch) in enumerate(train_ldr):
    X = batch[0]    # the predictors / inputs
    Y = teacher(X)  # outputs from the teacher
    oupt = student(X)
. . . 

The teacher-student technique is just a general idea rather than a specific algorithm and so there are many different ways to implement the technique. I’ve looked at teacher-student before but I wanted to revisit the ideas. I used one of my standard multi-class classification examples where the goal is to predict a person’s political leaning (conservative, moderate, liberal) from sex, age, state (Michigan, Nebraska, Oklahoma), and income. The normalized and encoded data looks like:

# sex  age  state       income  politics
 1   0.24   1   0   0   0.2950   2
-1   0.39   0   0   1   0.5120   1
 1   0.63   0   1   0   0.7580   0
-1   0.36   1   0   0   0.4450   1
. . .

I created a large teacher network with 6-(10-10)-3 architecture and trained it using NLLLoss(). Then I created a small student network with a 6-8-3 architecture and trained it using MSELoss(). Both networks had similar classification accuracy, which indicates the teacher-student technique succeeded in finding a condensed version of the original large network.

Good fun!



Three images from a stock photo search for “teacher-student”. The equations on the chalk boards in the background are hilarious. I classify the teacher on the left as too happy, the teacher in center as too angry, and the teacher on the right as too ecstatic. Thank you stock photos.


Demo code. The train and test data can be found at https://jamesmccaffrey.wordpress.com/2022/09/01/multi-class-classification-using-pytorch-1-12-1-on-windows-10-11/.

# people_teacher_student.py
# predict politics from sex, age, state, income
# use teacher-student technique to create a smaller network
# PyTorch 1.12.1-CPU Anaconda3-2020.02  Python 3.7.6
# Windows 10/11 

import numpy as np
import torch as T
device = T.device('cpu')  # apply to Tensor or Module

# -----------------------------------------------------------

class PeopleDataset(T.utils.data.Dataset):
  # sex age   state    income  politics
  # -1  0.27  0  1  0  0.7610  2
  # +1  0.19  0  0  1  0.6550  0
  # sex: -1 = male, +1 = female
  # state: michigan, nebraska, oklahoma
  # politics: conservative, moderate, liberal

  def __init__(self, src_file):
    tmp_x = np.loadtxt(src_file, usecols=range(0,6),
      delimiter="\t", dtype=np.float32)
    tmp_y = np.loadtxt(src_file, usecols=6,
      delimiter="\t", dtype=np.int64)   # 1d required

    self.x_data = T.tensor(tmp_x, dtype=T.float32).to(device)
    self.y_data = T.tensor(tmp_y, dtype=T.int64).to(device) 

  def __len__(self):
    return len(self.x_data)

  def __getitem__(self, idx):
    preds = self.x_data[idx]
    trgts = self.y_data[idx] 
    return (preds, trgts)  # as Tuple

# -----------------------------------------------------------

class TeacherNet(T.nn.Module):
  def __init__(self):
    super(TeacherNet, self).__init__()
    self.hid1 = T.nn.Linear(6, 10)  # 6-(10-10)-3
    self.hid2 = T.nn.Linear(10, 10)
    self.oupt = T.nn.Linear(10, 3)

    T.nn.init.xavier_uniform_(self.hid1.weight)
    T.nn.init.zeros_(self.hid1.bias)
    T.nn.init.xavier_uniform_(self.hid2.weight)
    T.nn.init.zeros_(self.hid2.bias)
    T.nn.init.xavier_uniform_(self.oupt.weight)
    T.nn.init.zeros_(self.oupt.bias)

  def forward(self, x):
    z = T.tanh(self.hid1(x))
    z = T.tanh(self.hid2(z))
    z = T.log_softmax(self.oupt(z), dim=1)  # NLLLoss() 
    return z

# -----------------------------------------------------------

class StudentNet(T.nn.Module):
  def __init__(self):
    super(StudentNet, self).__init__()
    self.hid1 = T.nn.Linear(6, 8)  # 6-8-3
    self.oupt = T.nn.Linear(8, 3)

    T.nn.init.xavier_uniform_(self.hid1.weight)
    T.nn.init.zeros_(self.hid1.bias)
    T.nn.init.xavier_uniform_(self.oupt.weight)
    T.nn.init.zeros_(self.oupt.bias)

  def forward(self, x):
    z = T.tanh(self.hid1(x))
    z = self.oupt(z)  # no activation for MSELoss() 
    return z

# -----------------------------------------------------------

def accuracy(model, ds):
  # assumes model.eval()
  n_correct = 0; n_wrong = 0
  for i in range(len(ds)):
    X = ds[i][0].reshape(1,-1)  # make it a batch
    Y = ds[i][1].reshape(1)     # 0 1 or 2
    with T.no_grad():
      oupt = model(X)  # logits form

    big_idx = T.argmax(oupt)  # 0 or 1 or 2
    if big_idx == Y:
      n_correct += 1
    else:
      n_wrong += 1

  acc = (n_correct * 1.0) / (n_correct + n_wrong)
  return acc

# -----------------------------------------------------------

def main():
  # 0. get started
  print("\nBegin Teacher-Student NN demo ")
  T.manual_seed(0)
  np.random.seed(0)
  
  # 1. create datasets objects
  print("\nCreating teacher network Datasets ")

  train_file = ".\\Data\\people_train.txt"
  train_ds = PeopleDataset(train_file)  # 200 rows

  test_file = ".\\Data\\people_test.txt"
  test_ds = PeopleDataset(test_file)  # 40 rows

  bat_size = 10
  train_ldr = T.utils.data.DataLoader(train_ds,
    batch_size=bat_size, shuffle=True)

  # 2. create network
  print("\nCreating 6-(10-10)-3 teacher network ")
  teacher = TeacherNet().to(device)
  teacher.train()  # set mode

  # 3. train the teacher NN
  max_epochs = 2000
  ep_log_interval = 500
  lrn_rate = 0.005
  # max_epochs = 20
  # ep_log_interval = 2
  # lrn_rate = 0.005

  loss_func = T.nn.NLLLoss()  # assumes log-softmax()
  optimizer = T.optim.SGD(teacher.parameters(), lr=lrn_rate)

  print("\nbat_size = %3d " % bat_size)
  print("loss = " + str(loss_func))
  print("optimizer = SGD")
  print("max_epochs = %3d " % max_epochs)
  print("lrn_rate = %0.3f " % lrn_rate)

  print("\nStarting training the teacher NN")
  for epoch in range(0, max_epochs):
    epoch_loss = 0  # for one full epoch
    for (batch_idx, batch) in enumerate(train_ldr):
      X = batch[0] 
      Y = batch[1] 
      optimizer.zero_grad()
      oupt = teacher(X)
      loss_val = loss_func(oupt, Y)  # a tensor
      epoch_loss += loss_val.item()  # accumulate
      loss_val.backward()
      optimizer.step()

    if epoch % ep_log_interval == 0:
      print("epoch = %4d   loss = %0.4f" % (epoch, epoch_loss))
  print("Done ")

  # 4. evaluate model accuracy
  print("\nComputing teacher model accuracy")
  teacher.eval()
  acc_train = accuracy(teacher, train_ds)  # item-by-item
  print("Teacher accuracy on training data = %0.4f" % acc_train)
  acc_test = accuracy(teacher, test_ds)  # item-by-item
  print("Teacher accuracy on test data = %0.4f" % acc_test)

  # 5. create and train Student NN
  print("\nCreating  6-8-3 student NN")
  student = StudentNet()
  student.train()  # set mode

  # 6. recreate Dataset and DataLoader
  train_file = ".\\Data\\people_train.txt"
  train_ds = PeopleDataset(train_file)  # 200 rows

  test_file = ".\\Data\\people_test.txt"
  test_ds = PeopleDataset(test_file)  # 40 rows

  bat_size = 10
  train_ldr = T.utils.data.DataLoader(train_ds,
    batch_size=bat_size, shuffle=True)

  # 7. train student NN
  max_epochs = 2000
  ep_log_interval = 500
  lrn_rate = 0.005

  loss_func = T.nn.MSELoss()  # no hidden activation
  optimizer = T.optim.SGD(student.parameters(), lr=lrn_rate)

  print("\nbat_size = %3d " % bat_size)
  print("loss = " + str(loss_func))
  print("optimizer = SGD")
  print("max_epochs = %3d " % max_epochs)
  print("lrn_rate = %0.3f " % lrn_rate)

  print("\nStarting training the student NN ")
  for epoch in range(0, max_epochs):
    epoch_loss = 0  # for one full epoch

    for (batch_idx, batch) in enumerate(train_ldr):
      X = batch[0] 
      # Y = batch[1] 
      Y = teacher(X)  # log_softmax logits output from teacher

      optimizer.zero_grad()
      oupt = student(X)  # outputs from Student
      loss_val = loss_func(oupt, Y)  # a tensor
      epoch_loss += loss_val.item()  # accumulate
      loss_val.backward()
      optimizer.step()

    if epoch % ep_log_interval == 0:
      print("epoch = %4d   loss = %0.4f" % (epoch, epoch_loss))
  print("Done ")

  # 8. evaluate student model accuracy
  print("\nComputing student model accuracy")
  student.eval()
  acc_train = accuracy(student, train_ds)  # item-by-item
  print("Student accuracy on training data = %0.4f" % acc_train)
  acc_test = accuracy(student, test_ds)  # item-by-item
  print("Student accuracy on test data = %0.4f" % acc_test)

  # 9. TODO: save trained student model

  print("\nEnd Teacher-Student NN demo")

if __name__ == "__main__":
  main()
This entry was posted in PyTorch. Bookmark the permalink.

Leave a Reply

Please log in using one of these methods to post your comment:

WordPress.com Logo

You are commenting using your WordPress.com account. Log Out /  Change )

Twitter picture

You are commenting using your Twitter account. Log Out /  Change )

Facebook photo

You are commenting using your Facebook account. Log Out /  Change )

Connecting to %s