## Multi-Class Classification Using a scikit Decision Tree

I’ve been reviewing the scikit-learn (scikit for short) library for several months, so I figured I’d do a multi-class decision tree classification example. Before I go any further, let me comment that machine learning beginners are often seduced by the visual elegance of decision trees, but tree classifiers have several weaknesses.

I used one of my standard datasets for multi-class classification. The data looks like:

``` 1   0.24   1 0 0   0.2950   2
0   0.39   0 0 1   0.5120   0
1   0.63   0 1 0   0.7580   1
0   0.36   1 0 0   0.4450   2
. . .
```

Each line of data represents a person. The fields are sex (male = 0, female = 1), age (normalized by dividing by 100), State (Michigan = 100, Nebraska = 010, Oklahoma = 001), annual income (divided by 100,000), and politics type (conservative = 0, moderate = 1, liberal = 2). The goal is to predict the politics type of a person from their sex, age, State, and income.

It isn’t necessary to normalize age and income. Converting categorical predictors like State is conceptually tricky, but the bottom line is that in most scenarios it’s best to one-hot encode. For binary predictor variables, I recommend using 0 or 1 encoding, but again, there are a lot of subtle details.

The key lines of code are:

```  import numpy as np
from sklearn import tree

md = 3  # max depth
print("Creating decision tree max_depth=" + str(md))
model = tree.DecisionTreeClassifier(max_depth=md)
model.fit(train_x, train_y)
print("Done ")
```

Decision trees are highly sensitive to overfitting. If you set a large max_depth you can get 100% classification on training data but the accuracy on test data and new previously unseen data will likely be very poor.

The accuracy of the model can be displayed using the built-in score() method or indirectly in the form of a confusion matrix:

```  from sklearn.metrics import confusion_matrix
y_predicteds = model.predict(test_x)
cm = confusion_matrix(test_y, y_predicteds)
print("Confusion matrix: \n")
# print(cm)  # no formatting
show_confusion(cm)  # custom formatting
```

There are several ways to visualize a trained tree classifier. The model can be displayed as text pseudo-code like this:

```  pseudo = tree.export_text(model,
["sex", "age",
"state0", "state1", "state2",
"income"])
print("Model in pseudo-code: ")
print(pseudo)
```

The wackiness of the pseudo-code points out a weakness of decision tress — they’re highly sensitive to changes in training data.

The tree model can be displayed graphically using the plot_tree() method like so:

```  import matplotlib.pyplot as plt
plt.figure(figsize=(14,8),
tight_layout=True)  # w,h inches
tree.plot_tree(model,
feature_names=["sex", "age",
"state0", "state1", "state2",
"income"],
class_names=["con", "mod", "lib"],
fontsize=8)
plt.show()
```

Anyway, the demo was a good refresher for me.

One of my favorite series of science fiction books is the Mars series by author Edgar Rice Burroughs. The fictional world has a lot of politics and races: Red Martians (human-like), Green (fierce, 15-feet tall with six arms), Yellow Martians (secretive), White Martians (predecessors to Red Martians), and Black Martians (evil).

“A Fighting Man of Mars” is the seventh book in the series. It was first published in book form in 1931. The book tells the story of low-born soldier Tan Hadron who sets of to rescue snooty noblewoman Sanoma. He has many adventures and recues and falls in love with beautiful slave Tavia — who turn out to be a princess.

Left: Cover art by Robert Abbett. Center: Cover art by Michael Whelan. Right: Cover art by Roy Krenkel.

Demo code. Replace “lte” with Boolean less-than-or-equal operator. The data is also listed below.

