I put together a PyTorch neural network multi-class classification demo. For the neural architecture, I dropped in a TransformerEncoder hidden layer. For model accuracy, I implemented a custom confusion matrix. For interpretability, I implemented custom input gradient monitoring.
I used one of my standard synthetic datasets. The data looks like:
1 0.24 1 0 0 0.2950 2 -1 0.39 0 0 1 0.5120 1 1 0.63 0 1 0 0.7580 0 -1 0.36 1 0 0 0.4450 1 . . .
The fields are sex (male = -1, female = +1), age (divided by 100), State (Michigan = 100, Nebraska = 010, Oklahoma = 001), income (divided by $100,000), and political leaning (0 = conservative, moderate = 1, liberal = 2). The goal is to predict political leaning from sex, age, State, and income. There are 200 training items and 40 test items.
Adding a TransformerEncoder layer to a multi-class classification network is not simple. A shortcut notation for my demo architecture is (6–24)-T-10-3. Each of the 6 input item is fed to a custom pseudo-embedding layer with dim = 4, resulting in 24 values. Then, positional encoding is added even though the predictor variables don’t have any implicit ordering, other than the one-hot encoded State of residence variable. The output of the Transformer layer is fed to a standard hidden layer with 10 nodes. The output of that hidden layer is sent to an output layer with 3 nodes.
For interpretability, I capture and accumulate the gradients of the 6 input variables, every 200 training epochs, and average the 6 sums after training. The results showed that changes in age and income have the greatest effect:
sex, age, state1, state2, state3, income [0.0574, 0.4335, 0.0302, 0.0307, 0.0304, 0.4178]
I could just capture the values of the input gradients after training, but the final input gradients could be small. This is something I need to explore further when I get some free time.
I experimented with several different approaches for monitoring the gradients of the input variables. For example, I modified the network architecture to include an explicit input layer, but it added a lot of complexity without benefit. Monitoring input gradients is a surprisingly tricky task. In the end, I just applied a requires_grad=True attribute to each batch of input items.
Good fun.
For reasons unknown to me, an Internet image search for “interpretability” served up all kinds of images of cyborgs — cyborg animals, cyborg geishas, cyborg whatever. Thank you Internet.
Demo code. Replace “lt” (less than), “gt”, “lte”, “gte” with Boolean operator symbols.
# people_transformer_interpret.py # PyTorch 2.0.0-CPU Anaconda3-2022.10 Python 3.9.13 # Windows 10/11 # Transformer component for political leaning classification import numpy as np import torch as T device = T.device('cpu') T.set_num_threads(1) # ----------------------------------------------------------- class PeopleDataset(T.utils.data.Dataset): # sex age state income politics # -1 0.27 0 1 0 0.7610 2 # +1 0.19 0 0 1 0.6550 0 # sex: -1 = male, +1 = female # state: michigan, nebraska, oklahoma # politics: conservative, moderate, liberal def __init__(self, src_file): all_xy = np.loadtxt(src_file, usecols=range(0,7), delimiter="\t", comments="#", dtype=np.float32) tmp_x = all_xy[:,0:6] # cols [0,6) = [0,5] tmp_y = all_xy[:,6] # 1-D self.x_data = T.tensor(tmp_x, dtype=T.float32).to(device) self.y_data = T.tensor(tmp_y, dtype=T.int64).to(device) # 1-D def __len__(self): return len(self.x_data) def __getitem__(self, idx): preds = self.x_data[idx] trgts = self.y_data[idx] return preds, trgts # as a Tuple # ----------------------------------------------------------- class SkipLinear(T.nn.Module): # ----- class Core(T.nn.Module): def __init__(self, n): super().__init__() # 1 node to n nodes, n gte 2 self.weights = T.nn.Parameter(T.zeros((n,1), dtype=T.float32)) self.biases = T.nn.Parameter(T.tensor(n, dtype=T.float32)) lim = 0.01 T.nn.init.uniform_(self.weights, -lim, lim) T.nn.init.zeros_(self.biases) def forward(self, x): wx= T.mm(x, self.weights.t()) v = T.add(wx, self.biases) return v # ----- def __init__(self, n_in, n_out): super().__init__() self.n_in = n_in; self.n_out = n_out if n_out % n_in != 0: print("FATAL: n_out must be divisible by n_in") n = n_out // n_in # num nodes per input self.lst_modules = \ T.nn.ModuleList([SkipLinear.Core(n) for \ i in range(n_in)]) def forward(self, x): lst_nodes = [] for i in range(self.n_in): xi = x[:,i].reshape(-1,1) oupt = self.lst_modules[i](xi) lst_nodes.append(oupt) result = T.cat((lst_nodes[0], lst_nodes[1]), 1) for i in range(2,self.n_in): result = T.cat((result, lst_nodes[i]), 1) result = result.reshape(-1, self.n_out) return result # ----------------------------------------------------------- class TransformerNet(T.nn.Module): # (6--24)-T-10-3 def __init__(self): super(TransformerNet, self).__init__() # old syntax # numeric pseudo-embedding, dim=4 self.embed = SkipLinear(6, 24) # 6 inputs, each goes to 4 self.pos_enc = \ PositionalEncoding(4, dropout=0.00) # positional self.enc_layer = T.nn.TransformerEncoderLayer(d_model=4, nhead=2, dim_feedforward=10, batch_first=True) # d_model divisible by nhead self.trans_enc = T.nn.TransformerEncoder(self.enc_layer, num_layers=2) # 6 layers default # People dataset has 6 inputs self.fc1 = T.nn.Linear(4*6, 10) # 10 hidden nodes self.fc2 = T.nn.Linear(10, 3) # 3 classes def forward(self, x): # x = 6 inputs, fixed length z = self.embed(x) # 6 inpts to 24 embed z = z.reshape(-1, 6, 4) # bat seq embed z = self.pos_enc(z) z = self.trans_enc(z) z = z.reshape(-1, 4*6) # torch.Size([bs, xxx]) z = T.tanh(self.fc1(z)) z = T.log_softmax(self.fc2(z), dim=1) # NLLLoss() return z # ----------------------------------------------------------- class PositionalEncoding(T.nn.Module): # documentation code def __init__(self, d_model: int, dropout: float=0.1, max_len: int=5000): super(PositionalEncoding, self).__init__() # old syntax self.dropout = T.nn.Dropout(p=dropout) pe = T.zeros(max_len, d_model) # like 10x4 position = \ T.arange(0, max_len, dtype=T.float).unsqueeze(1) div_term = T.exp(T.arange(0, d_model, 2).float() * \ (-np.log(10_000.0) / d_model)) pe[:, 0::2] = T.sin(position * div_term) pe[:, 1::2] = T.cos(position * div_term) pe = pe.unsqueeze(0).transpose(0, 1) self.register_buffer('pe', pe) # allows state-save def forward(self, x): x = x + self.pe[:x.size(0), :] return self.dropout(x) # ----------------------------------------------------------- def accuracy(model, ds): # assumes model.eval() # item-by-item version n_correct = 0; n_wrong = 0 for i in range(len(ds)): X = ds[i][0].reshape(1,-1) # make it a batch Y = ds[i][1].reshape(1) # 0 1 or 2, 1D with T.no_grad(): oupt = model(X) # logits form big_idx = T.argmax(oupt) # 0 or 1 if big_idx == Y: n_correct += 1 else: n_wrong += 1 acc = (n_correct * 1.0) / (n_correct + n_wrong) return acc # ----------------------------------------------------------- def confusion_matrix_multi(model, ds, n_classes): if n_classes "lte" 2: # less-than-or-equal print("ERROR: n_classes must be 3 or greater ") return None cm = np.zeros((n_classes,n_classes), dtype=np.int64) for i in range(len(ds)): X = ds[i][0].reshape(1,-1) # make it a batch Y = ds[i][1].reshape(1) # actual class 0 1 or 2, 1D with T.no_grad(): oupt = model(X) # logits form pred_class = T.argmax(oupt) # 0,1,2 cm[Y][pred_class] += 1 return cm # ----------------------------------------------------------- def show_confusion(cm): dim = len(cm) mx = np.max(cm) # largest count in cm wid = len(str(mx)) + 1 # width to print fmt = "%" + str(wid) + "d" # like "%3d" for i in range(dim): print("actual ", end="") print("%3d:" % i, end="") for j in range(dim): print(fmt % cm[i][j], end="") print("") print("------------") print("predicted ", end="") for j in range(dim): print(fmt % j, end="") print("") # accuracy for each class row_sums = np.sum(cm, axis=1) accs = np.zeros(dim, dtype=np.float32) for i in range(dim): accs[i] = cm[i][i] / row_sums[i] print("\naccuracy by class: ") print(accs) # ----------------------------------------------------------- def main(): # 0. setup print("\nBegin Transformer demo ") np.random.seed(1) T.manual_seed(1) np.set_printoptions(precision=4, suppress=True, floatmode='fixed') T.set_printoptions(precision=4, sci_mode=False) # 1. create Dataset print("\nCreating 200-item train Dataset from text file ") train_file = ".\\Data\\people_train.txt" train_ds = PeopleDataset(train_file) test_file = ".\\Data\\people_test.txt" test_ds = PeopleDataset(test_file) bat_size = 10 train_ldr = T.utils.data.DataLoader(train_ds, batch_size=bat_size, shuffle=True) # 2. create network print("\nCreating Transformer network ") net = TransformerNet().to(device) # ----------------------------------------------------------- # 3. train model max_epochs = 2000 ep_log_interval = 400 lrn_rate = 0.025 grad_accum_interval = 200 # for interpretability loss_func = T.nn.NLLLoss() # assumes log-softmax() optimizer = T.optim.SGD(net.parameters(), lr=lrn_rate) # optimizer = T.optim.Adam(net.parameters(), lr=lrn_rate) print("\nbat_size = %3d " % bat_size) print("loss = " + str(loss_func)) print("optimizer = SGD") print("lrn_rate = %0.3f " % lrn_rate) print("max_epochs = %3d " % max_epochs) acc_batch_grads = np.zeros((bat_size, 6), dtype=np.float32) # accumulated gradients print("\nStarting training") net.train() # set mode for epoch in range(0, max_epochs): ep_loss = 0.0 # for one full epoch for (batch_idx, batch) in enumerate(train_ldr): (X, y) = batch # X = pixels, y = target labels X.requires_grad = True optimizer.zero_grad() oupt = net(X) loss_val = loss_func(oupt, y) # a tensor ep_loss += loss_val.item() # accumulate loss_val.backward() # compute grads optimizer.step() # update weights if epoch % ep_log_interval == 0: print("epoch = %4d | loss = %9.4f" % (epoch, ep_loss)) net.eval() if epoch % grad_accum_interval == 0: curr_batch_grads = X.grad # [bs, features] acc_batch_grads += np.abs(curr_batch_grads.numpy()) print("Done ") # ----------------------------------------------------------- # 4. evaluate model accuracy print("\nComputing model accuracy") net.eval() acc_train = accuracy(net, train_ds) # item-by-item print("Accuracy on training data = %0.4f" % acc_train) net.eval() acc_test = accuracy(net, test_ds) print("Accuracy on test data = %0.4f" % acc_test) print("\nConfusion matrix test data: ") cm = confusion_matrix_multi(net, test_ds, n_classes=3) show_confusion(cm) # ----------------------------------------------------------- # 4b. show interpretability info print("\nAvg gradients sex, age, (s1, s2, s3), income: ") raw_avg_grads = np.mean(acc_batch_grads, axis=0) norm_avg_grads = raw_avg_grads / np.sum(raw_avg_grads) print("raw: ", end=""); print(raw_avg_grads) print("normalized: ", end=""); print(norm_avg_grads) # ----------------------------------------------------------- # 5. use model # print("\nPredicting politics for M 30 oklahoma $50,000: ") # X = np.array([[-1, 0.30, 0,0,1, 0.5000]], dtype=np.float32) # X = T.tensor(X, dtype=T.float32).to(device) # with T.no_grad(): # logits = net(X) # do not sum to 1.0 # probs = T.exp(logits) # sum to 1.0 # probs = probs.numpy() # numpy vector prints better # np.set_printoptions(precision=4, suppress=True) # print(probs) # ----------------------------------------------------------- # 6. save model # print("\nSaving trained model state") # fn = ".\\Models\\people_model.pt" # T.save(net.state_dict(), fn) print("\nEnd Transformer demo ") if __name__ == "__main__": main()
Training data. Replace space-space with tab characters.
# people_train.txt # sex (M=-1, F=1) age state (michigan, # nebraska, oklahoma) income # politics (consrvative, moderate, liberal) # 1 0.24 1 0 0 0.2950 2 -1 0.39 0 0 1 0.5120 1 1 0.63 0 1 0 0.7580 0 -1 0.36 1 0 0 0.4450 1 1 0.27 0 1 0 0.2860 2 1 0.50 0 1 0 0.5650 1 1 0.50 0 0 1 0.5500 1 -1 0.19 0 0 1 0.3270 0 1 0.22 0 1 0 0.2770 1 -1 0.39 0 0 1 0.4710 2 1 0.34 1 0 0 0.3940 1 -1 0.22 1 0 0 0.3350 0 1 0.35 0 0 1 0.3520 2 -1 0.33 0 1 0 0.4640 1 1 0.45 0 1 0 0.5410 1 1 0.42 0 1 0 0.5070 1 -1 0.33 0 1 0 0.4680 1 1 0.25 0 0 1 0.3000 1 -1 0.31 0 1 0 0.4640 0 1 0.27 1 0 0 0.3250 2 1 0.48 1 0 0 0.5400 1 -1 0.64 0 1 0 0.7130 2 1 0.61 0 1 0 0.7240 0 1 0.54 0 0 1 0.6100 0 1 0.29 1 0 0 0.3630 0 1 0.50 0 0 1 0.5500 1 1 0.55 0 0 1 0.6250 0 1 0.40 1 0 0 0.5240 0 1 0.22 1 0 0 0.2360 2 1 0.68 0 1 0 0.7840 0 -1 0.60 1 0 0 0.7170 2 -1 0.34 0 0 1 0.4650 1 -1 0.25 0 0 1 0.3710 0 -1 0.31 0 1 0 0.4890 1 1 0.43 0 0 1 0.4800 1 1 0.58 0 1 0 0.6540 2 -1 0.55 0 1 0 0.6070 2 -1 0.43 0 1 0 0.5110 1 -1 0.43 0 0 1 0.5320 1 -1 0.21 1 0 0 0.3720 0 1 0.55 0 0 1 0.6460 0 1 0.64 0 1 0 0.7480 0 -1 0.41 1 0 0 0.5880 1 1 0.64 0 0 1 0.7270 0 -1 0.56 0 0 1 0.6660 2 1 0.31 0 0 1 0.3600 1 -1 0.65 0 0 1 0.7010 2 1 0.55 0 0 1 0.6430 0 -1 0.25 1 0 0 0.4030 0 1 0.46 0 0 1 0.5100 1 -1 0.36 1 0 0 0.5350 0 1 0.52 0 1 0 0.5810 1 1 0.61 0 0 1 0.6790 0 1 0.57 0 0 1 0.6570 0 -1 0.46 0 1 0 0.5260 1 -1 0.62 1 0 0 0.6680 2 1 0.55 0 0 1 0.6270 0 -1 0.22 0 0 1 0.2770 1 -1 0.50 1 0 0 0.6290 0 -1 0.32 0 1 0 0.4180 1 -1 0.21 0 0 1 0.3560 0 1 0.44 0 1 0 0.5200 1 1 0.46 0 1 0 0.5170 1 1 0.62 0 1 0 0.6970 0 1 0.57 0 1 0 0.6640 0 -1 0.67 0 0 1 0.7580 2 1 0.29 1 0 0 0.3430 2 1 0.53 1 0 0 0.6010 0 -1 0.44 1 0 0 0.5480 1 1 0.46 0 1 0 0.5230 1 -1 0.20 0 1 0 0.3010 1 -1 0.38 1 0 0 0.5350 1 1 0.50 0 1 0 0.5860 1 1 0.33 0 1 0 0.4250 1 -1 0.33 0 1 0 0.3930 1 1 0.26 0 1 0 0.4040 0 1 0.58 1 0 0 0.7070 0 1 0.43 0 0 1 0.4800 1 -1 0.46 1 0 0 0.6440 0 1 0.60 1 0 0 0.7170 0 -1 0.42 1 0 0 0.4890 1 -1 0.56 0 0 1 0.5640 2 -1 0.62 0 1 0 0.6630 2 -1 0.50 1 0 0 0.6480 1 1 0.47 0 0 1 0.5200 1 -1 0.67 0 1 0 0.8040 2 -1 0.40 0 0 1 0.5040 1 1 0.42 0 1 0 0.4840 1 1 0.64 1 0 0 0.7200 0 -1 0.47 1 0 0 0.5870 2 1 0.45 0 1 0 0.5280 1 -1 0.25 0 0 1 0.4090 0 1 0.38 1 0 0 0.4840 0 1 0.55 0 0 1 0.6000 1 -1 0.44 1 0 0 0.6060 1 1 0.33 1 0 0 0.4100 1 1 0.34 0 0 1 0.3900 1 1 0.27 0 1 0 0.3370 2 1 0.32 0 1 0 0.4070 1 1 0.42 0 0 1 0.4700 1 -1 0.24 0 0 1 0.4030 0 1 0.42 0 1 0 0.5030 1 1 0.25 0 0 1 0.2800 2 1 0.51 0 1 0 0.5800 1 -1 0.55 0 1 0 0.6350 2 1 0.44 1 0 0 0.4780 2 -1 0.18 1 0 0 0.3980 0 -1 0.67 0 1 0 0.7160 2 1 0.45 0 0 1 0.5000 1 1 0.48 1 0 0 0.5580 1 -1 0.25 0 1 0 0.3900 1 -1 0.67 1 0 0 0.7830 1 1 0.37 0 0 1 0.4200 1 -1 0.32 1 0 0 0.4270 1 1 0.48 1 0 0 0.5700 1 -1 0.66 0 0 1 0.7500 2 1 0.61 1 0 0 0.7000 0 -1 0.58 0 0 1 0.6890 1 1 0.19 1 0 0 0.2400 2 1 0.38 0 0 1 0.4300 1 -1 0.27 1 0 0 0.3640 1 1 0.42 1 0 0 0.4800 1 1 0.60 1 0 0 0.7130 0 -1 0.27 0 0 1 0.3480 0 1 0.29 0 1 0 0.3710 0 -1 0.43 1 0 0 0.5670 1 1 0.48 1 0 0 0.5670 1 1 0.27 0 0 1 0.2940 2 -1 0.44 1 0 0 0.5520 0 1 0.23 0 1 0 0.2630 2 -1 0.36 0 1 0 0.5300 2 1 0.64 0 0 1 0.7250 0 1 0.29 0 0 1 0.3000 2 -1 0.33 1 0 0 0.4930 1 -1 0.66 0 1 0 0.7500 2 -1 0.21 0 0 1 0.3430 0 1 0.27 1 0 0 0.3270 2 1 0.29 1 0 0 0.3180 2 -1 0.31 1 0 0 0.4860 1 1 0.36 0 0 1 0.4100 1 1 0.49 0 1 0 0.5570 1 -1 0.28 1 0 0 0.3840 0 -1 0.43 0 0 1 0.5660 1 -1 0.46 0 1 0 0.5880 1 1 0.57 1 0 0 0.6980 0 -1 0.52 0 0 1 0.5940 1 -1 0.31 0 0 1 0.4350 1 -1 0.55 1 0 0 0.6200 2 1 0.50 1 0 0 0.5640 1 1 0.48 0 1 0 0.5590 1 -1 0.22 0 0 1 0.3450 0 1 0.59 0 0 1 0.6670 0 1 0.34 1 0 0 0.4280 2 -1 0.64 1 0 0 0.7720 2 1 0.29 0 0 1 0.3350 2 -1 0.34 0 1 0 0.4320 1 -1 0.61 1 0 0 0.7500 2 1 0.64 0 0 1 0.7110 0 -1 0.29 1 0 0 0.4130 0 1 0.63 0 1 0 0.7060 0 -1 0.29 0 1 0 0.4000 0 -1 0.51 1 0 0 0.6270 1 -1 0.24 0 0 1 0.3770 0 1 0.48 0 1 0 0.5750 1 1 0.18 1 0 0 0.2740 0 1 0.18 1 0 0 0.2030 2 1 0.33 0 1 0 0.3820 2 -1 0.20 0 0 1 0.3480 0 1 0.29 0 0 1 0.3300 2 -1 0.44 0 0 1 0.6300 0 -1 0.65 0 0 1 0.8180 0 -1 0.56 1 0 0 0.6370 2 -1 0.52 0 0 1 0.5840 1 -1 0.29 0 1 0 0.4860 0 -1 0.47 0 1 0 0.5890 1 1 0.68 1 0 0 0.7260 2 1 0.31 0 0 1 0.3600 1 1 0.61 0 1 0 0.6250 2 1 0.19 0 1 0 0.2150 2 1 0.38 0 0 1 0.4300 1 -1 0.26 1 0 0 0.4230 0 1 0.61 0 1 0 0.6740 0 1 0.40 1 0 0 0.4650 1 -1 0.49 1 0 0 0.6520 1 1 0.56 1 0 0 0.6750 0 -1 0.48 0 1 0 0.6600 1 1 0.52 1 0 0 0.5630 2 -1 0.18 1 0 0 0.2980 0 -1 0.56 0 0 1 0.5930 2 -1 0.52 0 1 0 0.6440 1 -1 0.18 0 1 0 0.2860 1 -1 0.58 1 0 0 0.6620 2 -1 0.39 0 1 0 0.5510 1 -1 0.46 1 0 0 0.6290 1 -1 0.40 0 1 0 0.4620 1 -1 0.60 1 0 0 0.7270 2 1 0.36 0 1 0 0.4070 2 1 0.44 1 0 0 0.5230 1 1 0.28 1 0 0 0.3130 2 1 0.54 0 0 1 0.6260 0
Test data.
-1 0.51 1 0 0 0.6120 1 -1 0.32 0 1 0 0.4610 1 1 0.55 1 0 0 0.6270 0 1 0.25 0 0 1 0.2620 2 1 0.33 0 0 1 0.3730 2 -1 0.29 0 1 0 0.4620 0 1 0.65 1 0 0 0.7270 0 -1 0.43 0 1 0 0.5140 1 -1 0.54 0 1 0 0.6480 2 1 0.61 0 1 0 0.7270 0 1 0.52 0 1 0 0.6360 0 1 0.30 0 1 0 0.3350 2 1 0.29 1 0 0 0.3140 2 -1 0.47 0 0 1 0.5940 1 1 0.39 0 1 0 0.4780 1 1 0.47 0 0 1 0.5200 1 -1 0.49 1 0 0 0.5860 1 -1 0.63 0 0 1 0.6740 2 -1 0.30 1 0 0 0.3920 0 -1 0.61 0 0 1 0.6960 2 -1 0.47 0 0 1 0.5870 1 1 0.30 0 0 1 0.3450 2 -1 0.51 0 0 1 0.5800 1 -1 0.24 1 0 0 0.3880 1 -1 0.49 1 0 0 0.6450 1 1 0.66 0 0 1 0.7450 0 -1 0.65 1 0 0 0.7690 0 -1 0.46 0 1 0 0.5800 0 -1 0.45 0 0 1 0.5180 1 -1 0.47 1 0 0 0.6360 0 -1 0.29 1 0 0 0.4480 0 -1 0.57 0 0 1 0.6930 2 -1 0.20 1 0 0 0.2870 2 -1 0.35 1 0 0 0.4340 1 -1 0.61 0 0 1 0.6700 2 -1 0.31 0 0 1 0.3730 1 1 0.18 1 0 0 0.2080 2 1 0.26 0 0 1 0.2920 2 -1 0.28 1 0 0 0.3640 2 -1 0.59 0 0 1 0.6940 2
You must be logged in to post a comment.