PyTorch Multi-Class Classification Using a Transformer with Custom Accuracy and Interpretability

I put together a PyTorch neural network multi-class classification demo. For the neural architecture, I dropped in a TransformerEncoder hidden layer. For model accuracy, I implemented a custom confusion matrix. For interpretability, I implemented custom input gradient monitoring.

I used one of my standard synthetic datasets. The data looks like:

 1   0.24   1  0  0   0.2950   2
-1   0.39   0  0  1   0.5120   1
 1   0.63   0  1  0   0.7580   0
-1   0.36   1  0  0   0.4450   1
. . .

The fields are sex (male = -1, female = +1), age (divided by 100), State (Michigan = 100, Nebraska = 010, Oklahoma = 001), income (divided by $100,000), and political leaning (0 = conservative, moderate = 1, liberal = 2). The goal is to predict political leaning from sex, age, State, and income. There are 200 training items and 40 test items.

Adding a TransformerEncoder layer to a multi-class classification network is not simple. A shortcut notation for my demo architecture is (6–24)-T-10-3. Each of the 6 input item is fed to a custom pseudo-embedding layer with dim = 4, resulting in 24 values. Then, positional encoding is added even though the predictor variables don’t have any implicit ordering, other than the one-hot encoded State of residence variable. The output of the Transformer layer is fed to a standard hidden layer with 10 nodes. The output of that hidden layer is sent to an output layer with 3 nodes.

For interpretability, I capture and accumulate the gradients of the 6 input variables, every 200 training epochs, and average the 6 sums after training. The results showed that changes in age and income have the greatest effect:

 sex,    age,    state1, state2, state3, income
[0.0574, 0.4335, 0.0302, 0.0307, 0.0304, 0.4178]

I could just capture the values of the input gradients after training, but the final input gradients could be small. This is something I need to explore further when I get some free time.

I experimented with several different approaches for monitoring the gradients of the input variables. For example, I modified the network architecture to include an explicit input layer, but it added a lot of complexity without benefit. Monitoring input gradients is a surprisingly tricky task. In the end, I just applied a requires_grad=True attribute to each batch of input items.

Good fun.



For reasons unknown to me, an Internet image search for “interpretability” served up all kinds of images of cyborgs — cyborg animals, cyborg geishas, cyborg whatever. Thank you Internet.


Demo code. Replace “lt” (less than), “gt”, “lte”, “gte” with Boolean operator symbols.

# people_transformer_interpret.py
# PyTorch 2.0.0-CPU Anaconda3-2022.10  Python 3.9.13
# Windows 10/11

# Transformer component for political leaning classification

import numpy as np
import torch as T

device = T.device('cpu')
T.set_num_threads(1)

# -----------------------------------------------------------

class PeopleDataset(T.utils.data.Dataset):
  # sex  age    state    income   politics
  # -1   0.27   0  1  0   0.7610   2
  # +1   0.19   0  0  1   0.6550   0
  # sex: -1 = male, +1 = female
  # state: michigan, nebraska, oklahoma
  # politics: conservative, moderate, liberal

  def __init__(self, src_file):
    all_xy = np.loadtxt(src_file, usecols=range(0,7),
      delimiter="\t", comments="#", dtype=np.float32)
    tmp_x = all_xy[:,0:6]   # cols [0,6) = [0,5]
    tmp_y = all_xy[:,6]     # 1-D

    self.x_data = T.tensor(tmp_x, 
      dtype=T.float32).to(device)
    self.y_data = T.tensor(tmp_y,
      dtype=T.int64).to(device)  # 1-D

  def __len__(self):
    return len(self.x_data)

  def __getitem__(self, idx):
    preds = self.x_data[idx]
    trgts = self.y_data[idx] 
    return preds, trgts  # as a Tuple

# -----------------------------------------------------------