```# people_politics_tree_sckit.py

# predict politics (0 = con, 1 = mod, 2 = lib)
# from sex, age, state, income

# sex  age    state    income   politics
#  0   0.27   0  1  0   0.7610   2
#  1   0.19   0  0  1   0.6550   0
# sex: 0 = male, 1 = female
# state: michigan = 100, nebraska = 010, oklahoma = 001
# politics: conservative, moderate, liberal

# Anaconda3-2020.02  Python 3.7.6  scikit 0.22.1
# Windows 10/11

import numpy as np
from sklearn import tree

# ---------------------------------------------------------

def tree_to_pseudo(model, feature_names):
# custom function to display tree model pseudo-code
left = model.tree_.children_left
right = model.tree_.children_right
threshold = model.tree_.threshold
features = [feature_names[i] for i in model.tree_.feature]
value = model.tree_.value

def recurse(left, right, threshold, features, node, depth=0):
indent = "  " * depth
if (threshold[node] != -2):
v = "%0.4f" % threshold[node]
print(indent,"if ( " + features[node] + " lte " +
str(v) + " ) {")

if left[node] != -1:
recurse(left, right, threshold, features, \
left[node], depth+1)
print(indent,"} else {")
if right[node] != -1:
recurse(left, right, threshold, features, \
right[node], depth+1)
print(indent,"}")
else:
idx = np.argmax(value[node])
# print(indent,"return " + str(value[node]))
print(indent,"return " + str(model.classes_[idx]))

recurse(left, right, threshold, features, 0)

# ---------------------------------------------------------

def show_confusion(cm):
dim = len(cm)
mx = np.max(cm)             # largest count in cm
wid = len(str(mx)) + 1      # width to print
fmt = "%" + str(wid) + "d"  # like "%3d"
for i in range(dim):
print("actual   ", end="")
print("%3d:" % i, end="")
for j in range(dim):
print(fmt % cm[i][j], end="")
print("")
print("------------")
print("predicted    ", end="")
for j in range(dim):
print(fmt % j, end="")
print("")

# ---------------------------------------------------------

def main():
print("\nBegin scikit decision tree example ")
print("Predict politics from sex, age, State, income ")
np.random.seed(0)
np.set_printoptions(precision=4, suppress=True)

# sex  age    state    income   politics
#  0   0.27   0  1  0   0.7610   2
#  1   0.19   0  0  1   0.6550   0

train_file = ".\\Data\\people_train.txt"
train_x = train_xy[:,0:6]
train_y = train_xy[:,6].astype(int)

test_file = ".\\Data\\people_test.txt"
test_x = test_xy[:,0:6]
test_y = test_xy[:,6].astype(int)

print("\nTraining data:")
print(train_x[0:4])
print(". . . \n")
print(train_y[0:4])
print(". . . ")

# 2. create and train
md = 3
print("\nCreating decision tree max_depth=" + str(md))
model = tree.DecisionTreeClassifier(max_depth=md)
model.fit(train_x, train_y)
print("Done ")

# 3. evaluate
acc_train = model.score(train_x, train_y)
print("\nAccuracy on train = %0.4f " % acc_train)
acc_test = model.score(test_x, test_y)
print("Accuracy on test = %0.4f " % acc_test)

# 3b. display formatted confusion matrix
from sklearn.metrics import confusion_matrix
y_predicteds = model.predict(test_x)
cm = confusion_matrix(test_y, y_predicteds)
print("\nConfusion matrix: \n")
show_confusion(cm)

# 4a. visualize using custom function
# print("\nModel in pseudo-code: ")
# tree_to_pseudo(model, ["sex", "age",
#   "state0", "state1", "state2",
#  "income"])

# 4b. use built-in export_text()
pseudo = tree.export_text(model,
["sex", "age",
"state0", "state1", "state2",
"income"])
print("\nModel in pseudo-code: ")
print(pseudo)

# 4c. use built-in plot_tree()
import matplotlib.pyplot as plt
plt.figure(figsize=(14,8),
tight_layout=True)  # w,h inches
tree.plot_tree(model,
feature_names=["sex", "age",
"state0", "state1", "state2",
"income"],
class_names=["con", "mod", "lib"],
fontsize=8)
plt.show()

# 5. use model
print("\nPredict for: M 35 Nebraska \$55K ")
X = np.array([[0, 0.35, 0,1,0, 0.5500]],
dtype=np.float32)
probs = model.predict_proba(X)
print("\nPrediction pseudo-probs: ")
print(probs)

politic = model.predict(X)
print("\nPredicted class: ")
print(politic)

# 6. TODO: save model using pickle
# import pickle
# print("Saving trained tree model ")
# path = ".\\Models\\tree_scikit_model.sav"
# pickle.dump(model, open(path, "wb"))

# use saved model
# X = np.array([[0, 0.35, 0,1,0, 0.5500]],
#   dtype=np.float32)
# with open(path, 'rb') as f:
# print(pa)

print("\nEnd scikit decision tree demo ")

if __name__ == "__main__":
main()
```

