PyTorch Multi-Class Classification Using MSELoss Instead of NLLLoss

I was looking at PyTorch model calibration and wondered if using mean squared error loss, instead of the standard negative log-likelihood loss, would have any effect.

So, to run an experiment I needed to refactor one of my standard PyTorch multi-class classification demo programs from NLLLoss() to MSELoss(). Bottom line: It was much trickier than I expected and there were far more changes needed than I had anticipated.

I used one of my standard synthetic datasets. The data looks like:

 1, 0.24, 1, 0, 0, 0.2950, 2
-1, 0.39, 0, 0, 1, 0.5120, 1
 1, 0.63, 0, 1, 0, 0.7580, 0
-1, 0.36, 1, 0, 0, 0.4450, 1
 1, 0.27, 0, 1, 0, 0.2860, 2
. . .

The fields are sex (M = -1, F = 1), age (divided by 100), State (Michigan = 100, Nebraska = 010, Oklahoma = 001), income (divided by $100,000), and political leaning (conservative = 0, moderate = 1, liberal = 2). The goal is to predict political leaning from sex, age, State, and income. There are 200 training items and 40 test items.


Click to enlarge

When using NLLLoss(), the target class is assumed to be ordinal encoded, as in the data snapshot above. But when using MSELoss(), the target class must be one-hot encoded. There are two choices: 1.) encode the target data in a preprocessing step, or 2.) programmatically encode the target data when reading it into memory. Both approaches are OK but I chose to programmatically encode so that I didn’t have an extra pair of train and test data files.

Almost every part of my baseline NLLLoss version program needed some modifications.

1. Target data must be converted from ordinal to one-hot. I wrote a short custom function to avoid the large overhead of using something like the scikit-learn sklearn.preprocessing.OneHotEncoder module.

2. The neural network class must use torch.nn.functional.softmax() output activation instead of torch.log_softmax() activation. This means the output values for the MSELoss() version are pseudo-probabilities that sum to 1, instead of logits that are negative log values that don’t sum to 1.

3. A custom accuracy() function for the MSELoss() version must compare the index of the largest computed pseudo-probability against the index of the largest element in the target vector, instead of comparing the index of the largest output logit vector against the ordinal-encoded target value.

4. When training the MSELoss() version, you must typically use about twice as many training iterations and learning rate that is about five times as large. (If you know how the two loss functions work, this makes sense).

5. When using the trained model to make a prediction with the MSELoss() version, the model emits pseudo-probabilities directly, instead of needing to apply the exp() function to the logit outputs of the NLLLoss() version which are log-softmax values.

Anyway, after about three hours of futzing about, I was able to compare a standard PyTorch NLLoss() multi-class classifier with a roughly equivalent MSELoss() classifier. And, the calibration error for both trained models was roughly similar, and so there were no indications that the choice of loss function affects calibration — at least for this one non-conclusive experiment.

But good fun!



I often create several different versions of PyTorch programs so that I can investigate the effects of the changes. Book covers for popular novels often have many different versions, and each has quite a different feel.

The novel James Bond novel “Thunderball” was first published in the UK in March 1961 as the ninth in the series. It was intended to be the basis of the first Bond movie, but legal issues held things up and the first movie was “Dr. No” in 1962. The movie “Thunderball” was eventually released in 1965 as the fourth Bond film (following “Dr. No”, “From Russia With Love”, and “Goldfinger”).

Left: This is the cover of the first UK edition (March, 1961), published by the Jonathan Cape company. The artist is Richard Chopping (1917-2008), who became quite famous for doing the UK first edition covers for the first nine novels.

Center Left: This is the cover of the first US hardcover edition (April, 1961), published by Viking Press. The artwork was done by an anonymous artist who worked for a company called SA Summit, Inc. This cover is somewhat bland in my opnion.

Center Right: The 1964 US softcover edition by Signet Publishing. Art by Barye Phillips (1924-1969). This is the version I read as a young man. The cover has sort of a meanacing dream-like quality.

