Gaussian Process Classification on the Wheat Seeds Dataset Using the scikit Library

A classification problem is one where the goal is to predict a single categorical value. For example, you might want to predict sex of a person (0 = male, 1 = female) based on age, income, and so on (a binary classification problem). Or you might want to predict the species of a wheat seed (0 = Kama, 1 = Rosa, 2 = Canadian) based on predictors like seed length, width, and so on (a multi-class problem).

There are many completely different techniques for classification. Examples include logistic regression (binary classification), naive Bayes (different versions depending on the data types of the predictor variables), decision trees (many variations such as AdaBoost), and neural networks.

A relatively rare technique for classification (among my colleagues at least) is Gaussian process classification (GPC). When GPC works, it often works very well, but when GPC fails, it often fails spectacularly.

I put together a brief demo using the scikit GaussianProcessClassifier module. For my demo I used the Wheat Seeds Dataset. There are 210 source data items, each representing one of three species of wheat seeds. There are 7 predictor variables.

As is usually the case, preparing the data was time-consuming and took far longer than building the machine learning model. I used divide-by-k normalization on the 7 predictors. The divide constants I used are (25, 20, 1, 10, 10, 10, 10) so that all predictors are between 0.0 and 1.0. After normalizing the 210 data items, I used the first 60 of each class/species for a 180-item training dataset, and the last 10 of each class/species as a 30-item test dataset. The resulting data looks like:

0.6104  0.7420  0.8710  0.5763  0.3312  0.2221  0.5220  0
0.5952  0.7285  0.8811  0.5554  0.3333  0.1018  0.4956  0
. . .
0.7052  0.7990  0.8673  0.6191  0.3561  0.4076  0.6060  1
0.6736  0.7835  0.8623  0.5998  0.3484  0.4675  0.5877  1
. . .
0.5048  0.6835  0.8481  0.5410  0.2911  0.3306  0.5231  2
0.5104  0.6690  0.8964  0.5073  0.3155  0.2828  0.4830  2

Using scikit GPR is simultaneously easy and difficult. The key code that creates my demo classification model is:

  print("Creating GPC model with default RBF kernel ")
  krnl = 1.0 * RBF(1.0)  # RBF parameter will be optimized
  model = GaussianProcessClassifier(kernel=krnl, random_state=0)

  print("Training model ")
  model.fit(train_X, train_y)
  print("Done ")

The code is deceptively simple because the main effort is spent on guessing which kernel functions to use and all their parameters. Surprisingly, the default RBF (radial basis function) kernel worked very well and gave 80.00% accuracy (24 out of 30 correct) on the test data. This is pretty much the same results I get when using a sophisticated PyTorch neural network classifier. In most GPC scenarios you spend a lot of time trying different kernel functions with different parameters.

Gaussian process classification is very complex mathematically. It is an extension of Gaussian process regression, where the output is a single numerice value, plus a logistic sigmoid link function to squash the output to a probability (for binary classification), plus a one-vs-rest scheme for multi-class classification.

There are dozens of blog posts and YouTube videos that explain GPR and GPC at many different levels. There’s not one best explanation of GPR — it all depends on what sort of background you have. Do you already understand kernels or not? Do you already understand multivariate Gaussian distributions or not? Do you already understand covariance matrices or not? Do you already understand the logistic sigmoid function or not? And so on.

Good fun though.



When Gaussian process regression or classification models fail, they often fail spectacularly. I grew up during the early days of the U.S. space program when there were spectacular successes and failures.

Left: The very first U.S. attempt to launch a satellite on a Vanguard rocket in 1957 was an embarrassing failure but the lessons learned eventually led to the first man on the moon just 12 years later.

Center: A Juno rocket failed a few seconds after launch in 1959 when its guidance system malfunctioned. Just a few months later, on May 5, 1961, Alan Shepard sat on top of a lengthened version of Juno called Redstone, and became the first American in space. Tremendous bravery!

Right: An Atlas rocket failed on the launch pad in September 1961. The design of the Atlas required internal pressure to retain structural integrity, and loss of pressure was disastrous. Just five months later, on February 20, 1962, John Glenn sat on top of an Atlas booster and became the first American in orbit. What incredible courage the early American astronauts and Soviet cosmonauts had! Sadly, NASA has lost its way and the upcoming mission to the moon is a politically correct embarrassment with two clearly unqualified astronauts, and Russia’s invasion of Ukraine is beyond despicable. But I’m optimistic that NASA will eventually get back on track and sanity will eventually return to Russia.