Training data. Replace commas with tab characters or modify program.

```# people_train.txt
# sex (M=0, F=1), age (div 100)
# state (michigan = 100, nebraska = 010,
#  oklahoma = 001)
# income (div 100,000)
# politics (con = 0, mod = 1, lib = 2)
#
1,0.24,1,0,0,0.2950,2
0,0.39,0,0,1,0.5120,1
1,0.63,0,1,0,0.7580,0
0,0.36,1,0,0,0.4450,1
1,0.27,0,1,0,0.2860,2
1,0.50,0,1,0,0.5650,1
1,0.50,0,0,1,0.5500,1
0,0.19,0,0,1,0.3270,0
1,0.22,0,1,0,0.2770,1
0,0.39,0,0,1,0.4710,2
1,0.34,1,0,0,0.3940,1
0,0.22,1,0,0,0.3350,0
1,0.35,0,0,1,0.3520,2
0,0.33,0,1,0,0.4640,1
1,0.45,0,1,0,0.5410,1
1,0.42,0,1,0,0.5070,1
0,0.33,0,1,0,0.4680,1
1,0.25,0,0,1,0.3000,1
0,0.31,0,1,0,0.4640,0
1,0.27,1,0,0,0.3250,2
1,0.48,1,0,0,0.5400,1
0,0.64,0,1,0,0.7130,2
1,0.61,0,1,0,0.7240,0
1,0.54,0,0,1,0.6100,0
1,0.29,1,0,0,0.3630,0
1,0.50,0,0,1,0.5500,1
1,0.55,0,0,1,0.6250,0
1,0.40,1,0,0,0.5240,0
1,0.22,1,0,0,0.2360,2
1,0.68,0,1,0,0.7840,0
0,0.60,1,0,0,0.7170,2
0,0.34,0,0,1,0.4650,1
0,0.25,0,0,1,0.3710,0
0,0.31,0,1,0,0.4890,1
1,0.43,0,0,1,0.4800,1
1,0.58,0,1,0,0.6540,2
0,0.55,0,1,0,0.6070,2
0,0.43,0,1,0,0.5110,1
0,0.43,0,0,1,0.5320,1
0,0.21,1,0,0,0.3720,0
1,0.55,0,0,1,0.6460,0
1,0.64,0,1,0,0.7480,0
0,0.41,1,0,0,0.5880,1
1,0.64,0,0,1,0.7270,0
0,0.56,0,0,1,0.6660,2
1,0.31,0,0,1,0.3600,1
0,0.65,0,0,1,0.7010,2
1,0.55,0,0,1,0.6430,0
0,0.25,1,0,0,0.4030,0
1,0.46,0,0,1,0.5100,1
0,0.36,1,0,0,0.5350,0
1,0.52,0,1,0,0.5810,1
1,0.61,0,0,1,0.6790,0
1,0.57,0,0,1,0.6570,0
0,0.46,0,1,0,0.5260,1
0,0.62,1,0,0,0.6680,2
1,0.55,0,0,1,0.6270,0
0,0.22,0,0,1,0.2770,1
0,0.50,1,0,0,0.6290,0
0,0.32,0,1,0,0.4180,1
0,0.21,0,0,1,0.3560,0
1,0.44,0,1,0,0.5200,1
1,0.46,0,1,0,0.5170,1
1,0.62,0,1,0,0.6970,0
1,0.57,0,1,0,0.6640,0
0,0.67,0,0,1,0.7580,2
1,0.29,1,0,0,0.3430,2
1,0.53,1,0,0,0.6010,0
0,0.44,1,0,0,0.5480,1
1,0.46,0,1,0,0.5230,1
0,0.20,0,1,0,0.3010,1
0,0.38,1,0,0,0.5350,1
1,0.50,0,1,0,0.5860,1
1,0.33,0,1,0,0.4250,1
0,0.33,0,1,0,0.3930,1
1,0.26,0,1,0,0.4040,0
1,0.58,1,0,0,0.7070,0
1,0.43,0,0,1,0.4800,1
0,0.46,1,0,0,0.6440,0
1,0.60,1,0,0,0.7170,0
0,0.42,1,0,0,0.4890,1
0,0.56,0,0,1,0.5640,2
0,0.62,0,1,0,0.6630,2
0,0.50,1,0,0,0.6480,1
1,0.47,0,0,1,0.5200,1
0,0.67,0,1,0,0.8040,2
0,0.40,0,0,1,0.5040,1
1,0.42,0,1,0,0.4840,1
1,0.64,1,0,0,0.7200,0
0,0.47,1,0,0,0.5870,2
1,0.45,0,1,0,0.5280,1
0,0.25,0,0,1,0.4090,0
1,0.38,1,0,0,0.4840,0
1,0.55,0,0,1,0.6000,1
0,0.44,1,0,0,0.6060,1
1,0.33,1,0,0,0.4100,1
1,0.34,0,0,1,0.3900,1
1,0.27,0,1,0,0.3370,2
1,0.32,0,1,0,0.4070,1
1,0.42,0,0,1,0.4700,1
0,0.24,0,0,1,0.4030,0
1,0.42,0,1,0,0.5030,1
1,0.25,0,0,1,0.2800,2
1,0.51,0,1,0,0.5800,1
0,0.55,0,1,0,0.6350,2
1,0.44,1,0,0,0.4780,2
0,0.18,1,0,0,0.3980,0
0,0.67,0,1,0,0.7160,2
1,0.45,0,0,1,0.5000,1
1,0.48,1,0,0,0.5580,1
0,0.25,0,1,0,0.3900,1
0,0.67,1,0,0,0.7830,1
1,0.37,0,0,1,0.4200,1
0,0.32,1,0,0,0.4270,1
1,0.48,1,0,0,0.5700,1
0,0.66,0,0,1,0.7500,2
1,0.61,1,0,0,0.7000,0
0,0.58,0,0,1,0.6890,1
1,0.19,1,0,0,0.2400,2
1,0.38,0,0,1,0.4300,1
0,0.27,1,0,0,0.3640,1
1,0.42,1,0,0,0.4800,1
1,0.60,1,0,0,0.7130,0
0,0.27,0,0,1,0.3480,0
1,0.29,0,1,0,0.3710,0
0,0.43,1,0,0,0.5670,1
1,0.48,1,0,0,0.5670,1
1,0.27,0,0,1,0.2940,2
0,0.44,1,0,0,0.5520,0
1,0.23,0,1,0,0.2630,2
0,0.36,0,1,0,0.5300,2
1,0.64,0,0,1,0.7250,0
1,0.29,0,0,1,0.3000,2
0,0.33,1,0,0,0.4930,1
0,0.66,0,1,0,0.7500,2
0,0.21,0,0,1,0.3430,0
1,0.27,1,0,0,0.3270,2
1,0.29,1,0,0,0.3180,2
0,0.31,1,0,0,0.4860,1
1,0.36,0,0,1,0.4100,1
1,0.49,0,1,0,0.5570,1
0,0.28,1,0,0,0.3840,0
0,0.43,0,0,1,0.5660,1
0,0.46,0,1,0,0.5880,1
1,0.57,1,0,0,0.6980,0
0,0.52,0,0,1,0.5940,1
0,0.31,0,0,1,0.4350,1
0,0.55,1,0,0,0.6200,2
1,0.50,1,0,0,0.5640,1
1,0.48,0,1,0,0.5590,1
0,0.22,0,0,1,0.3450,0
1,0.59,0,0,1,0.6670,0
1,0.34,1,0,0,0.4280,2
0,0.64,1,0,0,0.7720,2
1,0.29,0,0,1,0.3350,2
0,0.34,0,1,0,0.4320,1
0,0.61,1,0,0,0.7500,2
1,0.64,0,0,1,0.7110,0
0,0.29,1,0,0,0.4130,0
1,0.63,0,1,0,0.7060,0
0,0.29,0,1,0,0.4000,0
0,0.51,1,0,0,0.6270,1
0,0.24,0,0,1,0.3770,0
1,0.48,0,1,0,0.5750,1
1,0.18,1,0,0,0.2740,0
1,0.18,1,0,0,0.2030,2
1,0.33,0,1,0,0.3820,2
0,0.20,0,0,1,0.3480,0
1,0.29,0,0,1,0.3300,2
0,0.44,0,0,1,0.6300,0
0,0.65,0,0,1,0.8180,0
0,0.56,1,0,0,0.6370,2
0,0.52,0,0,1,0.5840,1
0,0.29,0,1,0,0.4860,0
0,0.47,0,1,0,0.5890,1
1,0.68,1,0,0,0.7260,2
1,0.31,0,0,1,0.3600,1
1,0.61,0,1,0,0.6250,2
1,0.19,0,1,0,0.2150,2
1,0.38,0,0,1,0.4300,1
0,0.26,1,0,0,0.4230,0
1,0.61,0,1,0,0.6740,0
1,0.40,1,0,0,0.4650,1
0,0.49,1,0,0,0.6520,1
1,0.56,1,0,0,0.6750,0
0,0.48,0,1,0,0.6600,1
1,0.52,1,0,0,0.5630,2
0,0.18,1,0,0,0.2980,0
0,0.56,0,0,1,0.5930,2
0,0.52,0,1,0,0.6440,1
0,0.18,0,1,0,0.2860,1
0,0.58,1,0,0,0.6620,2
0,0.39,0,1,0,0.5510,1
0,0.46,1,0,0,0.6290,1
0,0.40,0,1,0,0.4620,1
0,0.60,1,0,0,0.7270,2
1,0.36,0,1,0,0.4070,2
1,0.44,1,0,0,0.5230,1
1,0.28,1,0,0,0.3130,2
1,0.54,0,0,1,0.6260,0
```