Right: The 1965 US softcover edition by Signet Publishing. It features art by the well-known Richard McGinnis (b. 1926). This art also served as the basis for the official movie posters. This cover has a very different feel — sort of a 1960s cool and whimsical vibe.


Demo program. Replace “lt” (less than), “gt”, “lte”, “gte” with Boolean operator symbols. (My lame blog editor often chokes on those symbols).

# people_politics_calibration_error_MSE.py
# predict politics type from sex, age, state, income
# PyTorch 2.2.1-CPU Anaconda3-2023.09  Python 3.11.5
# Windows 10/11 

import numpy as np
import torch as T
device = T.device('cpu')  # apply to Tensor or Module

# -----------------------------------------------------------

class PeopleDataset(T.utils.data.Dataset):
  # sex  age    state    income   politics
  # -1   0.27   0  1  0   0.7610   0 0 1 (or 2)
  # +1   0.19   0  0  1   0.6550   0 1 0 (or 1)
  # sex: -1 = male, +1 = female
  # state: michigan, nebraska, oklahoma
  # politics: conservative, moderate, liberal

  def __init__(self, src_file):
    all_xy = np.loadtxt(src_file, usecols=range(0,7),
      delimiter=",", comments="#", dtype=np.float32)
    tmp_x = all_xy[:,0:6]   # cols [0,6) = [0,5]
    tmp_y = all_xy[:,6].astype(np.int64)
    tmp_y = vec_to_onehot(tmp_y, n_classes=3)   

    self.x_data = T.tensor(tmp_x, 
      dtype=T.float32).to(device)
    self.y_data = T.tensor(tmp_y,
      dtype=T.float32).to(device) 

  def __len__(self):
    return len(self.x_data)

  def __getitem__(self, idx):
    preds = self.x_data[idx]
    trgts = self.y_data[idx] 
    return preds, trgts  # as a Tuple

# -----------------------------------------------------------

def vec_to_onehot(vec, n_classes):
  # vec is an array or r x 1 (2D), float or int
  # result is r x n_classes
  vc = vec.reshape(-1).astype(np.int64)
  result = np.zeros((len(vc), n_classes), dtype=np.float32)
  for i in range(len(result)):
    result[i][vc[i]] = 1.0
  return result

# -----------------------------------------------------------

class Net(T.nn.Module):
  def __init__(self):
    super(Net, self).__init__()
    self.hid1 = T.nn.Linear(6, 10)  # 6-(10-10)-3
    self.hid2 = T.nn.Linear(10, 10)
    self.oupt = T.nn.Linear(10, 3)

    T.nn.init.xavier_uniform_(self.hid1.weight)
    T.nn.init.zeros_(self.hid1.bias)
    T.nn.init.xavier_uniform_(self.hid2.weight)
    T.nn.init.zeros_(self.hid2.bias)
    T.nn.init.xavier_uniform_(self.oupt.weight)
    T.nn.init.zeros_(self.oupt.bias)

  def forward(self, x):
    z = T.tanh(self.hid1(x))
    z = T.tanh(self.hid2(z))
    # z = T.log_softmax(self.oupt(z), dim=1)  # NLLLoss() 
    z = T.nn.functional.softmax(self.oupt(z),
      dim=1) # MSELoss 
    return z

# -----------------------------------------------------------

def accuracy(model, ds):
  # assumes model.eval()
  # item-by-item version
  n_correct = 0; n_wrong = 0
  for i in range(len(ds)):
    X = ds[i][0].reshape(1,-1)  # make it a batch
    Y = ds[i][1]  # 
    with T.no_grad():
      oupt = model(X)  # softmax form

    computed_idx = T.argmax(oupt)
    target_idx = T.argmax(Y)
    if computed_idx == target_idx:
      n_correct += 1
    else:
      n_wrong += 1

  acc = (n_correct * 1.0) / (n_correct + n_wrong)
  return acc

# -----------------------------------------------------------