class SkipLinear(T.nn.Module):

  # -----

  class Core(T.nn.Module):
    def __init__(self, n):
      super().__init__()
      # 1 node to n nodes, n gte 2
      self.weights = T.nn.Parameter(T.zeros((n,1),
        dtype=T.float32))
      self.biases = T.nn.Parameter(T.tensor(n,
        dtype=T.float32))
      lim = 0.01
      T.nn.init.uniform_(self.weights, -lim, lim)
      T.nn.init.zeros_(self.biases)

    def forward(self, x):
      wx= T.mm(x, self.weights.t())
      v = T.add(wx, self.biases)
      return v

  # -----

  def __init__(self, n_in, n_out):
    super().__init__()
    self.n_in = n_in; self.n_out = n_out
    if n_out  % n_in != 0:
      print("FATAL: n_out must be divisible by n_in")
    n = n_out // n_in  # num nodes per input

    self.lst_modules = \
      T.nn.ModuleList([SkipLinear.Core(n) for \
        i in range(n_in)])

  def forward(self, x):
    lst_nodes = []
    for i in range(self.n_in):
      xi = x[:,i].reshape(-1,1)
      oupt = self.lst_modules[i](xi)
      lst_nodes.append(oupt)
    result = T.cat((lst_nodes[0], lst_nodes[1]), 1)
    for i in range(2,self.n_in):
      result = T.cat((result, lst_nodes[i]), 1)
    result = result.reshape(-1, self.n_out)
    return result

# -----------------------------------------------------------

class TransformerNet(T.nn.Module):  # (6--24)-T-10-3
  def __init__(self):
    super(TransformerNet, self).__init__()  # old syntax

    # numeric pseudo-embedding, dim=4
    self.embed = SkipLinear(6, 24)  # 6 inputs, each goes to 4 

    self.pos_enc = \
      PositionalEncoding(4, dropout=0.00)  # positional

    self.enc_layer = T.nn.TransformerEncoderLayer(d_model=4,
      nhead=2, dim_feedforward=10, 
      batch_first=True)  # d_model divisible by nhead

    self.trans_enc = T.nn.TransformerEncoder(self.enc_layer,
      num_layers=2)  # 6 layers default

    # People dataset has 6 inputs
    self.fc1 = T.nn.Linear(4*6, 10)  # 10 hidden nodes
    self.fc2 = T.nn.Linear(10, 3)    # 3 classes

  def forward(self, x):
    # x = 6 inputs, fixed length
    z = self.embed(x)  # 6 inpts to 24 embed 
    z = z.reshape(-1, 6, 4)  # bat seq embed 
    z = self.pos_enc(z) 
    z = self.trans_enc(z) 
    z = z.reshape(-1, 4*6)  # torch.Size([bs, xxx])
    z = T.tanh(self.fc1(z))
    z = T.log_softmax(self.fc2(z), dim=1)  # NLLLoss()
    return z 

# -----------------------------------------------------------

class PositionalEncoding(T.nn.Module):  # documentation code
  def __init__(self, d_model: int, dropout: float=0.1,
   max_len: int=5000):
    super(PositionalEncoding, self).__init__()  # old syntax
    self.dropout = T.nn.Dropout(p=dropout)
    pe = T.zeros(max_len, d_model)  # like 10x4
    position = \
      T.arange(0, max_len, dtype=T.float).unsqueeze(1)
    div_term = T.exp(T.arange(0, d_model, 2).float() * \
      (-np.log(10_000.0) / d_model))
    pe[:, 0::2] = T.sin(position * div_term)
    pe[:, 1::2] = T.cos(position * div_term)
    pe = pe.unsqueeze(0).transpose(0, 1)
    self.register_buffer('pe', pe)  # allows state-save

  def forward(self, x):
    x = x + self.pe[:x.size(0), :]
    return self.dropout(x)

# -----------------------------------------------------------

def accuracy(model, ds):
  # assumes model.eval()
  # item-by-item version
  n_correct = 0; n_wrong = 0
  for i in range(len(ds)):
    X = ds[i][0].reshape(1,-1)  # make it a batch
    Y = ds[i][1].reshape(1)  # 0 1 or 2, 1D
    with T.no_grad():
      oupt = model(X)  # logits form

    big_idx = T.argmax(oupt)  # 0 or 1
    if big_idx == Y:
      n_correct += 1
    else:
      n_wrong += 1

  acc = (n_correct * 1.0) / (n_correct + n_wrong)
  return acc

# -----------------------------------------------------------

def confusion_matrix_multi(model, ds, n_classes):
  if n_classes "lte" 2:  # less-than-or-equal
    print("ERROR: n_classes must be 3 or greater ")
    return None

  cm = np.zeros((n_classes,n_classes), dtype=np.int64)
  for i in range(len(ds)):
    X = ds[i][0].reshape(1,-1)  # make it a batch
    Y = ds[i][1].reshape(1)  # actual class 0 1 or 2, 1D
    with T.no_grad():
      oupt = model(X)  # logits form
    pred_class = T.argmax(oupt)  # 0,1,2
    cm[Y][pred_class] += 1
  return cm

