Support Vector Machine Classification

Take a look at the graph below. Each of the nine data points belongs to one of three classes, red = 1, blue = 2, green = 3. The goal of a machine learning classifier is to create a prediction equation. For example, if we want to predict the class of a new point (4,5), we’d expect the classifier to respond with 2 (blue).

There are dozens of machine learning classification algorithms. For example, for this demo problem you could use multi-class logistic regression, or neural network classification, or SVM (support vector machine) classification.

The SVM algorithm is one of the most complex in machine learning and writing SVM from scratch isn’t practical so you have to use a tool. I coded up an example using the “svm” module in the “sklearn” Python language library:

# svm_demo.py

from sklearn import svm
import numpy as np

def get_points():
  return np.array([[2,3], [3,2], [4,3],
                   [3,6], [4,7], [5,6],
                   [6,4], [7,3], [7,5]])

def get_labels():
  return np.array([1,1,1, 2,2,2, 3,3,3])

# --------------------------------------

print("\nBegin SVM using sklearn demo \n")

print("Loading test data")
points = get_points()
labels = get_labels()

print("Creating SVM classifier \n")
classifier = svm.SVC(kernel='rbf', gamma=1.0, C=10.0)
classifier.fit(points, labels)

unknown = np.array([[4,5]])
print("Making prediction for: ")
print(unknown)
pred_class = classifier.predict(unknown)

print("\nPredicted class is: ")
print(pred_class)

print("\nEnd demo \n")

The SVC object (“support vector classifier”) requires a minimum of two parameters. The first is a kernel function, which typically has one or more of its own parameters. I used “rbf” (radial basis function) with gamma = 1.0. The second parameter, C, controls how the SVC classifier deals with outlier data points, and I set its value to 10.0.

I don’t use SVM classification very often. The values for the kernel function and its parameter(s), and the C constant, must be determined by trial and error. In general I prefer using neural network classification.

Advertisements
This entry was posted in Machine Learning. Bookmark the permalink.