def accuracy_quick(model, dataset):
  # assumes model.eval()
  X = dataset[0:len(dataset)][0]
  Y = dataset[0:len(dataset)][1]
  with T.no_grad():
    oupt = model(X)  #  [40,3]  pseudo-probs

  
  comp_maxs = T.argmax(oupt, dim=1)  # argmax() is new
  target_maxs = T.argmax(Y, dim=1)
  num_correct = T.sum(target_maxs==comp_maxs)
  acc = (num_correct * 1.0 / len(dataset))
  return acc.item()

# -----------------------------------------------------------

def acc_by_class(model, dataset, n_classes):
  n_corrects = np.zeros(n_classes, dtype=np.int64)
  n_wrongs = np.zeros(n_classes, dtype=np.int64)
  counts = np.zeros(n_classes, dtype=np.int64)

  for i in range(len(dataset)):
    X = dataset[i][0].reshape(1,-1)  # make it a batch
    Y = dataset[i][1]  # like (0, 1, 0)
    with T.no_grad():
      oupt = model(X)  # logits form

    computed_idx = T.argmax(oupt)
    target_idx = T.argmax(Y)

    counts[target_idx.item()] += 1

    if computed_idx == target_idx:
      n_corrects[target_idx.item()] += 1
    else:
      n_wrongs[target_idx.item()] += 1

  print("Counts     : ", end="")
  for c in range(n_classes):
    print("%8d" % counts[c], end="")
  print("")

  print("Correct    : ", end="")
  for c in range(n_classes): 
    print("%8d" % n_corrects[c], end="")
  print("")

  print("Wrong      : ", end="")
  for c in range(n_classes): 
    print("%8d" % n_wrongs[c], end="")
  print("")

  accuracies = n_corrects / counts
  print("Accuracies : ", end="")
  for c in range(n_classes): 
    print("%8.4f" % accuracies[c], end="")
  print("")
    
# -----------------------------------------------------------

def calibration_error(model, ds):
  counts = np.zeros(10, dtype=np.int64)  # of PPs each bin
  sums = np.zeros(10, dtype=np.float32)  # of PPs each bin
  n_corrects = np.zeros(10, dtype=np.int64)  # for each bin
  n_wrongs = np.zeros(10, dtype=np.int64)  # not needed
  accuracies = np.zeros(10, dtype=np.float32)  # each bin
  avg_pps = np.zeros(10, dtype=np.float32)
  abs_diffs = np.zeros(10, dtype=np.float32)

  for i in range(len(ds)):
    X = ds[i][0].reshape(1,-1)  # make it a batch
    Y = ds[i][1]  # like (0, 1, 0)
    with T.no_grad():
      oupt = model(X)  # pps like (0.2, 0.7, 0.1)

    probs = oupt # probs like (0.25, 0.60, 0.15)
    p_max = T.max(probs)  # largest PP like tensor[0.60]
    pp = p_max.item()  # scalar like 0.60

    correct = False
    computed_idx = T.argmax(oupt)  # 0 or 1 or 2
    target_idx = T.argmax(Y)
    if computed_idx == target_idx:
      correct = True
 
    if pp "gte" 0.0 and pp "lt" 0.1: bin = 0
    elif pp "gte" 0.1 and pp "lt" 0.2: bin = 1
    elif pp "gte" 0.2 and pp "lt" 0.3: bin = 2
    elif pp "gte" 0.3 and pp "lt" 0.4: bin = 3
    elif pp "gte" 0.4 and pp "lt" 0.5: bin = 4
    elif pp "gte" 0.5 and pp "lt" 0.6: bin = 5
    elif pp "gte" 0.6 and pp "lt" 0.7: bin = 6
    elif pp "gte" 0.7 and pp "lt" 0.8: bin = 7
    elif pp "gte" 0.8 and pp "lt" 0.9: bin = 8
    elif pp "gte" 0.9 and pp "lte" 1.0: bin = 9

    counts[bin] += 1
    sums[bin] += pp
    if correct == True: n_corrects[bin] += 1
    elif correct == False: n_wrongs[bin] += 1  # check

  for bin in range(10):
    if counts[bin] == 0: accuracies[bin] = 0.0
    else: accuracies[bin] = n_corrects[bin] / counts[bin]

  for bin in range(10):
    if counts[bin] == 0: avg_pps[bin] = 0.0
    else: avg_pps[bin] = sums[bin] / counts[bin]

  for bin in range(10):
    abs_diffs[bin] = \
      np.abs(avg_pps[bin] - accuracies[bin]) 

  cal_err = 0.0
  for bin in range(10):
    cal_err += counts[bin] * abs_diffs[bin]  # weighted
  cal_err /= len(ds)
  return cal_err