# -----------------------------------------------------------

def show_confusion(cm):
  dim = len(cm)
  mx = np.max(cm)             # largest count in cm
  wid = len(str(mx)) + 1      # width to print
  fmt = "%" + str(wid) + "d"  # like "%3d"
  for i in range(dim):
    print("actual   ", end="")
    print("%3d:" % i, end="")
    for j in range(dim):
      print(fmt % cm[i][j], end="")
    print("")
  print("------------")
  print("predicted    ", end="")
  for j in range(dim):
    print(fmt % j, end="")
  print("")

  # accuracy for each class
  row_sums = np.sum(cm, axis=1)
  accs = np.zeros(dim, dtype=np.float32)
  for i in range(dim):
    accs[i] = cm[i][i] / row_sums[i]
  print("\naccuracy by class: ")
  print(accs)

# -----------------------------------------------------------

def main():
  # 0. setup
  print("\nBegin Transformer demo ")
  np.random.seed(1) 
  T.manual_seed(1)
  np.set_printoptions(precision=4, suppress=True,
    floatmode='fixed')
  T.set_printoptions(precision=4, sci_mode=False)

  # 1. create Dataset
  print("\nCreating 200-item train Dataset from text file ")
  train_file = ".\\Data\\people_train.txt"
  train_ds = PeopleDataset(train_file)

  test_file = ".\\Data\\people_test.txt"
  test_ds = PeopleDataset(test_file)

  bat_size = 10
  train_ldr = T.utils.data.DataLoader(train_ds,
    batch_size=bat_size, shuffle=True)

  # 2. create network
  print("\nCreating Transformer network ")
  net = TransformerNet().to(device)
  
# -----------------------------------------------------------

  # 3. train model
  max_epochs = 2000
  ep_log_interval = 400
  lrn_rate = 0.025

  grad_accum_interval = 200  # for interpretability
  
  loss_func = T.nn.NLLLoss()  # assumes log-softmax()
  optimizer = T.optim.SGD(net.parameters(), lr=lrn_rate)
  # optimizer = T.optim.Adam(net.parameters(), lr=lrn_rate)
  
  print("\nbat_size = %3d " % bat_size)
  print("loss = " + str(loss_func))
  print("optimizer = SGD")
  print("lrn_rate = %0.3f " % lrn_rate)
  print("max_epochs = %3d " % max_epochs)

  acc_batch_grads = np.zeros((bat_size, 6),
    dtype=np.float32)  # accumulated gradients

  print("\nStarting training")
  net.train()  # set mode
  for epoch in range(0, max_epochs):
    ep_loss = 0.0  # for one full epoch
    for (batch_idx, batch) in enumerate(train_ldr):
      (X, y) = batch  # X = pixels, y = target labels

      X.requires_grad = True

      optimizer.zero_grad()
      oupt = net(X)
      loss_val = loss_func(oupt, y)  # a tensor
      ep_loss += loss_val.item()  # accumulate
      loss_val.backward()  # compute grads
      optimizer.step()     # update weights

    if epoch % ep_log_interval == 0:
      print("epoch = %4d   |  loss = %9.4f" % (epoch, ep_loss))
      net.eval()

    if epoch % grad_accum_interval == 0:
      curr_batch_grads = X.grad  # [bs, features] 
      acc_batch_grads += np.abs(curr_batch_grads.numpy())

  print("Done ") 

# -----------------------------------------------------------

  # 4. evaluate model accuracy
  print("\nComputing model accuracy")
  net.eval()
  acc_train = accuracy(net, train_ds)  # item-by-item
  print("Accuracy on training data = %0.4f" % acc_train)

  net.eval()
  acc_test = accuracy(net, test_ds) 
  print("Accuracy on test data = %0.4f" % acc_test)

  print("\nConfusion matrix test data: ")
  cm = confusion_matrix_multi(net, test_ds, n_classes=3)
  show_confusion(cm)

# -----------------------------------------------------------
  
  # 4b. show interpretability info
  print("\nAvg gradients sex, age, (s1, s2, s3), income: ")
  raw_avg_grads = np.mean(acc_batch_grads, axis=0)
  norm_avg_grads = raw_avg_grads / np.sum(raw_avg_grads)
  
  print("raw:        ", end=""); print(raw_avg_grads)
  print("normalized: ", end=""); print(norm_avg_grads)
  
