The Wheat Seeds Dataset Problem Using PyTorch

I was looking at the Wheat Seeds dataset problem recently. The goal is to predict the species of a wheat seed (Kama, Rosa, Canadian) from seven numeric predictors (seed length, width, etc.) I implemented a PyTorch neural network classifier.

As is often the case, data preparation was time-consuming. The raw data is at archive.ics.uci.edu/ml/datasets/seeds. The raw data looks like:

15.26  14.84  0.871   5.763  3.312  2.221  5.22   1
14.88  14.57  0.8811  5.554  3.333  1.018  4.956  1
. . .
17.63  15.98  0.8673  6.191  3.561  4.076  6.06   2
16.84  15.67  0.8623  5.998  3.484  4.675  5.877  2
. . .
11.84  13.21  0.8521  5.175  2.836  3.598  5.044  3
12.3   13.34  0.8684  5.243  2.974  5.637  5.063  3

---------------------------------------------------
10.59  12.41  0.8081  4.899  2.63   0.765  4.519 (min values)
21.18  17.25  0.9183  6.675  4.033  8.456  6.55  (max values)

There are 210 data items. Each represents one of three species of wheat seeds: Kama, Rosa, Canadian. There are 70 of each species. The first seven values on each line are the predictors: area, perimeter, compactness, length, width, asymmetry, groove. The eighth value is the one-based encoded species.

The ranges of the raw predictor values varies significantly. To normalize, I used the divide-by-constant technique. I dropped the raw data into an Excel spreadsheet. I divided the columns by (25, 20, 1, 10, 10, 10, 10). I also re-coded the target class labels from 1-based to 0-based. The resulting source 210-items looked like:

0.6104  0.7420  0.8710  0.5763  0.3312  0.2221  0.5220  0
0.5952  0.7285  0.8811  0.5554  0.3333  0.1018  0.4956  0
. . .
0.7052  0.7990  0.8673  0.6191  0.3561  0.4076  0.6060  1
0.6736  0.7835  0.8623  0.5998  0.3484  0.4675  0.5877  1
. . .
0.5048  0.6835  0.8481  0.5410  0.2911  0.3306  0.5231  2
0.5104  0.6690  0.8964  0.5073  0.3155  0.2828  0.4830  2

I split the 210-item normalized data into a 180-item training set and a 30-item test set. I used the first 60 of each target class for training and the last 10 of each target class for testing.

I had a lot more trouble than expected to tune the hyperparameters for my PyTorch classier. For architecture I ended up with a 7-(20-20)-3 design with tanh() hidden activation, log_softmax() output activation, and a dropout(0.50) on the first hidden layer.

Most of my attempts showed severe model overfitting where accuracy on the training data was close to 100% but test accuracy was around 60%. I tried L2 regularization without much luck, but 50% dropout on the first hidden layer gave decent results.

class Net(T.nn.Module):
  def __init__(self):
    super(Net, self).__init__()
    self.hid1 = T.nn.Linear(7, 20)  # 7-(20-20)-3
    self.drop1 = T.nn.Dropout(0.50)
    self.hid2 = T.nn.Linear(20, 20)
    self.oupt = T.nn.Linear(20, 3)

  # init weights here 

  def forward(self, x):
    z = T.tanh(self.hid1(x))
    z = self.drop1(z)
    z = T.tanh(self.hid2(z))
    z = T.log_softmax(self.oupt(z), dim=1)  # NLLLoss() 
    return z

For training, I used SGD optimization, NLLLoss loss, batch size = 10, and 1,000 epochs.

It was an interesting exploration. Good fun.



Wheat is good so wheat seeds are good. Left: “The Bad Seed” (1956) is a scary movie about a not-so-nice adopted girl, Rhoda, who was the result of a “bad seed” (a murderous father). Center: In “The Omen” (1976), adopted Damien came from a very bad seed — the devil, literally. Right: In “The Other” (1972), adopted boys Niles and his twin brother Holland always seem to be around when something very bad happens. But we never actually see the brothers together at the same time. Honorable Mention Bad Seed: adopted child Esther in “Orphan” (2009).