# -----------------------------------------------------------

def main():
  # 0. get started
  print("\nBegin People predict politics using MSE loss ")
  T.manual_seed(0)
  np.random.seed(0)
  
  # 1. create DataLoader objects
  print("\nCreating People Datasets ")

  train_file = ".\\Data\\people_train.txt"
  train_ds = PeopleDataset(train_file)  # 200 rows

  test_file = ".\\Data\\people_test.txt"
  test_ds = PeopleDataset(test_file)    # 40 rows

  bat_size = 10
  train_ldr = T.utils.data.DataLoader(train_ds,
    batch_size=bat_size, shuffle=True)

# -----------------------------------------------------------

  # 2. create network
  print("\nCreating 6-(10-10)-3 neural network ")
  net = Net().to(device)
  net.train()  # set mode

# -----------------------------------------------------------

  # 3. train model
  max_epochs = 2000
  ep_log_interval = 400
  lrn_rate = 0.05

  # loss_func = T.nn.NLLLoss()  # assumes log_softmax()
  loss_func = T.nn.MSELoss(reduction='mean') # assume softmax

  optimizer = T.optim.SGD(net.parameters(), lr=lrn_rate)

  print("\nbat_size = %3d " % bat_size)
  print("loss = " + str(loss_func))
  print("optimizer = SGD")
  print("max_epochs = %3d " % max_epochs)
  print("lrn_rate = %0.3f " % lrn_rate)

  print("\nStarting training")
  for epoch in range(0, max_epochs):
    # T.manual_seed(epoch+1)  # checkpoint reproducibility
    epoch_loss = 0  # for one full epoch

    for (batch_idx, batch) in enumerate(train_ldr):
      X = batch[0]  # inputs
      Y = batch[1]  # correct class/label/politics

      optimizer.zero_grad()
      oupt = net(X)
      loss_val = loss_func(oupt, Y)  # a tensor
      epoch_loss += loss_val.item()  # accumulate
      loss_val.backward()
      optimizer.step()

    if epoch % ep_log_interval == 0:
      print("epoch = %5d  |  loss = %10.4f" % \
        (epoch, epoch_loss))

  print("Training done ")

# -----------------------------------------------------------

  # 4. evaluate model accuracy
  print("\nEvaluating model ")
  net.eval()
  acc_train = accuracy(net, train_ds)  # item-by-item
  print("Accuracy on training data = %0.4f" % acc_train)
  acc_test = accuracy_quick(net, test_ds) 
  print("Accuracy on test data = %0.4f" % acc_test)

  print("\nAccuracy on test data by class: ")
  acc_by_class(net, test_ds, 3)

  ce_train = calibration_error(net, train_ds)
  print("\nCalibration error on train data = %0.4f " % ce_train)
  ce_test = calibration_error(net, test_ds)
  print("Calibration error on test data = %0.4f " % ce_test)

  # 5. make a prediction
  print("\nPredicting politics for M  30  oklahoma  $50,000: ")
  X = np.array([[-1, 0.30,  0,0,1,  0.5000]], dtype=np.float32)
  X = T.tensor(X, dtype=T.float32).to(device) 

  with T.no_grad():
    probs = net(X)  # do sum to 1.0
  probs = probs.numpy()  # numpy vector prints better
  np.set_printoptions(precision=4, suppress=True)
  print(probs)

  # 6. save model (state_dict approach)
  print("\nSaving trained model state")
  # fn = ".\\Models\\people_model.pt"
  # T.save(net.state_dict(), fn)

  # saved_model = Net()  # requires class definintion
  # saved_model.load_state_dict(T.load(fn))
  # use saved_model to make prediction(s)

  print("\nEnd People predict politics demo")