# -----------------------------------------------------------

  # 5. use model
  # print("\nPredicting politics for M  30  oklahoma  $50,000: ")
  # X = np.array([[-1, 0.30,  0,0,1,  0.5000]], dtype=np.float32)
  # X = T.tensor(X, dtype=T.float32).to(device) 

  # with T.no_grad():
  #   logits = net(X)  # do not sum to 1.0
  # probs = T.exp(logits)  # sum to 1.0
  # probs = probs.numpy()  # numpy vector prints better
  # np.set_printoptions(precision=4, suppress=True)
  # print(probs)

# -----------------------------------------------------------

  # 6. save model
  # print("\nSaving trained model state")
  # fn = ".\\Models\\people_model.pt"
  # T.save(net.state_dict(), fn)  

  print("\nEnd Transformer demo ")

if __name__ == "__main__":
  main()

Training data. Replace space-space with tab characters.

# people_train.txt
# sex (M=-1, F=1)  age  state (michigan, 
# nebraska, oklahoma) income
# politics (consrvative, moderate, liberal)
#
1  0.24  1  0  0  0.2950  2
-1  0.39  0  0  1  0.5120  1
1  0.63  0  1  0  0.7580  0
-1  0.36  1  0  0  0.4450  1
1  0.27  0  1  0  0.2860  2
1  0.50  0  1  0  0.5650  1
1  0.50  0  0  1  0.5500  1
-1  0.19  0  0  1  0.3270  0
1  0.22  0  1  0  0.2770  1
-1  0.39  0  0  1  0.4710  2
1  0.34  1  0  0  0.3940  1
-1  0.22  1  0  0  0.3350  0
1  0.35  0  0  1  0.3520  2
-1  0.33  0  1  0  0.4640  1
1  0.45  0  1  0  0.5410  1
1  0.42  0  1  0  0.5070  1
-1  0.33  0  1  0  0.4680  1
1  0.25  0  0  1  0.3000  1
-1  0.31  0  1  0  0.4640  0
1  0.27  1  0  0  0.3250  2
1  0.48  1  0  0  0.5400  1
-1  0.64  0  1  0  0.7130  2
1  0.61  0  1  0  0.7240  0
1  0.54  0  0  1  0.6100  0
1  0.29  1  0  0  0.3630  0
1  0.50  0  0  1  0.5500  1
1  0.55  0  0  1  0.6250  0
1  0.40  1  0  0  0.5240  0
1  0.22  1  0  0  0.2360  2
1  0.68  0  1  0  0.7840  0
-1  0.60  1  0  0  0.7170  2
-1  0.34  0  0  1  0.4650  1
-1  0.25  0  0  1  0.3710  0
-1  0.31  0  1  0  0.4890  1
1  0.43  0  0  1  0.4800  1
1  0.58  0  1  0  0.6540  2
-1  0.55  0  1  0  0.6070  2
-1  0.43  0  1  0  0.5110  1
-1  0.43  0  0  1  0.5320  1
-1  0.21  1  0  0  0.3720  0
1  0.55  0  0  1  0.6460  0
1  0.64  0  1  0  0.7480  0
-1  0.41  1  0  0  0.5880  1
1  0.64  0  0  1  0.7270  0
-1  0.56  0  0  1  0.6660  2
1  0.31  0  0  1  0.3600  1
-1  0.65  0  0  1  0.7010  2
1  0.55  0  0  1  0.6430  0
-1  0.25  1  0  0  0.4030  0
1  0.46  0  0  1  0.5100  1
-1  0.36  1  0  0  0.5350  0
1  0.52  0  1  0  0.5810  1
1  0.61  0  0  1  0.6790  0
1  0.57  0  0  1  0.6570  0
-1  0.46  0  1  0  0.5260  1
-1  0.62  1  0  0  0.6680  2
1  0.55  0  0  1  0.6270  0
-1  0.22  0  0  1  0.2770  1
-1  0.50  1  0  0  0.6290  0
-1  0.32  0  1  0  0.4180  1
-1  0.21  0  0  1  0.3560  0
1  0.44  0  1  0  0.5200  1
1  0.46  0  1  0  0.5170  1
1  0.62  0  1  0  0.6970  0
1  0.57  0  1  0  0.6640  0
-1  0.67  0  0  1  0.7580  2
1  0.29  1  0  0  0.3430  2
1  0.53  1  0  0  0.6010  0
-1  0.44  1  0  0  0.5480  1
1  0.46  0  1  0  0.5230  1
-1  0.20  0  1  0  0.3010  1
-1  0.38  1  0  0  0.5350  1
1  0.50  0  1  0  0.5860  1
1  0.33  0  1  0  0.4250  1
-1  0.33  0  1  0  0.3930  1
1  0.26  0  1  0  0.4040  0
1  0.58  1  0  0  0.7070  0
1  0.43  0  0  1  0.4800  1
-1  0.46  1  0  0  0.6440  0
1  0.60  1  0  0  0.7170  0
-1  0.42  1  0  0  0.4890  1
-1  0.56  0  0  1  0.5640  2
-1  0.62  0  1  0  0.6630  2
-1  0.50  1  0  0  0.6480  1
1  0.47  0  0  1  0.5200  1
-1  0.67  0  1  0  0.8040  2
-1  0.40  0  0  1  0.5040  1
1  0.42  0  1  0  0.4840  1
1  0.64  1  0  0  0.7200  0
-1  0.47  1  0  0  0.5870  2
1  0.45  0  1  0  0.5280  1
-1  0.25  0  0  1  0.4090  0
1  0.38  1  0  0  0.4840  0
1  0.55  0  0  1  0.6000  1
-1  0.44  1  0  0  0.6060  1
1  0.33  1  0  0  0.4100  1
1  0.34  0  0  1  0.3900  1
1  0.27  0  1  0  0.3370  2
1  0.32  0  1  0  0.4070  1
1  0.42  0  0  1  0.4700  1
-1  0.24  0  0  1  0.4030  0
1  0.42  0  1  0  0.5030  1
1  0.25  0  0  1  0.2800  2
1  0.51  0  1  0  0.5800  1
-1  0.55  0  1  0  0.6350  2
1  0.44  1  0  0  0.4780  2
-1  0.18  1  0  0  0.3980  0
-1  0.67  0  1  0  0.7160  2
1  0.45  0  0  1  0.5000  1
1  0.48  1  0  0  0.5580  1
-1  0.25  0  1  0  0.3900  1
-1  0.67  1  0  0  0.7830  1
1  0.37  0  0  1  0.4200  1
-1  0.32  1  0  0  0.4270  1
1  0.48  1  0  0  0.5700  1
-1  0.66  0  0  1  0.7500  2
1  0.61  1  0  0  0.7000  0
-1  0.58  0  0  1  0.6890  1
1  0.19  1  0  0  0.2400  2
1  0.38  0  0  1  0.4300  1
-1  0.27  1  0  0  0.3640  1
1  0.42  1  0  0  0.4800  1
1  0.60  1  0  0  0.7130  0
-1  0.27  0  0  1  0.3480  0
1  0.29  0  1  0  0.3710  0
-1  0.43  1  0  0  0.5670  1
1  0.48  1  0  0  0.5670  1
1  0.27  0  0  1  0.2940  2
-1  0.44  1  0  0  0.5520  0
1  0.23  0  1  0  0.2630  2
-1  0.36  0  1  0  0.5300  2
1  0.64  0  0  1  0.7250  0
1  0.29  0  0  1  0.3000  2
-1  0.33  1  0  0  0.4930  1
-1  0.66  0  1  0  0.7500  2
-1  0.21  0  0  1  0.3430  0
1  0.27  1  0  0  0.3270  2
1  0.29  1  0  0  0.3180  2
-1  0.31  1  0  0  0.4860  1
1  0.36  0  0  1  0.4100  1
1  0.49  0  1  0  0.5570  1
-1  0.28  1  0  0  0.3840  0
-1  0.43  0  0  1  0.5660  1
-1  0.46  0  1  0  0.5880  1
1  0.57  1  0  0  0.6980  0
-1  0.52  0  0  1  0.5940  1
-1  0.31  0  0  1  0.4350  1
-1  0.55  1  0  0  0.6200  2
1  0.50  1  0  0  0.5640  1
1  0.48  0  1  0  0.5590  1
-1  0.22  0  0  1  0.3450  0
1  0.59  0  0  1  0.6670  0
1  0.34  1  0  0  0.4280  2
-1  0.64  1  0  0  0.7720  2
1  0.29  0  0  1  0.3350  2
-1  0.34  0  1  0  0.4320  1
-1  0.61  1  0  0  0.7500  2
1  0.64  0  0  1  0.7110  0
-1  0.29  1  0  0  0.4130  0
1  0.63  0  1  0  0.7060  0
-1  0.29  0  1  0  0.4000  0
-1  0.51  1  0  0  0.6270  1
-1  0.24  0  0  1  0.3770  0
1  0.48  0  1  0  0.5750  1
1  0.18  1  0  0  0.2740  0
1  0.18  1  0  0  0.2030  2
1  0.33  0  1  0  0.3820  2
-1  0.20  0  0  1  0.3480  0
1  0.29  0  0  1  0.3300  2
-1  0.44  0  0  1  0.6300  0
-1  0.65  0  0  1  0.8180  0
-1  0.56  1  0  0  0.6370  2
-1  0.52  0  0  1  0.5840  1
-1  0.29  0  1  0  0.4860  0
-1  0.47  0  1  0  0.5890  1
1  0.68  1  0  0  0.7260  2
1  0.31  0  0  1  0.3600  1
1  0.61  0  1  0  0.6250  2
1  0.19  0  1  0  0.2150  2
1  0.38  0  0  1  0.4300  1
-1  0.26  1  0  0  0.4230  0
1  0.61  0  1  0  0.6740  0
1  0.40  1  0  0  0.4650  1
-1  0.49  1  0  0  0.6520  1
1  0.56  1  0  0  0.6750  0
-1  0.48  0  1  0  0.6600  1
1  0.52  1  0  0  0.5630  2
-1  0.18  1  0  0  0.2980  0
-1  0.56  0  0  1  0.5930  2
-1  0.52  0  1  0  0.6440  1
-1  0.18  0  1  0  0.2860  1
-1  0.58  1  0  0  0.6620  2
-1  0.39  0  1  0  0.5510  1
-1  0.46  1  0  0  0.6290  1
-1  0.40  0  1  0  0.4620  1
-1  0.60  1  0  0  0.7270  2
1  0.36  0  1  0  0.4070  2
1  0.44  1  0  0  0.5230  1
1  0.28  1  0  0  0.3130  2
1  0.54  0  0  1  0.6260  0