Demo code:

# wheat.py
# predict wheat species from 7 predictors
# PyTorch 1.12.1-CPU Anaconda3-2020.02  Python 3.7.6
# Windows 10/11 

import numpy as np
import torch as T
device = T.device('cpu')  # apply to Tensor or Module

# -----------------------------------------------------------

class WheatDataset(T.utils.data.Dataset):
  # seven numeric predictors then species
  # species: 0 = Kama, 1 = Rosa, 2 = Canadian

  def __init__(self, src_file):
    all_xy = np.loadtxt(src_file, usecols=range(0,8),
      delimiter="\t", comments="#", dtype=np.float32)
    tmp_x = all_xy[:,0:7]   # cols [0,7) = [0,6]
    tmp_y = all_xy[:,7]     # 1-D

    self.x_data = T.tensor(tmp_x, 
      dtype=T.float32).to(device)
    self.y_data = T.tensor(tmp_y,
      dtype=T.int64).to(device)  # 1-D

  def __len__(self):
    return len(self.x_data)

  def __getitem__(self, idx):
    preds = self.x_data[idx]
    trgts = self.y_data[idx] 
    return preds, trgts  # as a Tuple

# -----------------------------------------------------------

class Net(T.nn.Module):
  def __init__(self):
    super(Net, self).__init__()
    self.hid1 = T.nn.Linear(7, 20)  # 7-(20-20)-3
    self.drop1 = T.nn.Dropout(0.50)
    self.hid2 = T.nn.Linear(20, 20)
    self.oupt = T.nn.Linear(20, 3)

    T.nn.init.xavier_uniform_(self.hid1.weight)
    T.nn.init.zeros_(self.hid1.bias)
    T.nn.init.xavier_uniform_(self.hid2.weight)
    T.nn.init.zeros_(self.hid2.bias)
    T.nn.init.xavier_uniform_(self.oupt.weight)
    T.nn.init.zeros_(self.oupt.bias)

  def forward(self, x):
    z = T.tanh(self.hid1(x))
    z = self.drop1(z)
    z = T.tanh(self.hid2(z))
    z = T.log_softmax(self.oupt(z), dim=1)  # NLLLoss() 
    return z

# -----------------------------------------------------------

def accuracy(model, ds):
  # assumes model.eval()
  # item-by-item version
  n_correct = 0; n_wrong = 0
  for i in range(len(ds)):
    X = ds[i][0].reshape(1,-1)  # make it a batch
    Y = ds[i][1].reshape(1)  # 0 1 or 2, 1D
    with T.no_grad():
      oupt = model(X)  # logits form

    big_idx = T.argmax(oupt)  # 0 or 1 or 2
    if big_idx == Y:
      n_correct += 1
    else:
      n_wrong += 1

  acc = (n_correct * 1.0) / (n_correct + n_wrong)
  return acc

# -----------------------------------------------------------

def confusion_matrix_multi(model, ds, n_classes):
  if n_classes "lte" 2:  # less-than-or-equal
    print("ERROR: n_classes must be 3 or greater ")
    return None

  cm = np.zeros((n_classes,n_classes), dtype=np.int64)
  for i in range(len(ds)):
    X = ds[i][0].reshape(1,-1)  # make it a batch
    Y = ds[i][1].reshape(1)  # actual class 0 1 or 2, 1D
    with T.no_grad():
      oupt = model(X)  # logits form
    pred_class = T.argmax(oupt)  # 0,1,2
    cm[Y][pred_class] += 1
  return cm

# -----------------------------------------------------------

def show_confusion(cm):
  dim = len(cm)
  mx = np.max(cm)             # largest count in cm
  wid = len(str(mx)) + 1      # width to print
  fmt = "%" + str(wid) + "d"  # like "%3d"
  for i in range(dim):
    print("actual   ", end="")
    print("%3d:" % i, end="")
    for j in range(dim):
      print(fmt % cm[i][j], end="")
    print("")
  print("------------")
  print("predicted    ", end="")
  for j in range(dim):
    print(fmt % j, end="")
  print("")