if __name__ == "__main__":
  main()

Training data:

# people_train.txt
# sex (M=-1, F=1)  age  state (michigan, 
# nebraska, oklahoma) income
# politics (consrvative, moderate, liberal)
#
 1, 0.24, 1, 0, 0, 0.2950, 2
-1, 0.39, 0, 0, 1, 0.5120, 1
 1, 0.63, 0, 1, 0, 0.7580, 0
-1, 0.36, 1, 0, 0, 0.4450, 1
 1, 0.27, 0, 1, 0, 0.2860, 2
 1, 0.50, 0, 1, 0, 0.5650, 1
 1, 0.50, 0, 0, 1, 0.5500, 1
-1, 0.19, 0, 0, 1, 0.3270, 0
 1, 0.22, 0, 1, 0, 0.2770, 1
-1, 0.39, 0, 0, 1, 0.4710, 2
 1, 0.34, 1, 0, 0, 0.3940, 1
-1, 0.22, 1, 0, 0, 0.3350, 0
 1, 0.35, 0, 0, 1, 0.3520, 2
-1, 0.33, 0, 1, 0, 0.4640, 1
 1, 0.45, 0, 1, 0, 0.5410, 1
 1, 0.42, 0, 1, 0, 0.5070, 1
-1, 0.33, 0, 1, 0, 0.4680, 1
 1, 0.25, 0, 0, 1, 0.3000, 1
-1, 0.31, 0, 1, 0, 0.4640, 0
 1, 0.27, 1, 0, 0, 0.3250, 2
 1, 0.48, 1, 0, 0, 0.5400, 1
-1, 0.64, 0, 1, 0, 0.7130, 2
 1, 0.61, 0, 1, 0, 0.7240, 0
 1, 0.54, 0, 0, 1, 0.6100, 0
 1, 0.29, 1, 0, 0, 0.3630, 0
 1, 0.50, 0, 0, 1, 0.5500, 1
 1, 0.55, 0, 0, 1, 0.6250, 0
 1, 0.40, 1, 0, 0, 0.5240, 0
 1, 0.22, 1, 0, 0, 0.2360, 2
 1, 0.68, 0, 1, 0, 0.7840, 0
-1, 0.60, 1, 0, 0, 0.7170, 2
-1, 0.34, 0, 0, 1, 0.4650, 1
-1, 0.25, 0, 0, 1, 0.3710, 0
-1, 0.31, 0, 1, 0, 0.4890, 1
 1, 0.43, 0, 0, 1, 0.4800, 1
 1, 0.58, 0, 1, 0, 0.6540, 2
-1, 0.55, 0, 1, 0, 0.6070, 2
-1, 0.43, 0, 1, 0, 0.5110, 1
-1, 0.43, 0, 0, 1, 0.5320, 1
-1, 0.21, 1, 0, 0, 0.3720, 0
 1, 0.55, 0, 0, 1, 0.6460, 0
 1, 0.64, 0, 1, 0, 0.7480, 0
-1, 0.41, 1, 0, 0, 0.5880, 1
 1, 0.64, 0, 0, 1, 0.7270, 0
-1, 0.56, 0, 0, 1, 0.6660, 2
 1, 0.31, 0, 0, 1, 0.3600, 1
-1, 0.65, 0, 0, 1, 0.7010, 2
 1, 0.55, 0, 0, 1, 0.6430, 0
-1, 0.25, 1, 0, 0, 0.4030, 0
 1, 0.46, 0, 0, 1, 0.5100, 1
-1, 0.36, 1, 0, 0, 0.5350, 0
 1, 0.52, 0, 1, 0, 0.5810, 1
 1, 0.61, 0, 0, 1, 0.6790, 0
 1, 0.57, 0, 0, 1, 0.6570, 0
-1, 0.46, 0, 1, 0, 0.5260, 1
-1, 0.62, 1, 0, 0, 0.6680, 2
 1, 0.55, 0, 0, 1, 0.6270, 0