Demo code. The training and test data can be found at https://jamesmccaffrey.wordpress.com/2023/04/04/the-wheat-seeds-dataset-problem-using-pytorch/.

# wheat_gauss_process_classify.py
# Gaussian process classification

# Anaconda3-2022.10  Python 3.9.13
# scikit 1.0.2  Windows 10/11 

import numpy as np
from sklearn.gaussian_process import GaussianProcessClassifier

from sklearn.gaussian_process.kernels import RBF
from sklearn.gaussian_process.kernels import DotProduct
from sklearn.gaussian_process.kernels import WhiteKernel
from sklearn.gaussian_process.kernels import ConstantKernel

# -----------------------------------------------------------

def main():
  # 0. prepare
  print("\nBegin scikit Gaussian process classification ")
  print("Predict wheat seed species (Kama=0, Rosa=1, \
Canadian=2) ")
  np.random.seed(1)
  np.set_printoptions(precision=4, suppress=True)

# -----------------------------------------------------------

  # 1. load data
  print("\nLoading train and test data ")
  train_file = ".\\Data\\wheat_train_k.txt"
  train_X = np.loadtxt(train_file, delimiter="\t", 
    usecols=[0,1,2,3,4,5,6],
    comments="#", dtype=np.float32)
  train_y = np.loadtxt(train_file, delimiter="\t", 
    usecols=7, comments="#", dtype=np.int64) 

  test_file = ".\\Data\\wheat_test_k.txt"
  test_X = np.loadtxt(test_file, delimiter="\t",
    usecols=[0,1,2,3,4,5,6],
    comments="#", dtype=np.float32)
  test_y = np.loadtxt(test_file, delimiter="\t",
    usecols=7, comments="#", dtype=np.int64) 
  print("Done ")

  print("\nFirst few predictor values data: ")
  print(train_X[0:4][:])
  print(". . .")
  print("\nFirst few actual species: ")
  print(train_y[0:4])
  print(". . .")

# -----------------------------------------------------------

  # 2. create and train GPC model
  print("\nCreating GPC model with default RBF kernel ")

  # GaussianProcessClassifier(kernel=None, *,
  #  optimizer='fmin_l_bfgs_b', n_restarts_optimizer=0,
  #  max_iter_predict=100, warm_start=False, copy_X_train=True,
  #  random_state=None, multi_class='one_vs_rest', n_jobs=None)
  #
  # RBF(length_scale=1.0, length_scale_bounds=(1e-5, 100000.0))
  # DotProduct(sigma_0=1.0, sigma_0_bounds=(1e-5, 100000.0))
  # WhiteKernel(noise_level=1.0, 
  #  noise_level_bounds=(1e-5, 100000.0))

  krnl = 1.0 * RBF(1.0)
  # krnl = RBF(1.0, length_scale_bounds=(1e-6, 100000.0))
  # krnl = DotProduct() + WhiteKernel(noise_level=0.5)
  # krnl = ConstantKernel(1.0, (1e-1, 1e3)) * 
  #  RBF(10.0, (1e-3, 1e3))
  model = GaussianProcessClassifier(kernel=krnl, \
   random_state=0)

  print("\nTraining model ")
  model.fit(train_X, train_y)
  print("Done ")

# -----------------------------------------------------------

  # 3. compute model accuracy
  print("\nComputing accuracy ")
  acc_train = model.score(train_X, train_y)
  print("\nAccuracy on train data = %0.4f " % acc_train)
  acc_test = model.score(test_X, test_y)
  print("Accuracy on test data = %0.4f " % acc_test)

  # 4. use model
  X = np.array([[0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5]],
    dtype=np.float32)
  print("\nPredicting wheat species for dummy predictors ")
  pred_species = model.predict_proba(X)
  print("Prediction: ")
  print(pred_species)

  # 5. TODO: save model using pickle

  print("\nEnd GPC demo ")

if __name__ == "__main__":
  main()
This entry was posted in Scikit. Bookmark the permalink.

Leave a Reply

Please log in using one of these methods to post your comment:

WordPress.com Logo

You are commenting using your WordPress.com account. Log Out /  Change )

Facebook photo

You are commenting using your Facebook account. Log Out /  Change )

Connecting to %s