# -----------------------------------------------------------

def main():
  # 0. get started
  print("\nBegin Wheat Seeds species prediction ")
  T.manual_seed(1)
  np.random.seed(1)
  
  # 1. create DataLoader objects
  print("\nCreating Wheat Datasets ")

  # div by constant normed: (25,20,1,10,10,10,10)
  train_file = ".\\Data\\wheat_train_k.txt"
  train_ds = WheatDataset(train_file)  # 180 rows

  test_file = ".\\Data\\wheat_test_k.txt"
  test_ds = WheatDataset(test_file)    # 30 rows

  bat_size = 10
  train_ldr = T.utils.data.DataLoader(train_ds,
    batch_size=bat_size, shuffle=True)

# -----------------------------------------------------------

  # 2. create network
  print("\nCreating 7-(20-20)-3 tanh neural network ")
  net = Net().to(device)
  net.train()

# -----------------------------------------------------------

  # min-max (1) 7-(20-20)-3 relu SGD 10 2000 0.015 = 83.33%
  # k (1) 7-(20-20)-3 tanh SGD 10 1000 0.001 = 80.00%

  # 3. train model
  max_epochs = 1000
  ep_log_interval = 200
  lrn_rate = 0.001

  loss_func = T.nn.NLLLoss()  # assumes log_softmax()
  optimizer = T.optim.SGD(net.parameters(), lr=lrn_rate)
  # optimizer = T.optim.Adam(net.parameters(), lr=lrn_rate,
  #   weight_decay=0.001)

  print("\nbat_size = %3d " % bat_size)
  print("loss = " + str(loss_func))
  print("optimizer = SGD ")
  print("max_epochs = %3d " % max_epochs)
  print("lrn_rate = %0.3f " % lrn_rate)

  print("\nStarting training drop=0.50 on hid1")
  for epoch in range(0, max_epochs):
    # T.manual_seed(epoch+1)  # checkpoint reproducibility
    epoch_loss = 0  # for one full epoch

    for (batch_idx, batch) in enumerate(train_ldr):
      X = batch[0]  # inputs
      Y = batch[1]  # correct class/label/species

      optimizer.zero_grad()
      oupt = net(X)
      loss_val = loss_func(oupt, Y)  # a tensor
      epoch_loss += loss_val.item()  # accumulate
      loss_val.backward()
      optimizer.step()

    if epoch % ep_log_interval == 0:
      print("epoch = %5d  |  loss = %10.4f" % \
        (epoch, epoch_loss))

  print("Done ")

# -----------------------------------------------------------

  # 4. evaluate model accuracy
  print("\nComputing model accuracy")
  net.eval()
  acc_train = accuracy(net, train_ds)  # item-by-item
  print("Accuracy on training data = %0.4f" % acc_train)
  acc_test = accuracy(net, test_ds) 
  print("Accuracy on test data = %0.4f" % acc_test)

  print("\nComputing confusion matrix \n ")
  cm = confusion_matrix_multi(net, test_ds, n_classes=3)
  show_confusion(cm)

  # 5. make a prediction
  print("\nPredicting for 0.5 0.5 0.5 0.5 0.5 0.5 0.5 ")
  X = np.array([[0.5,0.5,0.5,0.5,0.5,0.5,0.5]],
    dtype=np.float32)
  X = T.tensor(X, dtype=T.float32).to(device) 

  with T.no_grad():
    logits = net(X)  # do not sum to 1.0
  probs = T.exp(logits)  # sum to 1.0
  probs = probs.numpy()  # numpy vector prints better
  np.set_printoptions(precision=4, suppress=True)
  print(probs)

  # 6. save model (state_dict approach)
  print("\nSaving trained model state")
  # fn = ".\\Models\\wheat_model.pt"
  # T.save(net.state_dict(), fn)

  # saved_model = Net()  # requires class definintion
  # saved_model.load_state_dict(T.load(fn))
  # use saved_model to make prediction(s)

  print("\nEnd Wheat Seeds demo ")