-1, 0.22, 0, 0, 1, 0.2770, 1
-1, 0.50, 1, 0, 0, 0.6290, 0
-1, 0.32, 0, 1, 0, 0.4180, 1
-1, 0.21, 0, 0, 1, 0.3560, 0
 1, 0.44, 0, 1, 0, 0.5200, 1
 1, 0.46, 0, 1, 0, 0.5170, 1
 1, 0.62, 0, 1, 0, 0.6970, 0
 1, 0.57, 0, 1, 0, 0.6640, 0
-1, 0.67, 0, 0, 1, 0.7580, 2
 1, 0.29, 1, 0, 0, 0.3430, 2
 1, 0.53, 1, 0, 0, 0.6010, 0
-1, 0.44, 1, 0, 0, 0.5480, 1
 1, 0.46, 0, 1, 0, 0.5230, 1
-1, 0.20, 0, 1, 0, 0.3010, 1
-1, 0.38, 1, 0, 0, 0.5350, 1
 1, 0.50, 0, 1, 0, 0.5860, 1
 1, 0.33, 0, 1, 0, 0.4250, 1
-1, 0.33, 0, 1, 0, 0.3930, 1
 1, 0.26, 0, 1, 0, 0.4040, 0
 1, 0.58, 1, 0, 0, 0.7070, 0
 1, 0.43, 0, 0, 1, 0.4800, 1
-1, 0.46, 1, 0, 0, 0.6440, 0
 1, 0.60, 1, 0, 0, 0.7170, 0
-1, 0.42, 1, 0, 0, 0.4890, 1
-1, 0.56, 0, 0, 1, 0.5640, 2
-1, 0.62, 0, 1, 0, 0.6630, 2
-1, 0.50, 1, 0, 0, 0.6480, 1
 1, 0.47, 0, 0, 1, 0.5200, 1
-1, 0.67, 0, 1, 0, 0.8040, 2
-1, 0.40, 0, 0, 1, 0.5040, 1
 1, 0.42, 0, 1, 0, 0.4840, 1
 1, 0.64, 1, 0, 0, 0.7200, 0
-1, 0.47, 1, 0, 0, 0.5870, 2
 1, 0.45, 0, 1, 0, 0.5280, 1
-1, 0.25, 0, 0, 1, 0.4090, 0
 1, 0.38, 1, 0, 0, 0.4840, 0
 1, 0.55, 0, 0, 1, 0.6000, 1
-1, 0.44, 1, 0, 0, 0.6060, 1
 1, 0.33, 1, 0, 0, 0.4100, 1
 1, 0.34, 0, 0, 1, 0.3900, 1
 1, 0.27, 0, 1, 0, 0.3370, 2
 1, 0.32, 0, 1, 0, 0.4070, 1
 1, 0.42, 0, 0, 1, 0.4700, 1
-1, 0.24, 0, 0, 1, 0.4030, 0
 1, 0.42, 0, 1, 0, 0.5030, 1
 1, 0.25, 0, 0, 1, 0.2800, 2
 1, 0.51, 0, 1, 0, 0.5800, 1
-1, 0.55, 0, 1, 0, 0.6350, 2
 1, 0.44, 1, 0, 0, 0.4780, 2
-1, 0.18, 1, 0, 0, 0.3980, 0
-1, 0.67, 0, 1, 0, 0.7160, 2
 1, 0.45, 0, 0, 1, 0.5000, 1
 1, 0.48, 1, 0, 0, 0.5580, 1
-1, 0.25, 0, 1, 0, 0.3900, 1
-1, 0.67, 1, 0, 0, 0.7830, 1
 1, 0.37, 0, 0, 1, 0.4200, 1
-1, 0.32, 1, 0, 0, 0.4270, 1
 1, 0.48, 1, 0, 0, 0.5700, 1
-1, 0.66, 0, 0, 1, 0.7500, 2
 1, 0.61, 1, 0, 0, 0.7000, 0