Test data.

-1  0.51  1  0  0  0.6120  1
-1  0.32  0  1  0  0.4610  1
1  0.55  1  0  0  0.6270  0
1  0.25  0  0  1  0.2620  2
1  0.33  0  0  1  0.3730  2
-1  0.29  0  1  0  0.4620  0
1  0.65  1  0  0  0.7270  0
-1  0.43  0  1  0  0.5140  1
-1  0.54  0  1  0  0.6480  2
1  0.61  0  1  0  0.7270  0
1  0.52  0  1  0  0.6360  0
1  0.30  0  1  0  0.3350  2
1  0.29  1  0  0  0.3140  2
-1  0.47  0  0  1  0.5940  1
1  0.39  0  1  0  0.4780  1
1  0.47  0  0  1  0.5200  1
-1  0.49  1  0  0  0.5860  1
-1  0.63  0  0  1  0.6740  2
-1  0.30  1  0  0  0.3920  0
-1  0.61  0  0  1  0.6960  2
-1  0.47  0  0  1  0.5870  1
1  0.30  0  0  1  0.3450  2
-1  0.51  0  0  1  0.5800  1
-1  0.24  1  0  0  0.3880  1
-1  0.49  1  0  0  0.6450  1
1  0.66  0  0  1  0.7450  0
-1  0.65  1  0  0  0.7690  0
-1  0.46  0  1  0  0.5800  0
-1  0.45  0  0  1  0.5180  1
-1  0.47  1  0  0  0.6360  0
-1  0.29  1  0  0  0.4480  0
-1  0.57  0  0  1  0.6930  2
-1  0.20  1  0  0  0.2870  2
-1  0.35  1  0  0  0.4340  1
-1  0.61  0  0  1  0.6700  2
-1  0.31  0  0  1  0.3730  1
1  0.18  1  0  0  0.2080  2
1  0.26  0  0  1  0.2920  2
-1  0.28  1  0  0  0.3640  2
-1  0.59  0  0  1  0.6940  2
This entry was posted in PyTorch. Bookmark the permalink.

Leave a comment