if __name__ == "__main__":
  main()

Training data. Replace commas with tabs or modify program code.

# wheat_train_k.txt
# http://archive.ics.uci.edu/ml/datasets/seeds
# 210 total items. train is first 60 each of 3 classes
# 180 training, 30 test
# area, perimeter, compactness, length, width, asymmetry, groove
# predictors are div-by-k normed: (25, 20, 1, 10, 10, 10, 10)
# 0 = Kama, 1 = Rosa, 2 = Canadian
#
0.6104,0.7420,0.8710,0.5763,0.3312,0.2221,0.5220,0
0.5952,0.7285,0.8811,0.5554,0.3333,0.1018,0.4956,0
0.5716,0.7045,0.9050,0.5291,0.3337,0.2699,0.4825,0
0.5536,0.6970,0.8955,0.5324,0.3379,0.2259,0.4805,0
0.6456,0.7495,0.9034,0.5658,0.3562,0.1355,0.5175,0
0.5752,0.7105,0.8951,0.5386,0.3312,0.2462,0.4956,0
0.5876,0.7245,0.8799,0.5563,0.3259,0.3586,0.5219,0
0.5644,0.7050,0.8911,0.5420,0.3302,0.2700,0.5000,0
0.6652,0.7730,0.8747,0.6053,0.3465,0.2040,0.5877,0
0.6576,0.7625,0.8880,0.5884,0.3505,0.1969,0.5533,0
0.6104,0.7425,0.8696,0.5714,0.3242,0.4543,0.5314,0
0.5612,0.7080,0.8796,0.5438,0.3201,0.1717,0.5001,0
0.5556,0.7010,0.8880,0.5439,0.3199,0.3986,0.4738,0
0.5512,0.7030,0.8759,0.5479,0.3156,0.3136,0.4872,0
0.5496,0.7025,0.8744,0.5482,0.3114,0.2932,0.4825,0
0.5836,0.7140,0.8993,0.5351,0.3333,0.4185,0.4781,0
0.5596,0.6915,0.9183,0.5119,0.3383,0.5234,0.4781,0
0.6276,0.7375,0.9058,0.5527,0.3514,0.1599,0.5046,0
0.5880,0.7105,0.9153,0.5205,0.3466,0.1767,0.4649,0
0.5088,0.6785,0.8686,0.5226,0.3049,0.4102,0.4914,0
0.5664,0.7200,0.8584,0.5658,0.3129,0.3072,0.5176,0
0.5644,0.7130,0.8722,0.5520,0.3168,0.2688,0.5219,0
0.6352,0.7450,0.8988,0.5618,0.3507,0.0765,0.5091,0
0.4832,0.6615,0.8664,0.5099,0.2936,0.1415,0.4961,0
0.6004,0.7380,0.8657,0.5789,0.3245,0.1791,0.5001,0
0.6476,0.7580,0.8849,0.5833,0.3421,0.0903,0.5307,0
0.5208,0.6880,0.8641,0.5395,0.3026,0.3373,0.4825,0
0.5096,0.6835,0.8564,0.5395,0.2956,0.2504,0.4869,0
0.5644,0.7090,0.8820,0.5541,0.3221,0.2754,0.5038,0
0.5380,0.7010,0.8604,0.5516,0.3065,0.3531,0.5097,0
0.5264,0.6910,0.8662,0.5454,0.2975,0.0855,0.5056,0
0.6196,0.7470,0.8724,0.5757,0.3371,0.3412,0.5228,0
0.5636,0.7205,0.8529,0.5717,0.3186,0.3920,0.5299,0
0.5576,0.7085,0.8728,0.5585,0.3150,0.2124,0.5012,0
0.6020,0.7340,0.8779,0.5712,0.3328,0.2129,0.5360,0
0.6448,0.7500,0.9000,0.5709,0.3485,0.2270,0.5443,0
0.6480,0.7635,0.8734,0.5826,0.3464,0.2823,0.5527,0
0.6832,0.7690,0.9079,0.5832,0.3683,0.2956,0.5484,0
0.5920,0.7260,0.8823,0.5656,0.3288,0.3112,0.5309,0
0.5712,0.7085,0.8944,0.5397,0.3298,0.6685,0.5001,0
0.5416,0.6925,0.8871,0.5348,0.3156,0.2587,0.5178,0
0.5400,0.6925,0.8852,0.5351,0.3158,0.2249,0.5176,0
0.5264,0.6775,0.9009,0.5138,0.3201,0.2461,0.4783,0
0.6200,0.7430,0.8820,0.5877,0.3396,0.4711,0.5528,0
0.6044,0.7270,0.8986,0.5579,0.3462,0.3128,0.5180,0
0.5520,0.7020,0.8794,0.5376,0.3155,0.1560,0.4961,0
0.6144,0.7380,0.8861,0.5701,0.3393,0.1367,0.5132,0
0.5996,0.7280,0.8883,0.5570,0.3377,0.2958,0.5175,0
0.5916,0.7260,0.8819,0.5545,0.3291,0.2704,0.5111,0
0.5944,0.7335,0.8676,0.5678,0.3258,0.2129,0.5351,0
0.5772,0.7200,0.8751,0.5585,0.3272,0.3975,0.5144,0
0.6312,0.7455,0.8923,0.5674,0.3434,0.5593,0.5136,0
0.5796,0.7305,0.8538,0.5715,0.3113,0.4116,0.5396,0
0.5732,0.7140,0.8831,0.5504,0.3199,0.3328,0.5224,0
0.5808,0.7300,0.8557,0.5741,0.3113,0.1481,0.5487,0
0.6012,0.7385,0.8658,0.5702,0.3212,0.1933,0.5439,0
0.5784,0.7175,0.8818,0.5388,0.3377,0.2802,0.5044,0
0.5968,0.7215,0.9006,0.5384,0.3412,0.1142,0.5088,0
0.6152,0.7385,0.8857,0.5662,0.3419,0.1999,0.5222,0
0.4844,0.6735,0.8392,0.5159,0.3032,0.1502,0.4519,0
0.7052,0.7990,0.8673,0.6191,0.3561,0.4076,0.6060,1
0.6736,0.7835,0.8623,0.5998,0.3484,0.4675,0.5877,1
0.6904,0.7865,0.8763,0.5978,0.3594,0.4539,0.5791,1
0.7644,0.8130,0.9081,0.6154,0.3930,0.2936,0.6079,1
0.6728,0.7755,0.8786,0.6017,0.3486,0.4004,0.5841,1
0.6708,0.7810,0.8638,0.5927,0.3438,0.4920,0.5795,1
0.6928,0.7955,0.8599,0.6064,0.3403,0.3824,0.5922,1
0.8284,0.8615,0.8763,0.6579,0.3814,0.4451,0.6451,1
0.7576,0.8245,0.8750,0.6445,0.3639,0.5064,0.6362,1
0.6848,0.7775,0.8892,0.5850,0.3566,0.2858,0.5746,1
0.6612,0.7670,0.8823,0.5875,0.3467,0.5532,0.5880,1
0.7488,0.8095,0.8977,0.6006,0.3857,0.5324,0.5879,1
0.8080,0.8445,0.8894,0.6285,0.3864,0.5173,0.6187,1
0.7828,0.8370,0.8779,0.6384,0.3772,0.1472,0.6273,1
0.7804,0.8355,0.8780,0.6366,0.3801,0.2962,0.6185,1
0.7308,0.8045,0.8870,0.6173,0.3651,0.2443,0.6197,1
0.7552,0.8130,0.8969,0.6084,0.3764,0.1649,0.6109,1
0.7592,0.8330,0.8590,0.6549,0.3670,0.3691,0.6498,1
0.8472,0.8605,0.8989,0.6573,0.4033,0.5780,0.6231,1
0.8352,0.8525,0.9031,0.6450,0.4032,0.5016,0.6321,1
0.8040,0.8495,0.8746,0.6581,0.3785,0.1955,0.6449,1
0.7504,0.8100,0.8984,0.6172,0.3796,0.3120,0.6053,1
0.7524,0.8145,0.8906,0.6272,0.3693,0.3237,0.6053,1
0.7436,0.8025,0.9066,0.6037,0.3860,0.6001,0.5877,1
0.7344,0.8260,0.8452,0.6666,0.3485,0.4933,0.6448,1
0.6748,0.7825,0.8648,0.6139,0.3463,0.3696,0.5967,1
0.7724,0.8295,0.8815,0.6341,0.3810,0.3477,0.6238,1
0.7592,0.8285,0.8687,0.6449,0.3552,0.2144,0.6453,1
0.7268,0.8130,0.8637,0.6271,0.3512,0.2853,0.6273,1
0.7488,0.8170,0.8810,0.6219,0.3684,0.2188,0.6097,1
0.6564,0.7625,0.8866,0.5718,0.3525,0.4217,0.5618,1
0.7196,0.7930,0.8992,0.5890,0.3694,0.2068,0.5837,1
0.7784,0.8250,0.8985,0.6113,0.3892,0.4308,0.6009,1
0.7672,0.8315,0.8717,0.6369,0.3681,0.3357,0.6229,1
0.7580,0.8210,0.8829,0.6248,0.3755,0.3368,0.6148,1
0.7532,0.8145,0.8917,0.6037,0.3786,0.2553,0.5879,1
0.7540,0.8085,0.9056,0.6152,0.3806,0.2843,0.6200,1
0.7052,0.7930,0.8800,0.6033,0.3573,0.3747,0.5929,1
0.7976,0.8460,0.8752,0.6675,0.3763,0.3252,0.6550,1
0.7420,0.8110,0.8865,0.6153,0.3674,0.1738,0.5894,1
0.7380,0.8060,0.8921,0.6107,0.3769,0.2235,0.5794,1
0.7752,0.8360,0.8716,0.6303,0.3791,0.3678,0.5965,1
0.7652,0.8155,0.9035,0.6183,0.3902,0.2109,0.5924,1
0.7656,0.8305,0.8722,0.6259,0.3737,0.6682,0.6053,1
0.8388,0.8625,0.8859,0.6563,0.3991,0.4677,0.6316,1
0.7624,0.8225,0.8854,0.6416,0.3719,0.2248,0.6163,1
0.7584,0.8100,0.9077,0.6051,0.3897,0.4334,0.5750,1
0.7660,0.8225,0.8890,0.6245,0.3815,0.3084,0.6185,1
0.7556,0.8115,0.9008,0.6227,0.3769,0.3639,0.5966,1
0.8012,0.8450,0.8811,0.6493,0.3857,0.3063,0.6320,1
0.8096,0.8455,0.8897,0.6315,0.3962,0.5901,0.6188,1
0.7256,0.8060,0.8772,0.6059,0.3563,0.3619,0.6011,1
0.6468,0.7690,0.8588,0.5762,0.3387,0.4286,0.5703,1
0.7372,0.7985,0.9077,0.5980,0.3771,0.2984,0.5905,1
0.6396,0.7445,0.9064,0.5363,0.3582,0.3336,0.5144,1
0.7500,0.8090,0.8999,0.6111,0.3869,0.4188,0.5992,1
0.7460,0.8205,0.8698,0.6285,0.3594,0.4391,0.6102,1
0.7192,0.7925,0.8993,0.5979,0.3687,0.2257,0.5919,1
0.8064,0.8515,0.8735,0.6513,0.3773,0.1910,0.6185,1
0.7020,0.7830,0.8991,0.5791,0.3690,0.5366,0.5661,1
0.5228,0.6960,0.8480,0.5472,0.2994,0.5304,0.5395,2
0.5328,0.6970,0.8613,0.5541,0.3073,0.7035,0.5440,2
0.5336,0.6975,0.8620,0.5389,0.3074,0.5995,0.5307,2
0.4888,0.6660,0.8652,0.5224,0.2967,0.5469,0.5221,2
0.4728,0.6700,0.8274,0.5314,0.2777,0.4471,0.5178,2
0.4484,0.6565,0.8167,0.5279,0.2687,0.6169,0.5275,2
0.4572,0.6565,0.8335,0.5176,0.2719,0.2221,0.5132,2
0.4996,0.6730,0.8658,0.5267,0.2967,0.4421,0.5002,2
0.5080,0.6855,0.8491,0.5386,0.2911,0.3260,0.5316,2
0.4316,0.6465,0.8107,0.5317,0.2648,0.5462,0.5194,2
0.4732,0.6615,0.8496,0.5263,0.2840,0.5195,0.5307,2
0.4804,0.6760,0.8249,0.5405,0.2776,0.6992,0.5270,2
0.4904,0.6800,0.8333,0.5408,0.2833,0.4756,0.5360,2
0.4472,0.6520,0.8266,0.5220,0.2693,0.3332,0.5001,2
0.4544,0.6525,0.8382,0.5175,0.2755,0.4048,0.5263,2
0.4476,0.6525,0.8253,0.5250,0.2675,0.5813,0.5219,2
0.4536,0.6435,0.8596,0.5053,0.2849,0.3347,0.5003,2
0.4852,0.6865,0.8081,0.5394,0.2745,0.4825,0.5220,2
0.4700,0.6760,0.8082,0.5444,0.2678,0.4378,0.5310,2
0.4596,0.6610,0.8263,0.5304,0.2695,0.5388,0.5310,2
0.5016,0.6835,0.8425,0.5451,0.2879,0.3082,0.5491,2
0.4808,0.6665,0.8503,0.5350,0.2810,0.4271,0.5308,2
0.4820,0.6705,0.8416,0.5267,0.2847,0.4988,0.5046,2
0.5020,0.6785,0.8558,0.5333,0.2968,0.4419,0.5176,2
0.4456,0.6395,0.8558,0.5011,0.2794,0.6388,0.5049,2
0.4840,0.6575,0.8793,0.5105,0.2941,0.2201,0.5056,2
0.4976,0.6795,0.8462,0.5319,0.2897,0.4924,0.5270,2
0.4860,0.6725,0.8443,0.5417,0.2837,0.3638,0.5338,2
0.4540,0.6560,0.8291,0.5176,0.2668,0.4337,0.5132,2
0.4496,0.6500,0.8359,0.5090,0.2715,0.3521,0.5088,2
0.4408,0.6500,0.8189,0.5325,0.2701,0.6735,0.5163,2
0.4620,0.6550,0.8455,0.5167,0.2845,0.6715,0.4956,2
0.4508,0.6485,0.8419,0.5088,0.2763,0.4309,0.5000,2
0.4560,0.6540,0.8375,0.5136,0.2763,0.5588,0.5089,2
0.4332,0.6480,0.8099,0.5278,0.2641,0.5182,0.5185,2
0.4320,0.6285,0.8590,0.4981,0.2821,0.4773,0.5063,2
0.4504,0.6505,0.8355,0.5186,0.2710,0.5335,0.5092,2
0.4296,0.6365,0.8329,0.5145,0.2642,0.4702,0.4963,2
0.4592,0.6525,0.8473,0.5180,0.2758,0.5876,0.5002,2
0.4884,0.6735,0.8453,0.5357,0.2893,0.1661,0.5178,2
0.4564,0.6475,0.8560,0.5090,0.2775,0.4957,0.4825,2
0.4984,0.6705,0.8706,0.5236,0.3017,0.4987,0.5147,2
0.4876,0.6680,0.8579,0.5240,0.2909,0.4857,0.5158,2
0.4660,0.6535,0.8575,0.5108,0.2850,0.5209,0.5135,2
0.5156,0.6885,0.8541,0.5495,0.3026,0.6185,0.5316,2
0.4624,0.6655,0.8198,0.5363,0.2683,0.4062,0.5182,2
0.4724,0.6725,0.8198,0.5413,0.2716,0.4898,0.5352,2
0.4364,0.6400,0.8372,0.5088,0.2675,0.4179,0.4956,2
0.4492,0.6410,0.8594,0.5089,0.2821,0.7524,0.4957,2
0.4236,0.6205,0.8648,0.4899,0.2787,0.4975,0.4794,2
0.4372,0.6400,0.8390,0.5046,0.2717,0.5398,0.5045,2
0.4508,0.6430,0.8563,0.5091,0.2804,0.3985,0.5001,2
0.4748,0.6510,0.8795,0.5132,0.2953,0.3597,0.5132,2
0.4328,0.6415,0.8256,0.5180,0.2630,0.4853,0.5089,2
0.4844,0.6635,0.8639,0.5236,0.2975,0.4132,0.5012,2
0.5120,0.6735,0.8860,0.5160,0.3126,0.4873,0.4914,2
0.5116,0.6765,0.8786,0.5224,0.3054,0.5483,0.4958,2
0.5348,0.6890,0.8849,0.5320,0.3128,0.4670,0.5091,2
0.5048,0.6835,0.8481,0.5410,0.2911,0.3306,0.5231,2
0.5104,0.6690,0.8964,0.5073,0.3155,0.2828,0.4830,2