Test data. Replace commas with tab characters or modify program.

```0,0.51,1,0,0,0.6120,1
0,0.32,0,1,0,0.4610,1
1,0.55,1,0,0,0.6270,0
1,0.25,0,0,1,0.2620,2
1,0.33,0,0,1,0.3730,2
0,0.29,0,1,0,0.4620,0
1,0.65,1,0,0,0.7270,0
0,0.43,0,1,0,0.5140,1
0,0.54,0,1,0,0.6480,2
1,0.61,0,1,0,0.7270,0
1,0.52,0,1,0,0.6360,0
1,0.30,0,1,0,0.3350,2
1,0.29,1,0,0,0.3140,2
0,0.47,0,0,1,0.5940,1
1,0.39,0,1,0,0.4780,1
1,0.47,0,0,1,0.5200,1
0,0.49,1,0,0,0.5860,1
0,0.63,0,0,1,0.6740,2
0,0.30,1,0,0,0.3920,0
0,0.61,0,0,1,0.6960,2
0,0.47,0,0,1,0.5870,1
1,0.30,0,0,1,0.3450,2
0,0.51,0,0,1,0.5800,1
0,0.24,1,0,0,0.3880,1
0,0.49,1,0,0,0.6450,1
1,0.66,0,0,1,0.7450,0
0,0.65,1,0,0,0.7690,0
0,0.46,0,1,0,0.5800,0
0,0.45,0,0,1,0.5180,1
0,0.47,1,0,0,0.6360,0
0,0.29,1,0,0,0.4480,0
0,0.57,0,0,1,0.6930,2
0,0.20,1,0,0,0.2870,2
0,0.35,1,0,0,0.4340,1
0,0.61,0,0,1,0.6700,2
0,0.31,0,0,1,0.3730,1
1,0.18,1,0,0,0.2080,2
1,0.26,0,0,1,0.2920,2
0,0.28,1,0,0,0.3640,2
0,0.59,0,0,1,0.6940,2
```
This entry was posted in Scikit. Bookmark the permalink.