The neural network teacher-student technique is designed to take a large network and reduce its size. This is sometimes called distillation, or model compression, or weight pruning, or lottery-ticket, and several other terms too.
The idea is simple. Start with a large neural network (the teacher) and train it using training data as usual. Then create a second smaller network (the student) and train it to reproduce the results of the teacher. For example, training the teacher looks like:
for (batch_idx, batch) in enumerate(train_ldr): X = batch[0] # the predictors / inputs Y = batch[1] # the targets oupt = teacher(X) . . .
But training the student looks like:
for (batch_idx, batch) in enumerate(train_ldr): X = batch[0] # the predictors / inputs Y = teacher(X) # outputs from the teacher oupt = student(X) . . .
The teacher-student technique is just a general idea rather than a specific algorithm and so there are many different ways to implement the technique. I’ve looked at teacher-student before but I wanted to revisit the ideas. I used one of my standard multi-class classification examples where the goal is to predict a person’s political leaning (conservative, moderate, liberal) from sex, age, state (Michigan, Nebraska, Oklahoma), and income. The normalized and encoded data looks like:
# sex age state income politics 1 0.24 1 0 0 0.2950 2 -1 0.39 0 0 1 0.5120 1 1 0.63 0 1 0 0.7580 0 -1 0.36 1 0 0 0.4450 1 . . .
I created a large teacher network with 6-(10-10)-3 architecture and trained it using NLLLoss(). Then I created a small student network with a 6-8-3 architecture and trained it using MSELoss(). Both networks had similar classification accuracy, which indicates the teacher-student technique succeeded in finding a condensed version of the original large network.
Good fun!
Three images from a stock photo search for “teacher-student”. The equations on the chalk boards in the background are hilarious. I classify the teacher on the left as too happy, the teacher in center as too angry, and the teacher on the right as too ecstatic. Thank you stock photos.
Demo code. The train and test data can be found at https://jamesmccaffrey.wordpress.com/2022/09/01/multi-class-classification-using-pytorch-1-12-1-on-windows-10-11/.
# people_teacher_student.py # predict politics from sex, age, state, income # use teacher-student technique to create a smaller network # PyTorch 1.12.1-CPU Anaconda3-2020.02 Python 3.7.6 # Windows 10/11 import numpy as np import torch as T device = T.device('cpu') # apply to Tensor or Module # ----------------------------------------------------------- class PeopleDataset(T.utils.data.Dataset): # sex age state income politics # -1 0.27 0 1 0 0.7610 2 # +1 0.19 0 0 1 0.6550 0 # sex: -1 = male, +1 = female # state: michigan, nebraska, oklahoma # politics: conservative, moderate, liberal def __init__(self, src_file): tmp_x = np.loadtxt(src_file, usecols=range(0,6), delimiter="\t", dtype=np.float32) tmp_y = np.loadtxt(src_file, usecols=6, delimiter="\t", dtype=np.int64) # 1d required self.x_data = T.tensor(tmp_x, dtype=T.float32).to(device) self.y_data = T.tensor(tmp_y, dtype=T.int64).to(device) def __len__(self): return len(self.x_data) def __getitem__(self, idx): preds = self.x_data[idx] trgts = self.y_data[idx] return (preds, trgts) # as Tuple # ----------------------------------------------------------- class TeacherNet(T.nn.Module): def __init__(self): super(TeacherNet, self).__init__() self.hid1 = T.nn.Linear(6, 10) # 6-(10-10)-3 self.hid2 = T.nn.Linear(10, 10) self.oupt = T.nn.Linear(10, 3) T.nn.init.xavier_uniform_(self.hid1.weight) T.nn.init.zeros_(self.hid1.bias) T.nn.init.xavier_uniform_(self.hid2.weight) T.nn.init.zeros_(self.hid2.bias) T.nn.init.xavier_uniform_(self.oupt.weight) T.nn.init.zeros_(self.oupt.bias) def forward(self, x): z = T.tanh(self.hid1(x)) z = T.tanh(self.hid2(z)) z = T.log_softmax(self.oupt(z), dim=1) # NLLLoss() return z # ----------------------------------------------------------- class StudentNet(T.nn.Module): def __init__(self): super(StudentNet, self).__init__() self.hid1 = T.nn.Linear(6, 8) # 6-8-3 self.oupt = T.nn.Linear(8, 3) T.nn.init.xavier_uniform_(self.hid1.weight) T.nn.init.zeros_(self.hid1.bias) T.nn.init.xavier_uniform_(self.oupt.weight) T.nn.init.zeros_(self.oupt.bias) def forward(self, x): z = T.tanh(self.hid1(x)) z = self.oupt(z) # no activation for MSELoss() return z # ----------------------------------------------------------- def accuracy(model, ds): # assumes model.eval() n_correct = 0; n_wrong = 0 for i in range(len(ds)): X = ds[i][0].reshape(1,-1) # make it a batch Y = ds[i][1].reshape(1) # 0 1 or 2 with T.no_grad(): oupt = model(X) # logits form big_idx = T.argmax(oupt) # 0 or 1 or 2 if big_idx == Y: n_correct += 1 else: n_wrong += 1 acc = (n_correct * 1.0) / (n_correct + n_wrong) return acc # ----------------------------------------------------------- def main(): # 0. get started print("\nBegin Teacher-Student NN demo ") T.manual_seed(0) np.random.seed(0) # 1. create datasets objects print("\nCreating teacher network Datasets ") train_file = ".\\Data\\people_train.txt" train_ds = PeopleDataset(train_file) # 200 rows test_file = ".\\Data\\people_test.txt" test_ds = PeopleDataset(test_file) # 40 rows bat_size = 10 train_ldr = T.utils.data.DataLoader(train_ds, batch_size=bat_size, shuffle=True) # 2. create network print("\nCreating 6-(10-10)-3 teacher network ") teacher = TeacherNet().to(device) teacher.train() # set mode # 3. train the teacher NN max_epochs = 2000 ep_log_interval = 500 lrn_rate = 0.005 # max_epochs = 20 # ep_log_interval = 2 # lrn_rate = 0.005 loss_func = T.nn.NLLLoss() # assumes log-softmax() optimizer = T.optim.SGD(teacher.parameters(), lr=lrn_rate) print("\nbat_size = %3d " % bat_size) print("loss = " + str(loss_func)) print("optimizer = SGD") print("max_epochs = %3d " % max_epochs) print("lrn_rate = %0.3f " % lrn_rate) print("\nStarting training the teacher NN") for epoch in range(0, max_epochs): epoch_loss = 0 # for one full epoch for (batch_idx, batch) in enumerate(train_ldr): X = batch[0] Y = batch[1] optimizer.zero_grad() oupt = teacher(X) loss_val = loss_func(oupt, Y) # a tensor epoch_loss += loss_val.item() # accumulate loss_val.backward() optimizer.step() if epoch % ep_log_interval == 0: print("epoch = %4d loss = %0.4f" % (epoch, epoch_loss)) print("Done ") # 4. evaluate model accuracy print("\nComputing teacher model accuracy") teacher.eval() acc_train = accuracy(teacher, train_ds) # item-by-item print("Teacher accuracy on training data = %0.4f" % acc_train) acc_test = accuracy(teacher, test_ds) # item-by-item print("Teacher accuracy on test data = %0.4f" % acc_test) # 5. create and train Student NN print("\nCreating 6-8-3 student NN") student = StudentNet() student.train() # set mode # 6. recreate Dataset and DataLoader train_file = ".\\Data\\people_train.txt" train_ds = PeopleDataset(train_file) # 200 rows test_file = ".\\Data\\people_test.txt" test_ds = PeopleDataset(test_file) # 40 rows bat_size = 10 train_ldr = T.utils.data.DataLoader(train_ds, batch_size=bat_size, shuffle=True) # 7. train student NN max_epochs = 2000 ep_log_interval = 500 lrn_rate = 0.005 loss_func = T.nn.MSELoss() # no hidden activation optimizer = T.optim.SGD(student.parameters(), lr=lrn_rate) print("\nbat_size = %3d " % bat_size) print("loss = " + str(loss_func)) print("optimizer = SGD") print("max_epochs = %3d " % max_epochs) print("lrn_rate = %0.3f " % lrn_rate) print("\nStarting training the student NN ") for epoch in range(0, max_epochs): epoch_loss = 0 # for one full epoch for (batch_idx, batch) in enumerate(train_ldr): X = batch[0] # Y = batch[1] Y = teacher(X) # log_softmax logits output from teacher optimizer.zero_grad() oupt = student(X) # outputs from Student loss_val = loss_func(oupt, Y) # a tensor epoch_loss += loss_val.item() # accumulate loss_val.backward() optimizer.step() if epoch % ep_log_interval == 0: print("epoch = %4d loss = %0.4f" % (epoch, epoch_loss)) print("Done ") # 8. evaluate student model accuracy print("\nComputing student model accuracy") student.eval() acc_train = accuracy(student, train_ds) # item-by-item print("Student accuracy on training data = %0.4f" % acc_train) acc_test = accuracy(student, test_ds) # item-by-item print("Student accuracy on test data = %0.4f" % acc_test) # 9. TODO: save trained student model print("\nEnd Teacher-Student NN demo") if __name__ == "__main__": main()
You must be logged in to post a comment.