-1, 0.58, 0, 0, 1, 0.6890, 1
 1, 0.19, 1, 0, 0, 0.2400, 2
 1, 0.38, 0, 0, 1, 0.4300, 1
-1, 0.27, 1, 0, 0, 0.3640, 1
 1, 0.42, 1, 0, 0, 0.4800, 1
 1, 0.60, 1, 0, 0, 0.7130, 0
-1, 0.27, 0, 0, 1, 0.3480, 0
 1, 0.29, 0, 1, 0, 0.3710, 0
-1, 0.43, 1, 0, 0, 0.5670, 1
 1, 0.48, 1, 0, 0, 0.5670, 1
 1, 0.27, 0, 0, 1, 0.2940, 2
-1, 0.44, 1, 0, 0, 0.5520, 0
 1, 0.23, 0, 1, 0, 0.2630, 2
-1, 0.36, 0, 1, 0, 0.5300, 2
 1, 0.64, 0, 0, 1, 0.7250, 0
 1, 0.29, 0, 0, 1, 0.3000, 2
-1, 0.33, 1, 0, 0, 0.4930, 1
-1, 0.66, 0, 1, 0, 0.7500, 2
-1, 0.21, 0, 0, 1, 0.3430, 0
 1, 0.27, 1, 0, 0, 0.3270, 2
 1, 0.29, 1, 0, 0, 0.3180, 2
-1, 0.31, 1, 0, 0, 0.4860, 1
 1, 0.36, 0, 0, 1, 0.4100, 1
 1, 0.49, 0, 1, 0, 0.5570, 1
-1, 0.28, 1, 0, 0, 0.3840, 0
-1, 0.43, 0, 0, 1, 0.5660, 1
-1, 0.46, 0, 1, 0, 0.5880, 1
 1, 0.57, 1, 0, 0, 0.6980, 0
-1, 0.52, 0, 0, 1, 0.5940, 1
-1, 0.31, 0, 0, 1, 0.4350, 1
-1, 0.55, 1, 0, 0, 0.6200, 2
 1, 0.50, 1, 0, 0, 0.5640, 1
 1, 0.48, 0, 1, 0, 0.5590, 1
-1, 0.22, 0, 0, 1, 0.3450, 0
 1, 0.59, 0, 0, 1, 0.6670, 0
 1, 0.34, 1, 0, 0, 0.4280, 2
-1, 0.64, 1, 0, 0, 0.7720, 2
 1, 0.29, 0, 0, 1, 0.3350, 2
-1, 0.34, 0, 1, 0, 0.4320, 1
-1, 0.61, 1, 0, 0, 0.7500, 2
 1, 0.64, 0, 0, 1, 0.7110, 0
-1, 0.29, 1, 0, 0, 0.4130, 0
 1, 0.63, 0, 1, 0, 0.7060, 0
-1, 0.29, 0, 1, 0, 0.4000, 0
-1, 0.51, 1, 0, 0, 0.6270, 1
-1, 0.24, 0, 0, 1, 0.3770, 0
 1, 0.48, 0, 1, 0, 0.5750, 1
 1, 0.18, 1, 0, 0, 0.2740, 0
 1, 0.18, 1, 0, 0, 0.2030, 2
 1, 0.33, 0, 1, 0, 0.3820, 2
-1, 0.20, 0, 0, 1, 0.3480, 0
 1, 0.29, 0, 0, 1, 0.3300, 2
-1, 0.44, 0, 0, 1, 0.6300, 0
-1, 0.65, 0, 0, 1, 0.8180, 0
-1, 0.56, 1, 0, 0, 0.6370, 2
-1, 0.52, 0, 0, 1, 0.5840, 1
-1, 0.29, 0, 1, 0, 0.4860, 0
-1, 0.47, 0, 1, 0, 0.5890, 1
 1, 0.68, 1, 0, 0, 0.7260, 2
 1, 0.31, 0, 0, 1, 0.3600, 1
 1, 0.61, 0, 1, 0, 0.6250, 2
 1, 0.19, 0, 1, 0, 0.2150, 2
 1, 0.38, 0, 0, 1, 0.4300, 1