Test data:

# wheat_test_k.txt
#
0.4568,0.6430,0.8683,0.5008,0.2850,0.2700,0.4607,0
0.4492,0.6315,0.8840,0.4902,0.2879,0.2269,0.4703,0
0.4944,0.6595,0.8923,0.5076,0.3042,0.3220,0.4605,0
0.5288,0.6920,0.8680,0.5395,0.3070,0.4157,0.5088,0
0.5112,0.6785,0.8716,0.5262,0.3026,0.1176,0.4782,0
0.5152,0.6750,0.8879,0.5139,0.3119,0.2352,0.4607,0
0.5736,0.7185,0.8726,0.5630,0.3190,0.1313,0.5150,0
0.5604,0.7145,0.8625,0.5609,0.3158,0.2217,0.5132,0
0.5748,0.7195,0.8726,0.5569,0.3153,0.1464,0.5300,0
0.5092,0.6875,0.8458,0.5412,0.2882,0.3533,0.5067,0
0.7320,0.7945,0.9108,0.5979,0.3755,0.2837,0.5962,1
0.7576,0.8160,0.8942,0.6144,0.3825,0.2908,0.5949,1
0.6152,0.7450,0.8706,0.5884,0.3268,0.4462,0.5795,1
0.6464,0.7665,0.8644,0.5845,0.3395,0.4266,0.5795,1
0.6224,0.7445,0.8823,0.5776,0.3408,0.4972,0.5847,1
0.6152,0.7330,0.8990,0.5477,0.3465,0.3600,0.5439,1
0.6944,0.7880,0.8785,0.6145,0.3574,0.3526,0.5971,1
0.6228,0.7575,0.8527,0.5920,0.3231,0.2640,0.5879,1
0.6240,0.7555,0.8580,0.5832,0.3286,0.2725,0.5752,1
0.6492,0.7590,0.8850,0.5872,0.3472,0.3769,0.5922,1
0.4952,0.6720,0.8609,0.5219,0.2989,0.5472,0.5045,2
0.5068,0.6660,0.8977,0.4984,0.3135,0.2300,0.4745,2
0.4472,0.6360,0.8680,0.5009,0.2810,0.4051,0.4828,2
0.5080,0.6705,0.8874,0.5183,0.3091,0.8456,0.5000,2
0.4948,0.6735,0.8567,0.5204,0.2960,0.3919,0.5001,2
0.4876,0.6600,0.8783,0.5137,0.2981,0.3631,0.4870,2
0.4492,0.6440,0.8511,0.5140,0.2795,0.4325,0.5003,2
0.5280,0.6830,0.8883,0.5236,0.3232,0.8315,0.5056,2
0.4736,0.6605,0.8521,0.5175,0.2836,0.3598,0.5044,2
0.4920,0.6670,0.8684,0.5243,0.2974,0.5637,0.5063,2
This entry was posted in PyTorch. Bookmark the permalink.

Leave a comment