-1, 0.26, 1, 0, 0, 0.4230, 0
 1, 0.61, 0, 1, 0, 0.6740, 0
 1, 0.40, 1, 0, 0, 0.4650, 1
-1, 0.49, 1, 0, 0, 0.6520, 1
 1, 0.56, 1, 0, 0, 0.6750, 0
-1, 0.48, 0, 1, 0, 0.6600, 1
 1, 0.52, 1, 0, 0, 0.5630, 2
-1, 0.18, 1, 0, 0, 0.2980, 0
-1, 0.56, 0, 0, 1, 0.5930, 2
-1, 0.52, 0, 1, 0, 0.6440, 1
-1, 0.18, 0, 1, 0, 0.2860, 1
-1, 0.58, 1, 0, 0, 0.6620, 2
-1, 0.39, 0, 1, 0, 0.5510, 1
-1, 0.46, 1, 0, 0, 0.6290, 1
-1, 0.40, 0, 1, 0, 0.4620, 1
-1, 0.60, 1, 0, 0, 0.7270, 2
 1, 0.36, 0, 1, 0, 0.4070, 2
 1, 0.44, 1, 0, 0, 0.5230, 1
 1, 0.28, 1, 0, 0, 0.3130, 2
 1, 0.54, 0, 0, 1, 0.6260, 0

Test data:

# people_test.txt
#
-1, 0.51, 1, 0, 0, 0.6120, 1
-1, 0.32, 0, 1, 0, 0.4610, 1
 1, 0.55, 1, 0, 0, 0.6270, 0
 1, 0.25, 0, 0, 1, 0.2620, 2
 1, 0.33, 0, 0, 1, 0.3730, 2
-1, 0.29, 0, 1, 0, 0.4620, 0
 1, 0.65, 1, 0, 0, 0.7270, 0
-1, 0.43, 0, 1, 0, 0.5140, 1
-1, 0.54, 0, 1, 0, 0.6480, 2
 1, 0.61, 0, 1, 0, 0.7270, 0
 1, 0.52, 0, 1, 0, 0.6360, 0
 1, 0.30, 0, 1, 0, 0.3350, 2
 1, 0.29, 1, 0, 0, 0.3140, 2
-1, 0.47, 0, 0, 1, 0.5940, 1
 1, 0.39, 0, 1, 0, 0.4780, 1
 1, 0.47, 0, 0, 1, 0.5200, 1
-1, 0.49, 1, 0, 0, 0.5860, 1
-1, 0.63, 0, 0, 1, 0.6740, 2
-1, 0.30, 1, 0, 0, 0.3920, 0
-1, 0.61, 0, 0, 1, 0.6960, 2
-1, 0.47, 0, 0, 1, 0.5870, 1
 1, 0.30, 0, 0, 1, 0.3450, 2
-1, 0.51, 0, 0, 1, 0.5800, 1
-1, 0.24, 1, 0, 0, 0.3880, 1
-1, 0.49, 1, 0, 0, 0.6450, 1
 1, 0.66, 0, 0, 1, 0.7450, 0
-1, 0.65, 1, 0, 0, 0.7690, 0
-1, 0.46, 0, 1, 0, 0.5800, 0
-1, 0.45, 0, 0, 1, 0.5180, 1
-1, 0.47, 1, 0, 0, 0.6360, 0
-1, 0.29, 1, 0, 0, 0.4480, 0
-1, 0.57, 0, 0, 1, 0.6930, 2
-1, 0.20, 1, 0, 0, 0.2870, 2
-1, 0.35, 1, 0, 0, 0.4340, 1
-1, 0.61, 0, 0, 1, 0.6700, 2
-1, 0.31, 0, 0, 1, 0.3730, 1
 1, 0.18, 1, 0, 0, 0.2080, 2
 1, 0.26, 0, 0, 1, 0.2920, 2
-1, 0.28, 1, 0, 0, 0.3640, 2
-1, 0.59, 0, 0, 1, 0.6940, 2
This entry was posted in PyTorch. Bookmark the permalink.

Leave a comment