## The Wheat Seed Problem Using Radius Neighbors Classification with the scikit Library

I was looking at k-NN classification using the scikit library recently. While wading through the scikit documentation, I noticed that scikit has a closely related Radius Neighbors classifier module. When I did my k-NN example, I used the Wheat Seeds dataset where the goal is to predict the species of a wheat seed (0 = Kama, 1 = Rosa, 2 = Canadian) from seven predictor variables: seed length, width, perimeter, and so on.

The most difficult part of the k-NN experiment was preparing the training and test data. Because I’d already done all the data preparation and had data ready, I figured I’d apply the radius neighbors classifier to the prepared wheat seed data.

In k-NN the unknown input to classify is compared to the k closest labeled data items in the training data (using Euclidean distance), and the most common class label is the prediction. For example, if k = 5, and the five closest data items to the item to predict have class labels (2, 0, 2, 2, 1), then the predicted class label is 2.

In radius neighbors classification, instead of specifying k data points to compare, you specify a radius. All the labeled data items with the radius of the item to classify are examined and the most common label is the prediction.

The raw Wheat Seeds data came from archive.ics.uci.edu/ml/datasets/seeds and looks like:

```15.26  14.84  0.871   5.763  3.312  2.221  5.22   1
14.88  14.57  0.8811  5.554  3.333  1.018  4.956  1
. . .
17.63  15.98  0.8673  6.191  3.561  4.076  6.06   2
16.84  15.67  0.8623  5.998  3.484  4.675  5.877  2
. . .
11.84  13.21  0.8521  5.175  2.836  3.598  5.044  3
12.3   13.34  0.8684  5.243  2.974  5.637  5.063  3
```

There are 210 data items. Each represents one of three species of wheat seeds: Kama, Rosa, Canadian. There are 70 of each species. The first 7 values on each line are the predictors: area, perimeter, compactness, length, width, asymmetry, groove. The eighth value in the raw data is the one-based encoded species. The goal is to predict species from the seven predictor values.

When using any of the neighbors classification techniques, it’s important to normalize the numeric predictors so that they all have roughly the same magnitude so that a predictor with large values doesn’t overwhelm other predictor values. As is often the case in machine learning, data preparation takes most of the time an effort of any exploration.

I dropped the raw data into an Excel spreadsheet. For each predictor, I computed the min and max values of the column. Then I performed min-max normalization where each value x in a column is normalized to x’ = (x – min) / (max – min). The result is that each predictor is a value between 0.0 and 1.0.

I recoded the target class labels from one-based to zero-based. The resulting 210-item dataset looks like:

```0.4410  0.5021  0.5708  0.4865  0.4861  0.1893  0.3452  0
0.4051  0.4463  0.6624  0.3688  0.5011  0.0329  0.2152  0
. . .
0.6648  0.7376  0.5372  0.7275  0.6636  0.4305  0.7587  1
0.5902  0.6736  0.4918  0.6188  0.6087  0.5084  0.6686  1
. . .
0.1917  0.2603  0.3630  0.2877  0.2003  0.3304  0.3506  2
0.2049  0.2004  0.8013  0.0980  0.3742  0.2682  0.1531  2
```

I split the 210-item normalized data into a 180-item training set and a 30-item test set. I used the first 60 of each target class for training and the last 10 of each target class for testing. Put another way, for the training data, items [0] to [59] are class 0, items [60] to [119] are class 1, and items [120] to [179] are class 2.

Using scikit is easy. After loading the training and test data into memory, a radius neighbors multi-class classification model is created and trained like so:

```  import numpy as np
. . .
model.fit(train_X, train_y)
print("Done ")
```

The hard part is determining the radius value. The default value is 1.0 but that radius was too big and it included virtually all of the training data. When I used a radius of 0.2, none of the labeled items were within that radius. I set up a dummy input to predict of all 0.5 values. With radius = 0.38 the prediction is [0.65, 0.35, 0.00]. Because the largest value is at index [0], the predicted wheat seed species is 0 = Kama.

I analyzed the results of the prediction using the radius_neighbors() method:

```  print("The idxs of neighbors within %0.2f are: " % rad)
np.set_printoptions(linewidth=40)
print(idxs)
```

This code gave me the indexes of the labeled training items that are within 0.38 of the input X = [0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5]:

```[  0,   6,  10,  20,  31,  32,
34,  36,  38,  43,  49,  50,
51,  52,  55,  61,  62,  64,
65,  66,  70,  90, 112]
```

There are 23 data items. Recall that the 15 items at indexes 0, 6, 10, 20, 31, 32, 34, 36, 38, 43, 49, 50, 51, 52, 55 are class 0, and the 8 items at indexes 61, 62, 64, 65, 66, 70, 90, 112 are class 1, and none of the class 2 items at indexes 120 to 179 appear.

Therefore, the probability of class 0 is 15 / 23 = 0.65, and the probability of class 1 is 8 / 23 = 0.35, and the probability of class 2 is 0 / 23 = 0.00.

Radius neighbors classification isn’t used very often. In practice, it’s a bit too difficult to specify the radius value.

The key to the radius neighbors classification algorithm is the concept of a radius. Most people, including me, first come into contact with radius in a geometry class: circles and spheres. Three memorable spherical spaceships from science fiction movies. Left: The alien ship from “It Came From Outer Space” (1953). The ship crashed on Earth and the (good) aliens impersonated townspeople to get supplies to repair their craft. Center: The Aries 1B from “2001: A Space Odyssey” (1968) is a shuttle for travel between a space station in Earth orbit and the Moon. Right: The Heart of Gold from “The Hitchhiker’s Guide to the Galaxy” (2005).

Demo code:

```# wheat_krn.py
# radius neighbor version of k-NN

# predict wheat seed species (0=Kama, 1=Rosa, 2=Canadian)
# from area, perimeter, compactness, length, width,
#   asymmetry, groove

# Anaconda3-2020.02  Python 3.7.6  scikit 0.22.1
# Windows 10/11

import numpy as np

# ---------------------------------------------------------

def show_confusion(cm):
dim = len(cm)
mx = np.max(cm)             # largest count in cm
wid = len(str(mx)) + 1      # width to print
fmt = "%" + str(wid) + "d"  # like "%3d"
for i in range(dim):
print("actual   ", end="")
print("%3d:" % i, end="")
for j in range(dim):
print(fmt % cm[i][j], end="")
print("")
print("------------")
print("predicted    ", end="")
for j in range(dim):
print(fmt % j, end="")
print("")

# ---------------------------------------------------------

def main():
# 0. prepare
print("\nBegin Wheat Seeds radius neighbors using scikit ")
np.set_printoptions(precision=4, suppress=True)
np.random.seed(1)

# 0.4410  0.5021  0.5708  0.4865  0.4861  0.1893  0.3452  0
# 0.4051  0.4463  0.6624  0.3688  0.5011  0.0329  0.2152  0
# . . .
# 0.1917  0.2603  0.3630  0.2877  0.2003  0.3304  0.3506  2
# 0.2049  0.2004  0.8013  0.0980  0.3742  0.2682  0.1531  2

train_file = ".\\Data\\wheat_train.txt"  # 180 items

test_file = ".\\Data\\wheat_test.txt"  # 30 items

print("\nTraining data:")
print(train_X[0:4])
print(". . . \n")
print(train_y[0:4])
print(". . . ")

# 2. create and train model
#   algorithm='auto', leaf_size=30, p=2, metric='minkowski',
#   outlier_label=None, metric_params=None, n_jobs=None)
# algorithm: 'ball_tree', 'kd_tree', 'brute', 'auto'.

model.fit(train_X, train_y)
print("Done ")

# 3. evaluate model
train_acc = model.score(train_X, train_y)
test_acc= model.score(test_X, test_y)
print("\nAccuracy on train data = %0.4f " % train_acc)
print("Accuracy on test data = %0.4f " % test_acc)

from sklearn.metrics import confusion_matrix
y_predicteds = model.predict(test_X)
cm = confusion_matrix(test_y, y_predicteds)
print("\nConfusion matrix raw: \n")
# print(cm)
show_confusion(cm)  # custom formatted

# 4. use model
X = np.array([[0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5]],
dtype=np.float32)
print("\nPredicting wheat species for: ")
print(X)
probs = model.predict_proba(X)
print("\nPrediction probs: ")
print(probs)

print("\nThe idxs of neighbors within %0.2f are: " % rad)
np.set_printoptions(linewidth=40)
print(idxs)

predicted = model.predict(X)
print("\nPredicted class: ")
print(predicted)

# 5. TODO: save model using pickle
print("\nEnd demo ")

if __name__ == "__main__":
main()
```

Training data. Replace commas with tabs or modify program code.

```# wheat_train.txt
#
# http://archive.ics.uci.edu/ml/datasets/seeds
# 210 total items. train is first 60 each of 3 classes
# 180 training, 30 test
# area, perimeter, compactness, length, width, asymmetry, groove
# predictors are all min-max normalized
# 0 = Kama, 1 = Rosa, 2 = Canadian
#
0.4410,0.5021,0.5708,0.4865,0.4861,0.1893,0.3452,0
0.4051,0.4463,0.6624,0.3688,0.5011,0.0329,0.2152,0
0.3494,0.3471,0.8793,0.2207,0.5039,0.2515,0.1507,0
0.3069,0.3161,0.7931,0.2393,0.5339,0.1942,0.1408,0
0.5241,0.5331,0.8648,0.4274,0.6643,0.0767,0.3230,0
0.3579,0.3719,0.7895,0.2742,0.4861,0.2206,0.2152,0
0.3872,0.4298,0.6515,0.3739,0.4483,0.3668,0.3447,0
0.3324,0.3492,0.7532,0.2934,0.4790,0.2516,0.2368,0
0.5703,0.6302,0.6044,0.6498,0.5952,0.1658,0.6686,0
0.5524,0.5868,0.7250,0.5546,0.6237,0.1565,0.4993,0
0.4410,0.5041,0.5581,0.4589,0.4362,0.4912,0.3914,0
0.3248,0.3616,0.6488,0.3035,0.4070,0.1238,0.2373,0
0.3116,0.3326,0.7250,0.3041,0.4056,0.4188,0.1078,0
0.3012,0.3409,0.6152,0.3266,0.3749,0.3083,0.1738,0
0.2975,0.3388,0.6016,0.3283,0.3450,0.2817,0.1507,0
0.3777,0.3864,0.8276,0.2545,0.5011,0.4447,0.1290,0
0.3211,0.2934,1.0000,0.1239,0.5367,0.5811,0.1290,0
0.4816,0.4835,0.8866,0.3536,0.6301,0.1084,0.2595,0
0.3881,0.3719,0.9728,0.1723,0.5959,0.1303,0.0640,0
0.2011,0.2397,0.5490,0.1841,0.2986,0.4339,0.1945,0
0.3371,0.4112,0.4564,0.4274,0.3557,0.3000,0.3235,0
0.3324,0.3822,0.5817,0.3497,0.3835,0.2500,0.3447,0
0.4995,0.5145,0.8230,0.4048,0.6251,0.0000,0.2816,0
0.1407,0.1694,0.5290,0.1126,0.2181,0.0845,0.2176,0
0.4174,0.4855,0.5227,0.5011,0.4383,0.1334,0.2373,0
0.5288,0.5682,0.6969,0.5259,0.5638,0.0179,0.3880,0
0.2295,0.2789,0.5082,0.2793,0.2823,0.3391,0.1507,0
0.2030,0.2603,0.4383,0.2793,0.2324,0.2261,0.1723,0
0.3324,0.3657,0.6706,0.3615,0.4212,0.2586,0.2555,0
0.2701,0.3326,0.4746,0.3474,0.3100,0.3596,0.2846,0
0.2427,0.2913,0.5272,0.3125,0.2459,0.0117,0.2644,0
0.4627,0.5227,0.5835,0.4831,0.5282,0.3442,0.3491,0
0.3305,0.4132,0.4065,0.4606,0.3963,0.4102,0.3840,0
0.3163,0.3636,0.5871,0.3863,0.3706,0.1767,0.2427,0
0.4212,0.4690,0.6334,0.4578,0.4975,0.1773,0.4141,0
0.5222,0.5351,0.8339,0.4561,0.6094,0.1957,0.4549,0
0.5297,0.5909,0.5926,0.5220,0.5944,0.2676,0.4963,0
0.6128,0.6136,0.9056,0.5253,0.7505,0.2849,0.4751,0
0.3975,0.4360,0.6733,0.4262,0.4690,0.3052,0.3890,0
0.3484,0.3636,0.7831,0.2804,0.4761,0.7697,0.2373,0
0.2786,0.2975,0.7169,0.2528,0.3749,0.2369,0.3245,0
0.2748,0.2975,0.6996,0.2545,0.3763,0.1929,0.3235,0
0.2427,0.2355,0.8421,0.1346,0.4070,0.2205,0.1300,0
0.4636,0.5062,0.6706,0.5507,0.5460,0.5131,0.4968,0
0.4268,0.4401,0.8212,0.3829,0.5930,0.3072,0.3255,0
0.3031,0.3368,0.6470,0.2686,0.3742,0.1034,0.2176,0
0.4504,0.4855,0.7078,0.4516,0.5438,0.0783,0.3018,0
0.4155,0.4442,0.7278,0.3778,0.5324,0.2851,0.3230,0
0.3966,0.4360,0.6697,0.3637,0.4711,0.2521,0.2915,0
0.4032,0.4669,0.5399,0.4386,0.4476,0.1773,0.4097,0
0.3626,0.4112,0.6080,0.3863,0.4576,0.4174,0.3077,0
0.4901,0.5165,0.7641,0.4364,0.5731,0.6277,0.3038,0
0.3683,0.4545,0.4147,0.4595,0.3443,0.4357,0.4318,0
0.3532,0.3864,0.6806,0.3407,0.4056,0.3332,0.3471,0
0.3711,0.4525,0.4319,0.4741,0.3443,0.0931,0.4766,0
0.4193,0.4876,0.5236,0.4521,0.4148,0.1519,0.4530,0
0.3654,0.4008,0.6688,0.2753,0.5324,0.2648,0.2585,0
0.4089,0.4174,0.8394,0.2731,0.5574,0.0490,0.2802,0
0.4523,0.4876,0.7042,0.4296,0.5624,0.1604,0.3461,0
0.1435,0.2190,0.2822,0.1464,0.2865,0.0958,0.0000,0
0.6648,0.7376,0.5372,0.7275,0.6636,0.4305,0.7587,1
0.5902,0.6736,0.4918,0.6188,0.6087,0.5084,0.6686,1
0.6298,0.6860,0.6189,0.6075,0.6871,0.4907,0.6263,1
0.8045,0.7955,0.9074,0.7066,0.9266,0.2823,0.7681,1
0.5883,0.6405,0.6397,0.6295,0.6101,0.4211,0.6509,1
0.5836,0.6632,0.5054,0.5788,0.5759,0.5402,0.6283,1
0.6355,0.7231,0.4701,0.6560,0.5510,0.3977,0.6908,1
0.9556,0.9959,0.6189,0.9459,0.8439,0.4793,0.9513,1
0.7885,0.8430,0.6071,0.8705,0.7192,0.5590,0.9074,1
0.6166,0.6488,0.7359,0.5355,0.6671,0.2721,0.6041,1
0.5609,0.6054,0.6733,0.5495,0.5966,0.6198,0.6701,1
0.7677,0.7810,0.8131,0.6233,0.8746,0.5928,0.6696,1
0.9075,0.9256,0.7377,0.7804,0.8795,0.5731,0.8213,1
0.8480,0.8946,0.6334,0.8361,0.8140,0.0919,0.8636,1
0.8423,0.8884,0.6343,0.8260,0.8346,0.2856,0.8203,1
0.7252,0.7603,0.7160,0.7173,0.7277,0.2182,0.8262,1
0.7828,0.7955,0.8058,0.6672,0.8083,0.1149,0.7829,1
0.7923,0.8781,0.4619,0.9291,0.7413,0.3804,0.9744,1
1.0000,0.9917,0.8240,0.9426,1.0000,0.6521,0.8429,1
0.9717,0.9587,0.8621,0.8733,0.9993,0.5527,0.8872,1
0.8980,0.9463,0.6034,0.9471,0.8232,0.1547,0.9503,1
0.7715,0.7831,0.8194,0.7168,0.8311,0.3062,0.7553,1
0.7762,0.8017,0.7486,0.7731,0.7577,0.3214,0.7553,1
0.7554,0.7521,0.8938,0.6408,0.8767,0.6808,0.6686,1
0.7337,0.8492,0.3367,0.9949,0.6094,0.5419,0.9498,1
0.5930,0.6694,0.5145,0.6982,0.5937,0.3811,0.7129,1
0.8234,0.8636,0.6661,0.8119,0.8411,0.3526,0.8464,1
0.7923,0.8595,0.5499,0.8727,0.6572,0.1793,0.9522,1
0.7158,0.7955,0.5045,0.7725,0.6287,0.2715,0.8636,1
0.7677,0.8120,0.6615,0.7432,0.7512,0.1850,0.7770,1
0.5496,0.5868,0.7123,0.4611,0.6379,0.4488,0.5411,1
0.6988,0.7128,0.8267,0.5580,0.7584,0.1694,0.6489,1
0.8376,0.8450,0.8203,0.6836,0.8995,0.4607,0.7336,1
0.8111,0.8719,0.5771,0.8277,0.7491,0.3370,0.8419,1
0.7894,0.8285,0.6788,0.7596,0.8019,0.3384,0.8021,1
0.7781,0.8017,0.7586,0.6408,0.8239,0.2325,0.6696,1
0.7800,0.7769,0.8848,0.7055,0.8382,0.2702,0.8277,1
0.6648,0.7128,0.6525,0.6385,0.6721,0.3877,0.6942,1
0.8829,0.9318,0.6089,1.0000,0.8076,0.3234,1.0000,1
0.7517,0.7872,0.7114,0.7061,0.7441,0.1265,0.6770,1
0.7422,0.7665,0.7623,0.6802,0.8118,0.1911,0.6278,1
0.8300,0.8905,0.5762,0.7905,0.8275,0.3787,0.7120,1
0.8064,0.8058,0.8657,0.7230,0.9066,0.1747,0.6918,1
0.8074,0.8678,0.5817,0.7658,0.7890,0.7693,0.7553,1
0.9802,1.0000,0.7060,0.9369,0.9701,0.5086,0.8848,1
0.7998,0.8347,0.7015,0.8542,0.7762,0.1928,0.8095,1
0.7904,0.7831,0.9038,0.6486,0.9031,0.4640,0.6061,1
0.8083,0.8347,0.7341,0.7579,0.8446,0.3015,0.8203,1
0.7838,0.7893,0.8412,0.7477,0.8118,0.3737,0.7125,1
0.8914,0.9277,0.6624,0.8975,0.8746,0.2988,0.8868,1
0.9112,0.9298,0.7405,0.7973,0.9494,0.6678,0.8218,1
0.7129,0.7665,0.6270,0.6532,0.6650,0.3711,0.7346,1
0.5269,0.6136,0.4601,0.4859,0.5396,0.4578,0.5830,1
0.7403,0.7355,0.9038,0.6087,0.8133,0.2885,0.6824,1
0.5099,0.5124,0.8920,0.2613,0.6785,0.3343,0.3077,1
0.7705,0.7789,0.8330,0.6824,0.8831,0.4451,0.7253,1
0.7611,0.8264,0.5599,0.7804,0.6871,0.4715,0.7794,1
0.6978,0.7107,0.8276,0.6081,0.7534,0.1940,0.6893,1
0.9037,0.9545,0.5935,0.9088,0.8147,0.1489,0.8203,1
0.6572,0.6715,0.8258,0.5023,0.7555,0.5982,0.5623,1
0.2342,0.3120,0.3621,0.3226,0.2594,0.5902,0.4313,2
0.2578,0.3161,0.4828,0.3615,0.3158,0.8152,0.4535,2
0.2597,0.3182,0.4891,0.2759,0.3165,0.6800,0.3880,2
0.1539,0.1880,0.5181,0.1830,0.2402,0.6116,0.3456,2
0.1161,0.2045,0.1751,0.2337,0.1048,0.4819,0.3245,2
0.0585,0.1488,0.0780,0.2140,0.0406,0.7026,0.3722,2
0.0793,0.1488,0.2305,0.1560,0.0634,0.1893,0.3018,2
0.1794,0.2169,0.5236,0.2072,0.2402,0.4754,0.2378,2
0.1992,0.2686,0.3721,0.2742,0.2003,0.3244,0.3924,2
0.0189,0.1074,0.0236,0.2354,0.0128,0.6107,0.3323,2
0.1171,0.1694,0.3766,0.2050,0.1497,0.5760,0.3880,2
0.1341,0.2293,0.1525,0.2849,0.1041,0.8096,0.3698,2
0.1577,0.2459,0.2287,0.2866,0.1447,0.5189,0.4141,2
0.0557,0.1302,0.1679,0.1807,0.0449,0.3338,0.2373,2
0.0727,0.1322,0.2731,0.1554,0.0891,0.4269,0.3663,2
0.0567,0.1322,0.1561,0.1976,0.0321,0.6563,0.3447,2
0.0708,0.0950,0.4673,0.0867,0.1561,0.3357,0.2383,2
0.1454,0.2727,0.0000,0.2787,0.0820,0.5279,0.3452,2
0.1095,0.2293,0.0009,0.3069,0.0342,0.4698,0.3895,2
0.0850,0.1674,0.1652,0.2280,0.0463,0.6011,0.3895,2
0.1841,0.2603,0.3122,0.3108,0.1775,0.3013,0.4786,2
0.1350,0.1901,0.3829,0.2539,0.1283,0.4559,0.3885,2
0.1379,0.2066,0.3040,0.2072,0.1547,0.5491,0.2595,2
0.1851,0.2397,0.4328,0.2444,0.2409,0.4751,0.3235,2
0.0519,0.0785,0.4328,0.0631,0.1169,0.7311,0.2610,2
0.1426,0.1529,0.6461,0.1160,0.2217,0.1867,0.2644,2
0.1747,0.2438,0.3457,0.2365,0.1903,0.5408,0.3698,2
0.1473,0.2149,0.3285,0.2917,0.1475,0.3735,0.4032,2
0.0718,0.1467,0.1906,0.1560,0.0271,0.4644,0.3018,2
0.0614,0.1219,0.2523,0.1075,0.0606,0.3583,0.2802,2
0.0406,0.1219,0.0980,0.2399,0.0506,0.7762,0.3171,2
0.0907,0.1426,0.3394,0.1509,0.1532,0.7736,0.2152,2
0.0642,0.1157,0.3067,0.1064,0.0948,0.4608,0.2368,2
0.0765,0.1384,0.2668,0.1334,0.0948,0.6271,0.2806,2
0.0227,0.1136,0.0163,0.2134,0.0078,0.5743,0.3279,2
0.0198,0.0331,0.4619,0.0462,0.1361,0.5211,0.2678,2
0.0633,0.1240,0.2486,0.1616,0.0570,0.5942,0.2821,2
0.0142,0.0661,0.2250,0.1385,0.0086,0.5119,0.2186,2
0.0840,0.1322,0.3557,0.1582,0.0912,0.6645,0.2378,2
0.1530,0.2190,0.3376,0.2579,0.1875,0.1165,0.3245,2
0.0774,0.1116,0.4347,0.1075,0.1033,0.5450,0.1507,2
0.1766,0.2066,0.5672,0.1898,0.2758,0.5489,0.3092,2
0.1511,0.1963,0.4519,0.1920,0.1989,0.5320,0.3146,2
0.1001,0.1364,0.4483,0.1177,0.1568,0.5778,0.3033,2
0.2172,0.2810,0.4174,0.3356,0.2823,0.7047,0.3924,2
0.0916,0.1860,0.1062,0.2613,0.0378,0.4287,0.3264,2
0.1152,0.2149,0.1062,0.2894,0.0613,0.5374,0.4101,2
0.0302,0.0806,0.2641,0.1064,0.0321,0.4439,0.2152,2
0.0604,0.0847,0.4655,0.1070,0.1361,0.8788,0.2157,2
0.0000,0.0000,0.5145,0.0000,0.1119,0.5474,0.1354,2
0.0321,0.0806,0.2804,0.0828,0.0620,0.6024,0.2590,2
0.0642,0.0930,0.4374,0.1081,0.1240,0.4187,0.2373,2
0.1209,0.1260,0.6479,0.1312,0.2302,0.3682,0.3018,2
0.0217,0.0868,0.1588,0.1582,0.0000,0.5315,0.2806,2
0.1435,0.1777,0.5064,0.1898,0.2459,0.4378,0.2427,2
0.2087,0.2190,0.7069,0.1470,0.3535,0.5341,0.1945,2
0.2077,0.2314,0.6397,0.1830,0.3022,0.6134,0.2161,2
0.2625,0.2831,0.6969,0.2370,0.3550,0.5077,0.2816,2
0.1917,0.2603,0.3630,0.2877,0.2003,0.3304,0.3506,2
0.2049,0.2004,0.8013,0.0980,0.3742,0.2682,0.1531,2
```

Test data:

```# wheat_test.txt
#
0.0784,0.0930,0.5463,0.0614,0.1568,0.2516,0.0433,0
0.0604,0.0455,0.6887,0.0017,0.1775,0.1955,0.0906,0
0.1671,0.1612,0.7641,0.0997,0.2937,0.3192,0.0423,0
0.2483,0.2955,0.5436,0.2793,0.3136,0.4410,0.2802,0
0.2068,0.2397,0.5762,0.2044,0.2823,0.0534,0.1295,0
0.2162,0.2252,0.7241,0.1351,0.3485,0.2063,0.0433,0
0.3541,0.4050,0.5853,0.4116,0.3991,0.0712,0.3107,0
0.3229,0.3884,0.4936,0.3998,0.3763,0.1888,0.3018,0
0.3569,0.4091,0.5853,0.3773,0.3728,0.0909,0.3845,0
0.2021,0.2769,0.3421,0.2889,0.1796,0.3599,0.2698,0
0.7280,0.7190,0.9319,0.6081,0.8019,0.2694,0.7105,1
0.7885,0.8079,0.7813,0.7010,0.8517,0.2786,0.7041,1
0.4523,0.5145,0.5672,0.5546,0.4547,0.4807,0.6283,1
0.5260,0.6033,0.5109,0.5327,0.5453,0.4552,0.6283,1
0.4693,0.5124,0.6733,0.4938,0.5545,0.5470,0.6539,1
0.4523,0.4649,0.8249,0.3255,0.5952,0.3686,0.4530,1
0.6393,0.6921,0.6388,0.7016,0.6728,0.3590,0.7149,1
0.4703,0.5661,0.4047,0.5749,0.4284,0.2438,0.6696,1
0.4731,0.5579,0.4528,0.5253,0.4676,0.2548,0.6071,1
0.5326,0.5723,0.6978,0.5479,0.6001,0.3906,0.6908,1
0.1690,0.2128,0.4791,0.1802,0.2559,0.6120,0.2590,2
0.1964,0.1880,0.8131,0.0479,0.3599,0.1996,0.1113,2
0.0557,0.0640,0.5436,0.0619,0.1283,0.4272,0.1521,2
0.1992,0.2066,0.7196,0.1599,0.3286,1.0000,0.2368,2
0.1681,0.2190,0.4410,0.1717,0.2352,0.4101,0.2373,2
0.1511,0.1632,0.6370,0.1340,0.2502,0.3726,0.1728,2
0.0604,0.0971,0.3902,0.1357,0.1176,0.4629,0.2383,2
0.2465,0.2583,0.7278,0.1898,0.4291,0.9817,0.2644,2
0.1180,0.1653,0.3993,0.1554,0.1468,0.3683,0.2585,2
0.1615,0.1921,0.5472,0.1937,0.2452,0.6335,0.2678,2
```

## Example of Kernel Ridge Regression Using the scikit Library

A regression problem is one where the goal is to predict a single numeric value. For example, you might want to predict the income of a person based on their sex, age, State, and political leaning. (Note: Somewhat confusingly, “logistic regression” is a binary classification technique in spite of its name).

The scikit (short for scikit-learn or sklearn) library has a Kernel Ridge Regression (KRR) module to predict a numeric value. KRR is an advanced version of basic linear regression. The “Kernel” in KRR means the technique uses the kernel trick which allows KRR to deal with complex data that’s not linearly separable. The “Ridge” indicates KRR uses ridge regularization to limit model overfitting. I hadn’t looked at KRR in a long time so I decided to code up a quick demo.

I used one of my standard demo datasets that looks like:

```# sex age   state   income   politics
-1  0.27  0 1 0   0.7610   0 0 1
+1  0.19  0 0 1   0.6550   1 0 0
. . .
```

The goal is to predict income from sex, age, State and politics. The sex column is encoded as Male = -1, Female = +1. Ages are divided by 100. The States are Michigan = 100, Nebraska = 010, Oklahoma = 001. Incomes are divided by \$100,000. The politics are conservative = 100, moderate = 010, liberal = 001.

I made a training file with 200 items and a test file with 40 items. The complete data is at: jamesmccaffrey.wordpress.com/2022/10/10/regression-people-income-using-pytorch-1-12-on-windows-10-11/.

Kernel ridge regression is difficult to explain. The technique is based on simple linear regression where each predictor value is multiplied by a weight. But the technique uses a kernel method where a kernel function is applied to each training item and the item to predict. This allows the technique to deal with data that isn’t linearly separable.

The ridge part of the KRR name means that L2 regularization is applied to prevent model overfitting, which kernel techniques are often highly vulnerable to.

After loading the training data into memory, the key statements in my demo program are:

```print("Creating and training KRR poly(4) model ")
model = KernelRidge(alpha=1.0, kernel='poly', degree=4)
model.fit(train_X, train_y)
```

The parameters to the KernelRidge class would take forever to explain in detail, and this is one of the difficulties with using KRR. The kernel function can be one of ‘additive_chi2’, ‘chi2’, ‘linear’, ‘poly’, ‘polynomial’, ‘rbf’, ‘laplacian’, ‘sigmoid’, ‘cosine’ and a good one must be determined by trial and error. The two most common are ‘polynomial’ and ‘rbf’ (radial basis function), but weirdly the default is ‘linear’.

One issue with regression problems is that you must implement a program-defined accuracy function. For a classification problem, a prediction is either correct or wrong. But with regression, when you predict a numeric value, you must specify what is correct prediction is. I defined an accuracy function where a prediction that is within 10% of the true value is considered correct.

I haven’t seen kernel ridge regression used very much. Neural networks are more powerful than KRR, but neural networks require lots of training data and neural networks are more difficult to fine tune.

Kernel ridge regression has been around for a long time — since about 1970. St. Patrick’s Day has been celebrated on March 17 since 1631. Here are three examples of St. Patrick’s Day garb that have varying degrees of sophistication.

Demo code. Replace “lt” (less-than) with Boolean operator symbol.

```# kernel_ridge_regression.py
# PyTorch 1.12.1-CPU Anaconda3-2020.02  Python 3.7.6
# Windows 10/11
# scikit / sklearn 0.22.1

# predict income from sex, age, State, politics

import numpy as np
from sklearn.kernel_ridge import KernelRidge
import pickle

# sex age   state   income   politics
# -1  0.27  0 1 0   0.7610   0 0 1
# +1  0.19  0 0 1   0.6550   1 0 0

# -----------------------------------------------------------

def accuracy(model, data_X, data_y, pct_close):
# correct within pct of true income
n_correct = 0; n_wrong = 0

for i in range(len(data_X)):
X = data_X[i].reshape(1, -1)  # make one-item batch
y = data_y[i]
pred = model.predict(X)       # predicted income

if np.abs(pred - y) "lt" np.abs(pct_close * y):
n_correct += 1
else:
n_wrong += 1
acc = (n_correct * 1.0) / (n_correct + n_wrong)
return acc

# -----------------------------------------------------------

def main():
print("\nBegin kernel ridge regression using scikit demo ")
print("Predict income from sex, age, State, political ")

# 0. prepare
np.random.seed(1)

train_file = ".\\Data\\people_train.txt"
dtype=np.float32)
train_X = train_xy[:,[0,1,2,3,4,6,7,8]]
train_y = train_xy[:,5].flatten()  # 1D required

print("\nX = ")
print(train_X[0:4,:])
print(" . . . ")
print("\ny = ")
print(train_y[0:4])
print(" . . . ")

test_file = ".\\Data\\people_test.txt"
dtype=np.float32)
test_X = test_xy[:,[0,1,2,3,4,6,7,8]]
test_y = test_xy[:,5].flatten()  # 1D required

# -----------------------------------------------------------

# 2. create and train KRR model
print("\nCreating and training KRR poly(4) model ")
# KernelRidge(alpha=1.0, *, kernel='linear', gamma=None,
#   degree=3, coef0=1, kernel_params=None
# ['additive_chi2', 'chi2', 'linear', 'poly', 'polynomial',
#  'rbf', 'laplacian', 'sigmoid', 'cosine']
model = KernelRidge(alpha=1.0, kernel='poly', degree=4)
model.fit(train_X, train_y)

# 3. compute model accuracy
acc_train = accuracy(model, train_X, train_y, 0.10)
print("\nAccuracy on train data = %0.4f " % acc_train)
acc_test = accuracy(model, test_X, test_y, 0.10)
print("Accuracy on test data = %0.4f " % acc_test)

# 4. make a prediction
print("\nPredicting income for M 34 Oklahoma moderate: ")
X = np.array([[-1, 0.34, 0,0,1,  0,1,0]],
dtype=np.float32)
pred_inc = model.predict(X)
print("\$%0.2f" % (pred_inc * 100_000))  # un-normalized

# 5. save model
print("\nSaving model ")
fn = ".\\Models\\krr_model.pkl"
with open(fn,'wb') as f:
pickle.dump(model, f)

# with open(fn, 'rb') as f:
# print("\$%0.2f" % (pi * 100_000))  # un-normalized

print("\nEnd scikit KRR demo ")

if __name__ == "__main__":
main()
```

## Revisiting My Neural Network Regression System with Raw JavaScript

A couple of years ago I implemented a neural network regression system (predict a single numeric value) in raw JavaScript. I enjoy coding, even in raw JavaScript, so one Saturday evening I figured I’d revise my old example.

I didn’t run into any major problems but working with raw JavaScript is always a bit slow. My raw JavaScript neural network can only handle a single hidden layer, but even so the results were pretty good.

I used one of my standard regression problem datasets. The data looks like:

```1  0.24  1  0  0  0.2950  0  0  1
0  0.39  0  0  1  0.5120  0  1  0
1  0.63  0  1  0  0.7580  1  0  0
. . .
```

Each line represents a person. The fields are sex (male = 0, female = 1), age (divided by 100), state (Michigan = 100, Nebraska = 010, Oklahoma = 001), income (divided by \$100,000) and political leaning (conservative = 100, moderate = 010, liberal = 001). The goal is to predict income from sex, age, state, and politics.

Implementing a neural network from scratch (in any language) is difficult. My implementation is several hundred lines of code long so I can’t present it in its entirety in this blog post.

```let U = require("../../Utilities/utilities_lib.js");
let FS = require("fs");

function main()
{
console.log("Begin binary classification demo  ");

//  1  0.24  1  0  0  0.2950  0  0  1
// -1  0.39  0  0  1  0.5120  0  1  0

"\t", [0,1,2,3,4,6,7,8], "//");
"\t", [5], "//");
"\t", [0,1,2,3,4,6,7,8], "//");
"\t", [5], "//");
. . .
```

And creating and training the network is:

```  // 2. create network
console.log("\nCreating 8-100-1 tanh, Identity NN ");
let seed = 0;
let nn = new NeuralNet(8, 100, 1, seed);

// 3. train network
let lrnRate = 0.005;
let maxEpochs = 500;
console.log("\nStarting train learn rate = " +
lrnRate.toString());
nn.train(trainX, trainY, lrnRate, maxEpochs);
console.log("Done ");
. . .
```

Notice that I made train() a method that belongs to the NeuralNet class rather than a standalone function that accepts a neural net object. Design decisions like this are often more difficult than coding implementation. Anyway, it was a good way to spend a Saturday evening.

I’ve always been fascinated by aircraft design. Left: The German Albatross D5 (1917) featured an early streamlined design with a spinner fairing over the propeller. The Albatross could fly at 115 mph. Center: Just 20 years later, the Vought Corsair F4U (1937) featured an inverted gull wing design to allow a huge propeller. The Corsair could fly at 445 mph. Right: And just another 20 years later, the Convair F-106 Delta Dart (1957) featured a delta shaped wing for large wing leading edge angle for speed with large surface area for lift. The F-106 could fly at 1,525 mph.

## Computing and Displaying a Confusion Matrix for a PyTorch Neural Network Multi-Class Classifier

After training a PyTorch multi-class classifier, it’s important to evaluate the accuracy of the trained model. Simple classification accuracy is OK but in many scenarios you want a so-called confusion matrix that gives details of the number of correct and wrong predictions for each target class label.

For example, suppose you’re predicting the political leaning (0 = conservative, 1 = moderate, 2 = liberal) of a person based on their sex (-1 = male, 1 = female), age (divided by 100), State (Michigan = 100, Nebraska = 010, Oklahoma = 001), and income (divided by \$100,000). An example of a formatted confusion matrix might look like:

```Computing model accuracy
Accuracy on training data = 0.8150
Accuracy on test data = 0.7500

Computing raw confusion matrix:
[[ 6  4  1]
[ 1 13  0]
[ 2  2 11]]

Formatted version:
actual     0:  6  4  1
actual     1:  1 13  0
actual     2:  2  2 11
------------
predicted      0  1  2
```

Here’s my function to compute a raw confusion matrix for a multi-class classifier:

```def confusion_matrix_multi(model, ds, n_classes):
if n_classes "lte" = 2:  # less-than-or-equal
print("ERROR: n_classes must be 3 or greater ")
return None

cm = np.zeros((n_classes,n_classes), dtype=np.int64)
for i in range(len(ds)):
X = ds[i][0].reshape(1,-1)  # make it a batch
Y = ds[i][1].reshape(1)  # actual class 0 1 or 2, 1D
oupt = model(X)  # logits form
pred_class = T.argmax(oupt)  # 0,1,2
cm[Y][pred_class] += 1
return cm
```

The function accepts a trained PyTorch classifier and a PyTorch Dataset object that is composed of either a Tuple or a Dictionary where the predictors are at [0] and the target labels are at [1]. The n_classes could be determined programmatically but it’s easier to pass that value in as a parameter.

Note: A function to compute a confusion matrix for a PyTorch binary classifier, where there are just two possible outcomes, uses slightly different code.

The raw confusion matrix is difficult to interpret so I wrote a function to format the matrix by adding some labels:

```def show_confusion(cm):
dim = len(cm)
mx = np.max(cm)             # largest count in cm
wid = len(str(mx)) + 1      # width to print
fmt = "%" + str(wid) + "d"  # like "%3d"
for i in range(dim):
print("actual   ", end="")
print("%3d:" % i, end="")
for j in range(dim):
print(fmt % cm[i][j], end="")
print("")
print("------------")
print("predicted    ", end="")
for j in range(dim):
print(fmt % j, end="")
print("")
```

If you do have a confusion matrix, it’s possible to compute an overall accuracy metric from it. The total number of data items is the sum of the entries in the matrix. The number of correct predictions is the sum of the entries on the main diagonal. For example:

```def accuracy_from_confusion_multi(cm):
# return (overall accuracy, list of class accuracies)
N = np.sum(cm)  # total count
dim = len(cm)
row_sums = cm.sum(axis=1)  # collapse on cols, process rows

n_correct = 0
for i in range(dim):
n_correct += cm[i][i]  # on the diagonal
overall = n_correct / N

class_accs = []
for i in range(dim):
class_accs.append(cm[i][i] / row_sums[i])

return (overall, class_accs)
```

The output for the demo data looks like:

```Computing test accuracies from confusion matrix
Accuracy on test data = 0.7500
Class accuracies:
0 = 0.5455
1 = 0.9286
2 = 0.7333
```

Good fun. Demo code below. The training and test data can be found at https://jamesmccaffrey.wordpress.com/2022/09/01/multi-class-classification-using-pytorch-1-12-1-on-windows-10-11/.

Easter is coming soon. I grew up Catholic so Easter is important to me. Here are three photos of child confusion caused by disturbing Easter Bunnies.

```# people_politics.py
# predict politics type from sex, age, state, income
# PyTorch 1.13.1-CPU Anaconda3-2022.10  Python 3.9.13
# Windows 10/11

import numpy as np
import torch as T
device = T.device('cpu')  # apply to Tensor or Module

# -----------------------------------------------------------

class PeopleDataset(T.utils.data.Dataset):
# sex  age    state    income   politics
# -1   0.27   0  1  0   0.7610   2
# +1   0.19   0  0  1   0.6550   0
# sex: -1 = male, +1 = female
# politics: conservative, moderate, liberal

def __init__(self, src_file):
tmp_x = all_xy[:,0:6]   # cols [0,6) = [0,5]
tmp_y = all_xy[:,6]     # 1-D

self.x_data = T.tensor(tmp_x,
dtype=T.float32).to(device)
self.y_data = T.tensor(tmp_y,
dtype=T.int64).to(device)  # 1-D

def __len__(self):
return len(self.x_data)

def __getitem__(self, idx):
preds = self.x_data[idx]
trgts = self.y_data[idx]
return preds, trgts  # as a Tuple

# -----------------------------------------------------------

class Net(T.nn.Module):
def __init__(self):
super(Net, self).__init__()
self.hid1 = T.nn.Linear(6, 10)  # 6-(10-10)-3
self.hid2 = T.nn.Linear(10, 10)
self.oupt = T.nn.Linear(10, 3)

T.nn.init.xavier_uniform_(self.hid1.weight)
T.nn.init.zeros_(self.hid1.bias)
T.nn.init.xavier_uniform_(self.hid2.weight)
T.nn.init.zeros_(self.hid2.bias)
T.nn.init.xavier_uniform_(self.oupt.weight)
T.nn.init.zeros_(self.oupt.bias)

def forward(self, x):
z = T.tanh(self.hid1(x))
z = T.tanh(self.hid2(z))
z = T.log_softmax(self.oupt(z), dim=1)  # NLLLoss()
return z

# -----------------------------------------------------------

def accuracy(model, ds):
# assumes model.eval()
# item-by-item version
n_correct = 0; n_wrong = 0
for i in range(len(ds)):
X = ds[i][0].reshape(1,-1)  # make it a batch
Y = ds[i][1].reshape(1)  # 0 1 or 2, 1D
oupt = model(X)  # logits form

big_idx = T.argmax(oupt)  # 0 or 1 or 2
if big_idx == Y:
n_correct += 1
else:
n_wrong += 1

acc = (n_correct * 1.0) / (n_correct + n_wrong)
return acc

# -----------------------------------------------------------

def accuracy_quick(model, dataset):
# assumes model.eval()
X = dataset[0:len(dataset)][0]
# Y = T.flatten(dataset[0:len(dataset)][1])
Y = dataset[0:len(dataset)][1]
oupt = model(X)
# (_, arg_maxs) = T.max(oupt, dim=1)
arg_maxs = T.argmax(oupt, dim=1)  # argmax() is new
num_correct = T.sum(Y==arg_maxs)
acc = (num_correct * 1.0 / len(dataset))
return acc.item()

# -----------------------------------------------------------

def confusion_matrix_multi(model, ds, n_classes):
if n_classes "lte" 2:  # less-than-or-equal
print("ERROR: n_classes must be 3 or greater ")
return None

cm = np.zeros((n_classes,n_classes), dtype=np.int64)
for i in range(len(ds)):
X = ds[i][0].reshape(1,-1)  # make it a batch
Y = ds[i][1].reshape(1)  # actual class 0 1 or 2, 1D
oupt = model(X)  # logits form
pred_class = T.argmax(oupt)  # 0,1,2
cm[Y][pred_class] += 1
return cm

# -----------------------------------------------------------

def show_confusion(cm):
dim = len(cm)
mx = np.max(cm)             # largest count in cm
wid = len(str(mx)) + 1      # width to print
fmt = "%" + str(wid) + "d"  # like "%3d"
for i in range(dim):
print("actual   ", end="")
print("%3d:" % i, end="")
for j in range(dim):
print(fmt % cm[i][j], end="")
print("")
print("------------")
print("predicted    ", end="")
for j in range(dim):
print(fmt % j, end="")
print("")

# -----------------------------------------------------------

def accuracy_from_confusion_multi(cm):
# return (overall accuracy, list of class accuracies)
N = np.sum(cm)  # total count
dim = len(cm)
row_sums = cm.sum(axis=1)

n_correct = 0
for i in range(dim):
n_correct += cm[i][i]  # on the diagonal
overall = n_correct / N

class_accs = []
for i in range(dim):
class_accs.append(cm[i][i] / row_sums[i])

return (overall, class_accs)

# -----------------------------------------------------------

def main():
# 0. get started
print("\nBegin People predict politics type ")
T.manual_seed(1)
np.random.seed(1)

print("\nCreating People Datasets ")

train_file = ".\\Data\\people_train.txt"
train_ds = PeopleDataset(train_file)  # 200 rows

test_file = ".\\Data\\people_test.txt"
test_ds = PeopleDataset(test_file)    # 40 rows

bat_size = 10
batch_size=bat_size, shuffle=True)

# -----------------------------------------------------------

# 2. create and compile network
print("\nCreating 6-(10-10)-3 neural network ")
net = Net().to(device)
net.train()

# -----------------------------------------------------------

# 3. train model
max_epochs = 1000
ep_log_interval = 200
lrn_rate = 0.01

loss_func = T.nn.NLLLoss()  # assumes log_softmax()
optimizer = T.optim.SGD(net.parameters(), lr=lrn_rate)

print("\nbat_size = %3d " % bat_size)
print("loss = " + str(loss_func))
print("optimizer = SGD")
print("max_epochs = %3d " % max_epochs)
print("lrn_rate = %0.3f " % lrn_rate)

print("\nStarting training ")
for epoch in range(0, max_epochs):
# T.manual_seed(epoch+1)  # checkpoint reproducibility
epoch_loss = 0  # for one full epoch

for (batch_idx, batch) in enumerate(train_ldr):
X = batch[0]  # inputs
Y = batch[1]  # correct class/label/politics

oupt = net(X)
loss_val = loss_func(oupt, Y)  # a tensor
epoch_loss += loss_val.item()  # accumulate
loss_val.backward()
optimizer.step()

if epoch % ep_log_interval == 0:
print("epoch = %5d  |  loss = %10.4f" % \
(epoch, epoch_loss))

print("Training done ")

# -----------------------------------------------------------

# 4. evaluate model accuracy
print("\nComputing model accuracy")
net.eval()
acc_train = accuracy(net, train_ds)  # item-by-item
print("Accuracy on training data = %0.4f" % acc_train)
acc_test = accuracy(net, test_ds)
print("Accuracy on test data = %0.4f" % acc_test)

# 4b. confusion matrix
print("\nComputing confusion matrix ")
cm = confusion_matrix_multi(net, test_ds, n_classes=3)
# print(cm)  # raw matrix
print("Formatted version: \n")
show_confusion(cm)

# 4c. acuracy metrics from confusion
print("\nComputing test accuracies from confusion matrix ")
(test_acc, class_accs) = accuracy_from_confusion_multi(cm)
print("Accuracy on test data = %0.4f" % test_acc)
print("Class accuracies: ")
for i in range(len(cm)):
print("%4d = %0.4f " % (i, class_accs[i]))

# 5. make a prediction
print("\nPredicting politics for M  30  oklahoma  \$50,000: ")
X = np.array([[-1, 0.30,  0,0,1,  0.5000]], dtype=np.float32)
X = T.tensor(X, dtype=T.float32).to(device)

logits = net(X)  # do not sum to 1.0
probs = T.exp(logits)  # sum to 1.0
probs = probs.numpy()  # numpy vector prints better
np.set_printoptions(precision=4, suppress=True)
print(probs)

# 6. save model (state_dict approach)
print("\nSaving trained model state ")
# fn = ".\\Models\\people_model.pt"
# T.save(net.state_dict(), fn)

# model = Net()  # requires class definition
# use model to make prediction(s)

print("\nEnd People predict politics demo ")

if __name__ == "__main__":
main()
```

## The Wheat Seed Problem Using k-NN Classification With the scikit Library

One of my work laptops died so I tried to reimage it by reinstalling everything, including OS, from the ground up. While that was going on, I decided to entertain myself by doing a k-nearest neighbors (k-NN) classification example using the scikit library on the Wheat Seeds dataset (on one of my working machines).

The k-NN (k-nearest neighbors) classification technique is intended only for data that has strictly numeric (i.e., no categorical) predictor variables. The raw Wheat Seeds data came from archive.ics.uci.edu/ml/datasets/seeds and looks like:

```15.26  14.84  0.871   5.763  3.312  2.221  5.22   1
14.88  14.57  0.8811  5.554  3.333  1.018  4.956  1
. . .
17.63  15.98  0.8673  6.191  3.561  4.076  6.06   2
16.84  15.67  0.8623  5.998  3.484  4.675  5.877  2
. . .
11.84  13.21  0.8521  5.175  2.836  3.598  5.044  3
12.3   13.34  0.8684  5.243  2.974  5.637  5.063  3
```
```---------------------------------------------------
10.59  12.41  0.8081  4.899  2.63   0.765  4.519 (min values)
21.18  17.25  0.9183  6.675  4.033  8.456  6.55  (max values)
```

There are 210 data items. Each represents one of three species of wheat seeds: Kama, Rosa, Canadian. There are 70 of each species. The first 7 values on each line are the predictors: area, perimeter, compactness, length, width, asymmetry, groove. The eighth value is the one-based encoded species. The goal is to predict species from the seven predictor values.

When using the k-NN classification technique, it’s important to normalize the numeric predictors so that they all have roughly the same magnitude so that a predictor with large values doesn’t overwhelm other predictor values. As is often the case in machine learning, data preparation took most of the time an effort of my exploration.

I dropped the raw data into an Excel spreadsheet. For each predictor, I computed the min and max values of the column. Then I performed min-max normalization where each value x in a column is normalized to x’ = (x – min) / (max – min). The result is that each predictor is a value between 0.0 and 1.0.

I recoded the target class labels from one-based to zero-based. The resulting 210-item dataset looks like:

```0.4410  0.5021  0.5708  0.4865  0.4861  0.1893  0.3452  0
0.4051  0.4463  0.6624  0.3688  0.5011  0.0329  0.2152  0
. . .
0.6648  0.7376  0.5372  0.7275  0.6636  0.4305  0.7587  1
0.5902  0.6736  0.4918  0.6188  0.6087  0.5084  0.6686  1
. . .
0.1917  0.2603  0.3630  0.2877  0.2003  0.3304  0.3506  2
0.2049  0.2004  0.8013  0.0980  0.3742  0.2682  0.1531  2
```

I split the 210-item normalized data into a 180-item training set and a 30-item test set. I used the first 60 of each target class for training and the last 10 of each target class for testing.

Using scikit is easy. After loading the training and test data into memory, a k-NN multi-class classification model is created and trained like so:

```  k = 7
print("Creating kNN model, with k=" + str(k) )
model = KNeighborsClassifier(n_neighbors=k, algorithm='brute')
model.fit(train_X, train_y)
print("Done ")
```

The default number of nearest neighbors is k=5 but I used k=7 which gave more representative results.

Well, I wasn’t able to revive my dead laptop but I had fun with k-NN classification..

Three interesting photos of unusual-looking seed pods. Left: Brachychiton rupestris tree (Australia). Center: Cojoba arborea tree (Mexico). Right: Nelumbo nucifera (water lotus) plant (Asia).

Demo code:

```# wheat_knn.py

# predict wheat seed species (0=Kama, 1=Rosa, 2=Canadian)
# from area, perimeter, compactness, length, width,
#   asymmetry, groove

# Anaconda3-2020.02  Python 3.7.6  scikit 0.22.1
# Windows 10/11

import numpy as np
from sklearn.neighbors import KNeighborsClassifier

# ---------------------------------------------------------

def show_confusion(cm):
dim = len(cm)
mx = np.max(cm)             # largest count in cm
wid = len(str(mx)) + 1      # width to print
fmt = "%" + str(wid) + "d"  # like "%3d"
for i in range(dim):
print("actual   ", end="")
print("%3d:" % i, end="")
for j in range(dim):
print(fmt % cm[i][j], end="")
print("")
print("------------")
print("predicted    ", end="")
for j in range(dim):
print(fmt % j, end="")
print("")

# ---------------------------------------------------------

def main():
# 0. prepare
print("\nBegin Wheat Seeds k-NN using scikit ")
np.set_printoptions(precision=4, suppress=True)
np.random.seed(1)

# 0.4410  0.5021  0.5708  0.4865  0.4861  0.1893  0.3452  0
# 0.4051  0.4463  0.6624  0.3688  0.5011  0.0329  0.2152  0
# . . .
# 0.1917  0.2603  0.3630  0.2877  0.2003  0.3304  0.3506  2
# 0.2049  0.2004  0.8013  0.0980  0.3742  0.2682  0.1531  2

train_file = ".\\Data\\wheat_train.txt"  # 180 items

test_file = ".\\Data\\wheat_test.txt"  # 30 items

print("\nTraining data:")
print(train_X[0:4])
print(". . . \n")
print(train_y[0:4])
print(". . . ")

# 2. create and train model
# KNeighborsClassifier(n_neighbors=5, *, weights='uniform',
#   algorithm='auto', leaf_size=30, p=2, metric='minkowski',
#   metric_params=None, n_jobs=None
# algorithm: 'ball_tree', 'kd_tree', 'brute', 'auto'.

k = 7
print("\nCreating kNN model, with k=" + str(k) )
model = KNeighborsClassifier(n_neighbors=k, algorithm='brute')
model.fit(train_X, train_y)
print("Done ")

# 3. evaluate model
train_acc = model.score(train_X, train_y)
test_acc= model.score(test_X, test_y)
print("\nAccuracy on train data = %0.4f " % train_acc)
print("Accuracy on test data = %0.4f " % test_acc)

from sklearn.metrics import confusion_matrix
y_predicteds = model.predict(test_X)
cm = confusion_matrix(test_y, y_predicteds)
print("\nConfusion matrix: \n")
# print(cm)
show_confusion(cm)  # custom formatted

# 4. use model
X = np.array([[0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5]],
dtype=np.float32)
print("\nPredicting class for: ")
print(X)
probs = model.predict_proba(X)
print("\nPrediction probs: ")
print(probs)

predicted = model.predict(X)
print("\nPredicted class: ")
print(predicted)

# 5. TODO: save model using pickle
import pickle
print("\nSaving trained kNN model ")
# path = ".\\Models\\wheat_knn_model.sav"
# pickle.dump(model, open(path, "wb"))

# usage:
# X = np.array([[0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5]],
#   dtype=np.int64)
# with open(path, 'rb') as f:
# print(pa)

print("\nEnd demo ")

if __name__ == "__main__":
main()
```

Training data. Replace commas with tabs or modify program code.

```# wheat_train.txt
#
# http://archive.ics.uci.edu/ml/datasets/seeds
# 210 total items. train is first 60 each of 3 classes
# 180 training, 30 test
# area, perimeter, compactness, length, width, asymmetry, groove
# predictors are all min-max normalized
# 0 = Kama, 1 = Rosa, 2 = Canadian
#
0.4410,0.5021,0.5708,0.4865,0.4861,0.1893,0.3452,0
0.4051,0.4463,0.6624,0.3688,0.5011,0.0329,0.2152,0
0.3494,0.3471,0.8793,0.2207,0.5039,0.2515,0.1507,0
0.3069,0.3161,0.7931,0.2393,0.5339,0.1942,0.1408,0
0.5241,0.5331,0.8648,0.4274,0.6643,0.0767,0.3230,0
0.3579,0.3719,0.7895,0.2742,0.4861,0.2206,0.2152,0
0.3872,0.4298,0.6515,0.3739,0.4483,0.3668,0.3447,0
0.3324,0.3492,0.7532,0.2934,0.4790,0.2516,0.2368,0
0.5703,0.6302,0.6044,0.6498,0.5952,0.1658,0.6686,0
0.5524,0.5868,0.7250,0.5546,0.6237,0.1565,0.4993,0
0.4410,0.5041,0.5581,0.4589,0.4362,0.4912,0.3914,0
0.3248,0.3616,0.6488,0.3035,0.4070,0.1238,0.2373,0
0.3116,0.3326,0.7250,0.3041,0.4056,0.4188,0.1078,0
0.3012,0.3409,0.6152,0.3266,0.3749,0.3083,0.1738,0
0.2975,0.3388,0.6016,0.3283,0.3450,0.2817,0.1507,0
0.3777,0.3864,0.8276,0.2545,0.5011,0.4447,0.1290,0
0.3211,0.2934,1.0000,0.1239,0.5367,0.5811,0.1290,0
0.4816,0.4835,0.8866,0.3536,0.6301,0.1084,0.2595,0
0.3881,0.3719,0.9728,0.1723,0.5959,0.1303,0.0640,0
0.2011,0.2397,0.5490,0.1841,0.2986,0.4339,0.1945,0
0.3371,0.4112,0.4564,0.4274,0.3557,0.3000,0.3235,0
0.3324,0.3822,0.5817,0.3497,0.3835,0.2500,0.3447,0
0.4995,0.5145,0.8230,0.4048,0.6251,0.0000,0.2816,0
0.1407,0.1694,0.5290,0.1126,0.2181,0.0845,0.2176,0
0.4174,0.4855,0.5227,0.5011,0.4383,0.1334,0.2373,0
0.5288,0.5682,0.6969,0.5259,0.5638,0.0179,0.3880,0
0.2295,0.2789,0.5082,0.2793,0.2823,0.3391,0.1507,0
0.2030,0.2603,0.4383,0.2793,0.2324,0.2261,0.1723,0
0.3324,0.3657,0.6706,0.3615,0.4212,0.2586,0.2555,0
0.2701,0.3326,0.4746,0.3474,0.3100,0.3596,0.2846,0
0.2427,0.2913,0.5272,0.3125,0.2459,0.0117,0.2644,0
0.4627,0.5227,0.5835,0.4831,0.5282,0.3442,0.3491,0
0.3305,0.4132,0.4065,0.4606,0.3963,0.4102,0.3840,0
0.3163,0.3636,0.5871,0.3863,0.3706,0.1767,0.2427,0
0.4212,0.4690,0.6334,0.4578,0.4975,0.1773,0.4141,0
0.5222,0.5351,0.8339,0.4561,0.6094,0.1957,0.4549,0
0.5297,0.5909,0.5926,0.5220,0.5944,0.2676,0.4963,0
0.6128,0.6136,0.9056,0.5253,0.7505,0.2849,0.4751,0
0.3975,0.4360,0.6733,0.4262,0.4690,0.3052,0.3890,0
0.3484,0.3636,0.7831,0.2804,0.4761,0.7697,0.2373,0
0.2786,0.2975,0.7169,0.2528,0.3749,0.2369,0.3245,0
0.2748,0.2975,0.6996,0.2545,0.3763,0.1929,0.3235,0
0.2427,0.2355,0.8421,0.1346,0.4070,0.2205,0.1300,0
0.4636,0.5062,0.6706,0.5507,0.5460,0.5131,0.4968,0
0.4268,0.4401,0.8212,0.3829,0.5930,0.3072,0.3255,0
0.3031,0.3368,0.6470,0.2686,0.3742,0.1034,0.2176,0
0.4504,0.4855,0.7078,0.4516,0.5438,0.0783,0.3018,0
0.4155,0.4442,0.7278,0.3778,0.5324,0.2851,0.3230,0
0.3966,0.4360,0.6697,0.3637,0.4711,0.2521,0.2915,0
0.4032,0.4669,0.5399,0.4386,0.4476,0.1773,0.4097,0
0.3626,0.4112,0.6080,0.3863,0.4576,0.4174,0.3077,0
0.4901,0.5165,0.7641,0.4364,0.5731,0.6277,0.3038,0
0.3683,0.4545,0.4147,0.4595,0.3443,0.4357,0.4318,0
0.3532,0.3864,0.6806,0.3407,0.4056,0.3332,0.3471,0
0.3711,0.4525,0.4319,0.4741,0.3443,0.0931,0.4766,0
0.4193,0.4876,0.5236,0.4521,0.4148,0.1519,0.4530,0
0.3654,0.4008,0.6688,0.2753,0.5324,0.2648,0.2585,0
0.4089,0.4174,0.8394,0.2731,0.5574,0.0490,0.2802,0
0.4523,0.4876,0.7042,0.4296,0.5624,0.1604,0.3461,0
0.1435,0.2190,0.2822,0.1464,0.2865,0.0958,0.0000,0
0.6648,0.7376,0.5372,0.7275,0.6636,0.4305,0.7587,1
0.5902,0.6736,0.4918,0.6188,0.6087,0.5084,0.6686,1
0.6298,0.6860,0.6189,0.6075,0.6871,0.4907,0.6263,1
0.8045,0.7955,0.9074,0.7066,0.9266,0.2823,0.7681,1
0.5883,0.6405,0.6397,0.6295,0.6101,0.4211,0.6509,1
0.5836,0.6632,0.5054,0.5788,0.5759,0.5402,0.6283,1
0.6355,0.7231,0.4701,0.6560,0.5510,0.3977,0.6908,1
0.9556,0.9959,0.6189,0.9459,0.8439,0.4793,0.9513,1
0.7885,0.8430,0.6071,0.8705,0.7192,0.5590,0.9074,1
0.6166,0.6488,0.7359,0.5355,0.6671,0.2721,0.6041,1
0.5609,0.6054,0.6733,0.5495,0.5966,0.6198,0.6701,1
0.7677,0.7810,0.8131,0.6233,0.8746,0.5928,0.6696,1
0.9075,0.9256,0.7377,0.7804,0.8795,0.5731,0.8213,1
0.8480,0.8946,0.6334,0.8361,0.8140,0.0919,0.8636,1
0.8423,0.8884,0.6343,0.8260,0.8346,0.2856,0.8203,1
0.7252,0.7603,0.7160,0.7173,0.7277,0.2182,0.8262,1
0.7828,0.7955,0.8058,0.6672,0.8083,0.1149,0.7829,1
0.7923,0.8781,0.4619,0.9291,0.7413,0.3804,0.9744,1
1.0000,0.9917,0.8240,0.9426,1.0000,0.6521,0.8429,1
0.9717,0.9587,0.8621,0.8733,0.9993,0.5527,0.8872,1
0.8980,0.9463,0.6034,0.9471,0.8232,0.1547,0.9503,1
0.7715,0.7831,0.8194,0.7168,0.8311,0.3062,0.7553,1
0.7762,0.8017,0.7486,0.7731,0.7577,0.3214,0.7553,1
0.7554,0.7521,0.8938,0.6408,0.8767,0.6808,0.6686,1
0.7337,0.8492,0.3367,0.9949,0.6094,0.5419,0.9498,1
0.5930,0.6694,0.5145,0.6982,0.5937,0.3811,0.7129,1
0.8234,0.8636,0.6661,0.8119,0.8411,0.3526,0.8464,1
0.7923,0.8595,0.5499,0.8727,0.6572,0.1793,0.9522,1
0.7158,0.7955,0.5045,0.7725,0.6287,0.2715,0.8636,1
0.7677,0.8120,0.6615,0.7432,0.7512,0.1850,0.7770,1
0.5496,0.5868,0.7123,0.4611,0.6379,0.4488,0.5411,1
0.6988,0.7128,0.8267,0.5580,0.7584,0.1694,0.6489,1
0.8376,0.8450,0.8203,0.6836,0.8995,0.4607,0.7336,1
0.8111,0.8719,0.5771,0.8277,0.7491,0.3370,0.8419,1
0.7894,0.8285,0.6788,0.7596,0.8019,0.3384,0.8021,1
0.7781,0.8017,0.7586,0.6408,0.8239,0.2325,0.6696,1
0.7800,0.7769,0.8848,0.7055,0.8382,0.2702,0.8277,1
0.6648,0.7128,0.6525,0.6385,0.6721,0.3877,0.6942,1
0.8829,0.9318,0.6089,1.0000,0.8076,0.3234,1.0000,1
0.7517,0.7872,0.7114,0.7061,0.7441,0.1265,0.6770,1
0.7422,0.7665,0.7623,0.6802,0.8118,0.1911,0.6278,1
0.8300,0.8905,0.5762,0.7905,0.8275,0.3787,0.7120,1
0.8064,0.8058,0.8657,0.7230,0.9066,0.1747,0.6918,1
0.8074,0.8678,0.5817,0.7658,0.7890,0.7693,0.7553,1
0.9802,1.0000,0.7060,0.9369,0.9701,0.5086,0.8848,1
0.7998,0.8347,0.7015,0.8542,0.7762,0.1928,0.8095,1
0.7904,0.7831,0.9038,0.6486,0.9031,0.4640,0.6061,1
0.8083,0.8347,0.7341,0.7579,0.8446,0.3015,0.8203,1
0.7838,0.7893,0.8412,0.7477,0.8118,0.3737,0.7125,1
0.8914,0.9277,0.6624,0.8975,0.8746,0.2988,0.8868,1
0.9112,0.9298,0.7405,0.7973,0.9494,0.6678,0.8218,1
0.7129,0.7665,0.6270,0.6532,0.6650,0.3711,0.7346,1
0.5269,0.6136,0.4601,0.4859,0.5396,0.4578,0.5830,1
0.7403,0.7355,0.9038,0.6087,0.8133,0.2885,0.6824,1
0.5099,0.5124,0.8920,0.2613,0.6785,0.3343,0.3077,1
0.7705,0.7789,0.8330,0.6824,0.8831,0.4451,0.7253,1
0.7611,0.8264,0.5599,0.7804,0.6871,0.4715,0.7794,1
0.6978,0.7107,0.8276,0.6081,0.7534,0.1940,0.6893,1
0.9037,0.9545,0.5935,0.9088,0.8147,0.1489,0.8203,1
0.6572,0.6715,0.8258,0.5023,0.7555,0.5982,0.5623,1
0.2342,0.3120,0.3621,0.3226,0.2594,0.5902,0.4313,2
0.2578,0.3161,0.4828,0.3615,0.3158,0.8152,0.4535,2
0.2597,0.3182,0.4891,0.2759,0.3165,0.6800,0.3880,2
0.1539,0.1880,0.5181,0.1830,0.2402,0.6116,0.3456,2
0.1161,0.2045,0.1751,0.2337,0.1048,0.4819,0.3245,2
0.0585,0.1488,0.0780,0.2140,0.0406,0.7026,0.3722,2
0.0793,0.1488,0.2305,0.1560,0.0634,0.1893,0.3018,2
0.1794,0.2169,0.5236,0.2072,0.2402,0.4754,0.2378,2
0.1992,0.2686,0.3721,0.2742,0.2003,0.3244,0.3924,2
0.0189,0.1074,0.0236,0.2354,0.0128,0.6107,0.3323,2
0.1171,0.1694,0.3766,0.2050,0.1497,0.5760,0.3880,2
0.1341,0.2293,0.1525,0.2849,0.1041,0.8096,0.3698,2
0.1577,0.2459,0.2287,0.2866,0.1447,0.5189,0.4141,2
0.0557,0.1302,0.1679,0.1807,0.0449,0.3338,0.2373,2
0.0727,0.1322,0.2731,0.1554,0.0891,0.4269,0.3663,2
0.0567,0.1322,0.1561,0.1976,0.0321,0.6563,0.3447,2
0.0708,0.0950,0.4673,0.0867,0.1561,0.3357,0.2383,2
0.1454,0.2727,0.0000,0.2787,0.0820,0.5279,0.3452,2
0.1095,0.2293,0.0009,0.3069,0.0342,0.4698,0.3895,2
0.0850,0.1674,0.1652,0.2280,0.0463,0.6011,0.3895,2
0.1841,0.2603,0.3122,0.3108,0.1775,0.3013,0.4786,2
0.1350,0.1901,0.3829,0.2539,0.1283,0.4559,0.3885,2
0.1379,0.2066,0.3040,0.2072,0.1547,0.5491,0.2595,2
0.1851,0.2397,0.4328,0.2444,0.2409,0.4751,0.3235,2
0.0519,0.0785,0.4328,0.0631,0.1169,0.7311,0.2610,2
0.1426,0.1529,0.6461,0.1160,0.2217,0.1867,0.2644,2
0.1747,0.2438,0.3457,0.2365,0.1903,0.5408,0.3698,2
0.1473,0.2149,0.3285,0.2917,0.1475,0.3735,0.4032,2
0.0718,0.1467,0.1906,0.1560,0.0271,0.4644,0.3018,2
0.0614,0.1219,0.2523,0.1075,0.0606,0.3583,0.2802,2
0.0406,0.1219,0.0980,0.2399,0.0506,0.7762,0.3171,2
0.0907,0.1426,0.3394,0.1509,0.1532,0.7736,0.2152,2
0.0642,0.1157,0.3067,0.1064,0.0948,0.4608,0.2368,2
0.0765,0.1384,0.2668,0.1334,0.0948,0.6271,0.2806,2
0.0227,0.1136,0.0163,0.2134,0.0078,0.5743,0.3279,2
0.0198,0.0331,0.4619,0.0462,0.1361,0.5211,0.2678,2
0.0633,0.1240,0.2486,0.1616,0.0570,0.5942,0.2821,2
0.0142,0.0661,0.2250,0.1385,0.0086,0.5119,0.2186,2
0.0840,0.1322,0.3557,0.1582,0.0912,0.6645,0.2378,2
0.1530,0.2190,0.3376,0.2579,0.1875,0.1165,0.3245,2
0.0774,0.1116,0.4347,0.1075,0.1033,0.5450,0.1507,2
0.1766,0.2066,0.5672,0.1898,0.2758,0.5489,0.3092,2
0.1511,0.1963,0.4519,0.1920,0.1989,0.5320,0.3146,2
0.1001,0.1364,0.4483,0.1177,0.1568,0.5778,0.3033,2
0.2172,0.2810,0.4174,0.3356,0.2823,0.7047,0.3924,2
0.0916,0.1860,0.1062,0.2613,0.0378,0.4287,0.3264,2
0.1152,0.2149,0.1062,0.2894,0.0613,0.5374,0.4101,2
0.0302,0.0806,0.2641,0.1064,0.0321,0.4439,0.2152,2
0.0604,0.0847,0.4655,0.1070,0.1361,0.8788,0.2157,2
0.0000,0.0000,0.5145,0.0000,0.1119,0.5474,0.1354,2
0.0321,0.0806,0.2804,0.0828,0.0620,0.6024,0.2590,2
0.0642,0.0930,0.4374,0.1081,0.1240,0.4187,0.2373,2
0.1209,0.1260,0.6479,0.1312,0.2302,0.3682,0.3018,2
0.0217,0.0868,0.1588,0.1582,0.0000,0.5315,0.2806,2
0.1435,0.1777,0.5064,0.1898,0.2459,0.4378,0.2427,2
0.2087,0.2190,0.7069,0.1470,0.3535,0.5341,0.1945,2
0.2077,0.2314,0.6397,0.1830,0.3022,0.6134,0.2161,2
0.2625,0.2831,0.6969,0.2370,0.3550,0.5077,0.2816,2
0.1917,0.2603,0.3630,0.2877,0.2003,0.3304,0.3506,2
0.2049,0.2004,0.8013,0.0980,0.3742,0.2682,0.1531,2
```

Test data:

```# wheat_test.txt
#
0.0784,0.0930,0.5463,0.0614,0.1568,0.2516,0.0433,0
0.0604,0.0455,0.6887,0.0017,0.1775,0.1955,0.0906,0
0.1671,0.1612,0.7641,0.0997,0.2937,0.3192,0.0423,0
0.2483,0.2955,0.5436,0.2793,0.3136,0.4410,0.2802,0
0.2068,0.2397,0.5762,0.2044,0.2823,0.0534,0.1295,0
0.2162,0.2252,0.7241,0.1351,0.3485,0.2063,0.0433,0
0.3541,0.4050,0.5853,0.4116,0.3991,0.0712,0.3107,0
0.3229,0.3884,0.4936,0.3998,0.3763,0.1888,0.3018,0
0.3569,0.4091,0.5853,0.3773,0.3728,0.0909,0.3845,0
0.2021,0.2769,0.3421,0.2889,0.1796,0.3599,0.2698,0
0.7280,0.7190,0.9319,0.6081,0.8019,0.2694,0.7105,1
0.7885,0.8079,0.7813,0.7010,0.8517,0.2786,0.7041,1
0.4523,0.5145,0.5672,0.5546,0.4547,0.4807,0.6283,1
0.5260,0.6033,0.5109,0.5327,0.5453,0.4552,0.6283,1
0.4693,0.5124,0.6733,0.4938,0.5545,0.5470,0.6539,1
0.4523,0.4649,0.8249,0.3255,0.5952,0.3686,0.4530,1
0.6393,0.6921,0.6388,0.7016,0.6728,0.3590,0.7149,1
0.4703,0.5661,0.4047,0.5749,0.4284,0.2438,0.6696,1
0.4731,0.5579,0.4528,0.5253,0.4676,0.2548,0.6071,1
0.5326,0.5723,0.6978,0.5479,0.6001,0.3906,0.6908,1
0.1690,0.2128,0.4791,0.1802,0.2559,0.6120,0.2590,2
0.1964,0.1880,0.8131,0.0479,0.3599,0.1996,0.1113,2
0.0557,0.0640,0.5436,0.0619,0.1283,0.4272,0.1521,2
0.1992,0.2066,0.7196,0.1599,0.3286,1.0000,0.2368,2
0.1681,0.2190,0.4410,0.1717,0.2352,0.4101,0.2373,2
0.1511,0.1632,0.6370,0.1340,0.2502,0.3726,0.1728,2
0.0604,0.0971,0.3902,0.1357,0.1176,0.4629,0.2383,2
0.2465,0.2583,0.7278,0.1898,0.4291,0.9817,0.2644,2
0.1180,0.1653,0.3993,0.1554,0.1468,0.3683,0.2585,2
0.1615,0.1921,0.5472,0.1937,0.2452,0.6335,0.2678,2
```

## Regression (People Income) Using a scikit MLPRegressor Neural Network

The scikit-learn library was originally designed for classical machine learning techniques like logistic regression and naive Bayes classification. The library eventually added the ability to do binary and multi-class classification via the MLPClassifier (multi-layer perceptron) class and regression via the MLPRegressor class. As best as I can determine by wading through the scikit change logs, these two classes were added in version 0.18 in early 2017.

I decided to take a look at regression using the scikit MLPRegressor class.

In my work environment, when I need to tackle a regression problem (i.e., predict a single numeric value such as a person’s annual income), I use PyTorch. PyTorch is very complex but it gives me the flexibility I need and PyTorch can do much more sophisticated things than scikit, notably image classification, natural language processing, unsupervised anomaly detection, and Transformer architecture systems.

But scikit is easy to use and makes sense in some scenarios.

My data is synthetic and looks like:

``` 1   0.24   1 0 0   0.2950   0 0 1
-1   0.39   0 0 1   0.5120   0 1 0
1   0.63   0 1 0   0.7580   1 0 0
-1   0.36   1 0 0   0.4450   0 1 0
1   0.27   0 1 0   0.2860   0 0 1
. . .
```

There are 200 training items and 40 test items.

The first value in column [0] is sex (M = -1, F = +1). Column [1] is age, normalized by dividing by 100. Columns [2,3,4] is State one-hot encoded (Michigan = 100, Nebraska = 010, Oklahoma = 001). Column [5] is annual income, divided by \$100,000, and is the value to predict. Columns [6,7,8] is political leaning (conservative = 100, moderate = 010, liberal = 001).

Setting up a scikit MLP regressor is daunting because there are a lot of parameters:

```  params = { 'hidden_layer_sizes' : [10,10],
'activation' : 'relu',
'alpha' : 0.0,
'batch_size' : 10,
'random_state' : 0,
'tol' : 0.0001,
'nesterovs_momentum' : False,
'learning_rate' : 'constant',
'learning_rate_init' : 0.01,
'max_iter' : 1000,
'shuffle' : True,
'n_iter_no_change' : 50,
'verbose' : False }

print("Creating 8-(10-10)-1 tanh neural network ")
net = MLPRegressor(**params)
```

My demo implements an accuracy() function. For most scikit classes, there is a score() function that gives simple accuracy but with regression you must specify what a correct prediction is — for example within 10% of the correct target value.

Good fun.

Predicting income is a difficult task with real data. For jobs that rely mostly on tips, such as golf course beverage cart driver, predicting income is especially difficult. I suspect the cart driver on the left makes more money from tips than the cart driver on the right.

Demo code. Replace “lt”, “gt”, “lte”, “gte” with Boolean operator symbols.

```# people_income_nn_sckit.py

# predict income
# from sex, age, state, politics

# sex  age   state   income  politics
#  1  0.24  1  0  0  0.2950  0  0  1
# -1  0.39  0  0  1  0.5120  0  1  0
# state: michigan = 100, nebraska = 010, oklahoma = 001
# conservative = 100, moderate = 010, liberal = 001

# Anaconda3-2020.02  Python 3.7.6  scikit 0.22.1
# Windows 10/11

import numpy as np
from sklearn.neural_network import MLPRegressor
import warnings
warnings.filterwarnings('ignore')  # early-stop warnings

# ---------------------------------------------------------

def accuracy(model, data_x, data_y, pct_close=0.10):
# accuracy predicted within pct_close of actual income
# item-by-item allows inspection but is slow
n_correct = 0; n_wrong = 0
predicteds = model.predict(data_x)  # all predicteds
for i in range(len(predicteds)):
actual = data_y[i]
pred = predicteds[i]

if np.abs(pred - actual) "lt" np.abs(pct_close * actual):
n_correct += 1
else:
n_wrong += 1
acc = (n_correct * 1.0) / (n_correct + n_wrong)
return acc

# ---------------------------------------------------------

def accuracy_q(model, data_x, data_y, pct_close=0.10):
# accuracy within pct_close of actual income
# all-at-once is quick
n_items = len(data_y)
preds = model.predict(data_x)  # all predicteds

n_correct = np.sum((np.abs(preds - data_y) "lt" \
np.abs(pct_close * data_y)))
result = (n_correct / n_items)
return result

# ---------------------------------------------------------

def main():
print("\nBegin scikit neural network regression example ")
print("Predict income from sex, age, State, politics ")
np.random.seed(1)
np.set_printoptions(precision=4, suppress=True)

train_file = ".\\Data\\people_train.txt"
train_x = train_xy[:,[0,1,2,3,4,6,7,8]]
train_y = train_xy[:,5]

test_file = ".\\Data\\people_test.txt"
test_x = test_xy[:,[0,1,2,3,4,6,7,8]]
test_y = test_xy[:,5]

print("\nTraining data:")
print(train_x[0:4])
print(". . . \n")
print(train_y[0:4])
print(". . . ")

# ---------------------------------------------------------

# 2. create network
# MLPRegressor(hidden_layer_sizes=(100,),
#  batch_size='auto', learning_rate='constant',
#  learning_rate_init=0.001, power_t=0.5, max_iter=200,
#  shuffle=True, random_state=None, tol=0.0001,
#  verbose=False, warm_start=False, momentum=0.9,
#  nesterovs_momentum=True, early_stopping=False,
#  validation_fraction=0.1, beta_1=0.9, beta_2=0.999,
#  epsilon=1e-08, n_iter_no_change=10, max_fun=15000)

params = { 'hidden_layer_sizes' : [10,10],
'activation' : 'relu',
'alpha' : 0.0,
'batch_size' : 10,
'random_state' : 0,
'tol' : 0.0001,
'nesterovs_momentum' : False,
'learning_rate' : 'constant',
'learning_rate_init' : 0.01,
'max_iter' : 1000,
'shuffle' : True,
'n_iter_no_change' : 50,
'verbose' : False }

print("\nCreating 8-(10-10)-1 relu neural network ")
net = MLPRegressor(**params)

# ---------------------------------------------------------

# 3. train
print("\nTraining with bat sz = " + \
str(params['batch_size']) + " lrn rate = " + \
str(params['learning_rate_init']) + " ")
print("Stop if no change " + \
str(params['n_iter_no_change']) + " iterations ")
net.fit(train_x, train_y)
print("Done ")

# ---------------------------------------------------------

# 4. evaluate model
# score() is coefficient of determination for MLPRegressor
print("\nCompute model accuracy (within 0.10 of actual) ")
acc_train = accuracy(net, train_x, train_y, 0.10)
print("\nAccuracy on train = %0.4f " % acc_train)
acc_test = accuracy(net, test_x, test_y, 0.10)
print("Accuracy on test = %0.4f " % acc_test)

# print("\nModel accuracy quick (within 0.10 of actual) ")
# acc_train = accuracy_q(net, train_x, train_y, 0.10)
# print("\nAccuracy on train = %0.4f " % acc_train)
# acc_test = accuracy_q(net, test_x, test_y, 0.10)
# print("Accuracy on test = %0.4f " % acc_test)

# ---------------------------------------------------------

# 5. use model
# no proba() for MLPRegressor
print("\nSetting X = M 34 Oklahoma moderate: ")
X = np.array([[-1, 0.34, 0,0,1,  0,1,0]])
income = net.predict(X)  # divided by 100,000
income *= 100000  # denormalize
print("Predicted income: %0.2f " % income)

# ---------------------------------------------------------

# 6. TODO: save model using pickle
# import pickle
# print("Saving trained network ")
# path = ".\\Models\\people_income_net.sav"
# pickle.dump(model, open(path, "wb"))

# use saved model
# X = np.array([[-1, 0.34, 0,0,1,  0,1,0]]],
#   dtype=np.float32)
# with open(path, 'rb') as f:
# print(inc)

print("\nEnd scikit binary neural network demo ")

if __name__ == "__main__":
main()
```

Training data. Replace commas with tabs or modify program.

```# people_train.txt
#
# sex (-1 = male, 1 = female), age / 100,
# state (michigan = 100, nebraska = 010, oklahoma = 001)
# income / 100_000,
# conservative = 100, moderate = 010, liberal = 001
#
1,0.24,1,0,0,0.2950,0,0,1
-1,0.39,0,0,1,0.5120,0,1,0
1,0.63,0,1,0,0.7580,1,0,0
-1,0.36,1,0,0,0.4450,0,1,0
1,0.27,0,1,0,0.2860,0,0,1
1,0.50,0,1,0,0.5650,0,1,0
1,0.50,0,0,1,0.5500,0,1,0
-1,0.19,0,0,1,0.3270,1,0,0
1,0.22,0,1,0,0.2770,0,1,0
-1,0.39,0,0,1,0.4710,0,0,1
1,0.34,1,0,0,0.3940,0,1,0
-1,0.22,1,0,0,0.3350,1,0,0
1,0.35,0,0,1,0.3520,0,0,1
-1,0.33,0,1,0,0.4640,0,1,0
1,0.45,0,1,0,0.5410,0,1,0
1,0.42,0,1,0,0.5070,0,1,0
-1,0.33,0,1,0,0.4680,0,1,0
1,0.25,0,0,1,0.3000,0,1,0
-1,0.31,0,1,0,0.4640,1,0,0
1,0.27,1,0,0,0.3250,0,0,1
1,0.48,1,0,0,0.5400,0,1,0
-1,0.64,0,1,0,0.7130,0,0,1
1,0.61,0,1,0,0.7240,1,0,0
1,0.54,0,0,1,0.6100,1,0,0
1,0.29,1,0,0,0.3630,1,0,0
1,0.50,0,0,1,0.5500,0,1,0
1,0.55,0,0,1,0.6250,1,0,0
1,0.40,1,0,0,0.5240,1,0,0
1,0.22,1,0,0,0.2360,0,0,1
1,0.68,0,1,0,0.7840,1,0,0
-1,0.60,1,0,0,0.7170,0,0,1
-1,0.34,0,0,1,0.4650,0,1,0
-1,0.25,0,0,1,0.3710,1,0,0
-1,0.31,0,1,0,0.4890,0,1,0
1,0.43,0,0,1,0.4800,0,1,0
1,0.58,0,1,0,0.6540,0,0,1
-1,0.55,0,1,0,0.6070,0,0,1
-1,0.43,0,1,0,0.5110,0,1,0
-1,0.43,0,0,1,0.5320,0,1,0
-1,0.21,1,0,0,0.3720,1,0,0
1,0.55,0,0,1,0.6460,1,0,0
1,0.64,0,1,0,0.7480,1,0,0
-1,0.41,1,0,0,0.5880,0,1,0
1,0.64,0,0,1,0.7270,1,0,0
-1,0.56,0,0,1,0.6660,0,0,1
1,0.31,0,0,1,0.3600,0,1,0
-1,0.65,0,0,1,0.7010,0,0,1
1,0.55,0,0,1,0.6430,1,0,0
-1,0.25,1,0,0,0.4030,1,0,0
1,0.46,0,0,1,0.5100,0,1,0
-1,0.36,1,0,0,0.5350,1,0,0
1,0.52,0,1,0,0.5810,0,1,0
1,0.61,0,0,1,0.6790,1,0,0
1,0.57,0,0,1,0.6570,1,0,0
-1,0.46,0,1,0,0.5260,0,1,0
-1,0.62,1,0,0,0.6680,0,0,1
1,0.55,0,0,1,0.6270,1,0,0
-1,0.22,0,0,1,0.2770,0,1,0
-1,0.50,1,0,0,0.6290,1,0,0
-1,0.32,0,1,0,0.4180,0,1,0
-1,0.21,0,0,1,0.3560,1,0,0
1,0.44,0,1,0,0.5200,0,1,0
1,0.46,0,1,0,0.5170,0,1,0
1,0.62,0,1,0,0.6970,1,0,0
1,0.57,0,1,0,0.6640,1,0,0
-1,0.67,0,0,1,0.7580,0,0,1
1,0.29,1,0,0,0.3430,0,0,1
1,0.53,1,0,0,0.6010,1,0,0
-1,0.44,1,0,0,0.5480,0,1,0
1,0.46,0,1,0,0.5230,0,1,0
-1,0.20,0,1,0,0.3010,0,1,0
-1,0.38,1,0,0,0.5350,0,1,0
1,0.50,0,1,0,0.5860,0,1,0
1,0.33,0,1,0,0.4250,0,1,0
-1,0.33,0,1,0,0.3930,0,1,0
1,0.26,0,1,0,0.4040,1,0,0
1,0.58,1,0,0,0.7070,1,0,0
1,0.43,0,0,1,0.4800,0,1,0
-1,0.46,1,0,0,0.6440,1,0,0
1,0.60,1,0,0,0.7170,1,0,0
-1,0.42,1,0,0,0.4890,0,1,0
-1,0.56,0,0,1,0.5640,0,0,1
-1,0.62,0,1,0,0.6630,0,0,1
-1,0.50,1,0,0,0.6480,0,1,0
1,0.47,0,0,1,0.5200,0,1,0
-1,0.67,0,1,0,0.8040,0,0,1
-1,0.40,0,0,1,0.5040,0,1,0
1,0.42,0,1,0,0.4840,0,1,0
1,0.64,1,0,0,0.7200,1,0,0
-1,0.47,1,0,0,0.5870,0,0,1
1,0.45,0,1,0,0.5280,0,1,0
-1,0.25,0,0,1,0.4090,1,0,0
1,0.38,1,0,0,0.4840,1,0,0
1,0.55,0,0,1,0.6000,0,1,0
-1,0.44,1,0,0,0.6060,0,1,0
1,0.33,1,0,0,0.4100,0,1,0
1,0.34,0,0,1,0.3900,0,1,0
1,0.27,0,1,0,0.3370,0,0,1
1,0.32,0,1,0,0.4070,0,1,0
1,0.42,0,0,1,0.4700,0,1,0
-1,0.24,0,0,1,0.4030,1,0,0
1,0.42,0,1,0,0.5030,0,1,0
1,0.25,0,0,1,0.2800,0,0,1
1,0.51,0,1,0,0.5800,0,1,0
-1,0.55,0,1,0,0.6350,0,0,1
1,0.44,1,0,0,0.4780,0,0,1
-1,0.18,1,0,0,0.3980,1,0,0
-1,0.67,0,1,0,0.7160,0,0,1
1,0.45,0,0,1,0.5000,0,1,0
1,0.48,1,0,0,0.5580,0,1,0
-1,0.25,0,1,0,0.3900,0,1,0
-1,0.67,1,0,0,0.7830,0,1,0
1,0.37,0,0,1,0.4200,0,1,0
-1,0.32,1,0,0,0.4270,0,1,0
1,0.48,1,0,0,0.5700,0,1,0
-1,0.66,0,0,1,0.7500,0,0,1
1,0.61,1,0,0,0.7000,1,0,0
-1,0.58,0,0,1,0.6890,0,1,0
1,0.19,1,0,0,0.2400,0,0,1
1,0.38,0,0,1,0.4300,0,1,0
-1,0.27,1,0,0,0.3640,0,1,0
1,0.42,1,0,0,0.4800,0,1,0
1,0.60,1,0,0,0.7130,1,0,0
-1,0.27,0,0,1,0.3480,1,0,0
1,0.29,0,1,0,0.3710,1,0,0
-1,0.43,1,0,0,0.5670,0,1,0
1,0.48,1,0,0,0.5670,0,1,0
1,0.27,0,0,1,0.2940,0,0,1
-1,0.44,1,0,0,0.5520,1,0,0
1,0.23,0,1,0,0.2630,0,0,1
-1,0.36,0,1,0,0.5300,0,0,1
1,0.64,0,0,1,0.7250,1,0,0
1,0.29,0,0,1,0.3000,0,0,1
-1,0.33,1,0,0,0.4930,0,1,0
-1,0.66,0,1,0,0.7500,0,0,1
-1,0.21,0,0,1,0.3430,1,0,0
1,0.27,1,0,0,0.3270,0,0,1
1,0.29,1,0,0,0.3180,0,0,1
-1,0.31,1,0,0,0.4860,0,1,0
1,0.36,0,0,1,0.4100,0,1,0
1,0.49,0,1,0,0.5570,0,1,0
-1,0.28,1,0,0,0.3840,1,0,0
-1,0.43,0,0,1,0.5660,0,1,0
-1,0.46,0,1,0,0.5880,0,1,0
1,0.57,1,0,0,0.6980,1,0,0
-1,0.52,0,0,1,0.5940,0,1,0
-1,0.31,0,0,1,0.4350,0,1,0
-1,0.55,1,0,0,0.6200,0,0,1
1,0.50,1,0,0,0.5640,0,1,0
1,0.48,0,1,0,0.5590,0,1,0
-1,0.22,0,0,1,0.3450,1,0,0
1,0.59,0,0,1,0.6670,1,0,0
1,0.34,1,0,0,0.4280,0,0,1
-1,0.64,1,0,0,0.7720,0,0,1
1,0.29,0,0,1,0.3350,0,0,1
-1,0.34,0,1,0,0.4320,0,1,0
-1,0.61,1,0,0,0.7500,0,0,1
1,0.64,0,0,1,0.7110,1,0,0
-1,0.29,1,0,0,0.4130,1,0,0
1,0.63,0,1,0,0.7060,1,0,0
-1,0.29,0,1,0,0.4000,1,0,0
-1,0.51,1,0,0,0.6270,0,1,0
-1,0.24,0,0,1,0.3770,1,0,0
1,0.48,0,1,0,0.5750,0,1,0
1,0.18,1,0,0,0.2740,1,0,0
1,0.18,1,0,0,0.2030,0,0,1
1,0.33,0,1,0,0.3820,0,0,1
-1,0.20,0,0,1,0.3480,1,0,0
1,0.29,0,0,1,0.3300,0,0,1
-1,0.44,0,0,1,0.6300,1,0,0
-1,0.65,0,0,1,0.8180,1,0,0
-1,0.56,1,0,0,0.6370,0,0,1
-1,0.52,0,0,1,0.5840,0,1,0
-1,0.29,0,1,0,0.4860,1,0,0
-1,0.47,0,1,0,0.5890,0,1,0
1,0.68,1,0,0,0.7260,0,0,1
1,0.31,0,0,1,0.3600,0,1,0
1,0.61,0,1,0,0.6250,0,0,1
1,0.19,0,1,0,0.2150,0,0,1
1,0.38,0,0,1,0.4300,0,1,0
-1,0.26,1,0,0,0.4230,1,0,0
1,0.61,0,1,0,0.6740,1,0,0
1,0.40,1,0,0,0.4650,0,1,0
-1,0.49,1,0,0,0.6520,0,1,0
1,0.56,1,0,0,0.6750,1,0,0
-1,0.48,0,1,0,0.6600,0,1,0
1,0.52,1,0,0,0.5630,0,0,1
-1,0.18,1,0,0,0.2980,1,0,0
-1,0.56,0,0,1,0.5930,0,0,1
-1,0.52,0,1,0,0.6440,0,1,0
-1,0.18,0,1,0,0.2860,0,1,0
-1,0.58,1,0,0,0.6620,0,0,1
-1,0.39,0,1,0,0.5510,0,1,0
-1,0.46,1,0,0,0.6290,0,1,0
-1,0.40,0,1,0,0.4620,0,1,0
-1,0.60,1,0,0,0.7270,0,0,1
1,0.36,0,1,0,0.4070,0,0,1
1,0.44,1,0,0,0.5230,0,1,0
1,0.28,1,0,0,0.3130,0,0,1
1,0.54,0,0,1,0.6260,1,0,0
```

Test data.

```# people_test.txt
#
-1,0.51,1,0,0,0.6120,0,1,0
-1,0.32,0,1,0,0.4610,0,1,0
1,0.55,1,0,0,0.6270,1,0,0
1,0.25,0,0,1,0.2620,0,0,1
1,0.33,0,0,1,0.3730,0,0,1
-1,0.29,0,1,0,0.4620,1,0,0
1,0.65,1,0,0,0.7270,1,0,0
-1,0.43,0,1,0,0.5140,0,1,0
-1,0.54,0,1,0,0.6480,0,0,1
1,0.61,0,1,0,0.7270,1,0,0
1,0.52,0,1,0,0.6360,1,0,0
1,0.3,0,1,0,0.3350,0,0,1
1,0.29,1,0,0,0.3140,0,0,1
-1,0.47,0,0,1,0.5940,0,1,0
1,0.39,0,1,0,0.4780,0,1,0
1,0.47,0,0,1,0.5200,0,1,0
-1,0.49,1,0,0,0.5860,0,1,0
-1,0.63,0,0,1,0.6740,0,0,1
-1,0.3,1,0,0,0.3920,1,0,0
-1,0.61,0,0,1,0.6960,0,0,1
-1,0.47,0,0,1,0.5870,0,1,0
1,0.3,0,0,1,0.3450,0,0,1
-1,0.51,0,0,1,0.5800,0,1,0
-1,0.24,1,0,0,0.3880,0,1,0
-1,0.49,1,0,0,0.6450,0,1,0
1,0.66,0,0,1,0.7450,1,0,0
-1,0.65,1,0,0,0.7690,1,0,0
-1,0.46,0,1,0,0.5800,1,0,0
-1,0.45,0,0,1,0.5180,0,1,0
-1,0.47,1,0,0,0.6360,1,0,0
-1,0.29,1,0,0,0.4480,1,0,0
-1,0.57,0,0,1,0.6930,0,0,1
-1,0.2,1,0,0,0.2870,0,0,1
-1,0.35,1,0,0,0.4340,0,1,0
-1,0.61,0,0,1,0.6700,0,0,1
-1,0.31,0,0,1,0.3730,0,1,0
1,0.18,1,0,0,0.2080,0,0,1
1,0.26,0,0,1,0.2920,0,0,1
-1,0.28,1,0,0,0.3640,0,0,1
-1,0.59,0,0,1,0.6940,0,0,1
```

## The House Voting Dataset Problem Using PyTorch

A somewhat unusual machine learning problem scenario is one where the predictor variables are all Boolean. This is sometimes called Bernoulli classification. The most well-known example (to me anyway) of this type of problem is the House Voting dataset. I put together a demo using PyTorch.

The raw data looks like:

```republican,n,y,n,y,y,y,n,n,n,y,?,y,y,y,n,y
democrat,n,y,y,n,y,y,n,n,n,n,n,n,y,y,y,y
democrat,n,y,n,y,y,y,n,n,n,n,n,n,?,y,y,y
republican,n,y,n,y,y,y,n,n,n,n,n,y,y,y,n,y
. . .
```

There are 435 data items, corresponding to each of the 435 members of the U.S. House of Representatives. The first column is the member’s political pary, Democrat or Republican. The next 16 values correspond to a vote on a particular bill. The possible values are ‘n’ (no), ‘y’ (yes), or ‘?’ (abstain).

I removed all the data items that had one or more ‘?’ values, which left me with 232 items. I encoded democrat as 0, republican as 1, and ‘n’ as 0, ‘y’ as 1. The result looks like:

```0,0,1,1,0,1,1,0,0,0,0,0,0,1,1,1,1
1,0,1,0,1,1,1,0,0,0,0,0,1,1,1,0,1
. . .
```

I split the data into the first 200 items for training data, and the last 32 items for test data.

My PyTorch neural network has a 16-(10-10)-1 architecture with Xavier initialization, tanh hidden activation, and logistic sigmoid output activation. For training, I used a batch size of 10, BCELoss loss function, and SGD optimization.

The trained model scored 0.9950 accuracy on the training data (199 out of 200 correct), and 0.9062 accuracy on the test data (29 out of 32 correct).

One thing that’s interesting about problems that have all Boolean predictor variables is that in addition to all the normal classification techniques, there are two specialized techniques: Winnow classification and Bernoulli Naive Bayes classification. I’ll show those techniques at some point in the future.

Internet image searches. Left: “Democratic party”. Center: “Republican party”. Right: “Party party”.

Demo code:

```# vote_pytorch.py
# House Reps Voting Dataset binary classification
# PyTorch 1.12.1-CPU Anaconda3-2020.02  Python 3.7.6
# Windows 10/11

import numpy as np
import torch as T
device = T.device('cpu')  # apply to Tensor or Module

#  0=democrat, 1 = republican
#  0=no, 1=yes

def __init__(self, src_file):

x_data = all_data[:,1:17]  # 16 binary predictors
y_data = all_data[:,0]     # target 0 or 1
y_data = y_data.reshape(-1,1)  # 2-D required

self.x_data = T.tensor(x_data,
dtype=T.float32).to(device)
self.y_data = T.tensor(y_data,
dtype=T.float32).to(device)

def __len__(self):
return len(self.x_data)

def __getitem__(self, idx):
votes = self.x_data[idx,:]  # idx row, all 8 cols
party = self.y_data[idx,:]  # idx row, the only col
return votes, party       # as a Tuple

# ---------------------------------------------------------

class Net(T.nn.Module):
def __init__(self):
super(Net, self).__init__()
self.hid1 = T.nn.Linear(16, 10)  # 16-(10-10)-1
self.hid2 = T.nn.Linear(10, 10)
self.oupt = T.nn.Linear(10, 1)

T.nn.init.xavier_uniform_(self.hid1.weight)
T.nn.init.zeros_(self.hid1.bias)
T.nn.init.xavier_uniform_(self.hid2.weight)
T.nn.init.zeros_(self.hid2.bias)
T.nn.init.xavier_uniform_(self.oupt.weight)
T.nn.init.zeros_(self.oupt.bias)

def forward(self, x):
z = T.tanh(self.hid1(x))
z = T.tanh(self.hid2(z))
z = T.sigmoid(self.oupt(z))  # for BCELoss()
return z

# ---------------------------------------------------------

def metrics(model, ds, thresh=0.5):
# note: N = total number of items = TP + FP + TN + FN
# accuracy  = (TP + TN)  / N
# precision = TP / (TP + FP)
# recall    = TP / (TP + FN)
# F1        = 2 / [(1 / precision) + (1 / recall)]

tp = 0; tn = 0; fp = 0; fn = 0
for i in range(len(ds)):
inpts = ds[i][0]         # dictionary style
target = ds[i][1]        # float32  [0.0] or [1.0]
target = target.int()    # int 0 or 1
p = model(inpts)       # between 0.0 and 1.0

# FP: "falsely predicted to be positive"
# FN: "falsely predicted to be negative"
if target == 1 and p "gte" thresh:    # TP
tp += 1
elif target == 1 and p "lt" thresh:   # FN
fn += 1
elif target == 0 and p "lt" thresh:   # TN
tn += 1
elif target == 0 and p "gte" thresh:  # FP
fp += 1

N = tp + fp + tn + fn
if N != len(ds):
print("FATAL LOGIC ERROR in metrics()")

accuracy = (tp + tn) / (N * 1.0)
precision = (1.0 * tp) / (tp + fp)  # tp + fp != 0
recall = (1.0 * tp) / (tp + fn)     # tp + fn != 0
f1 = 2.0 / ((1.0 / precision) + (1.0 / recall))
return (accuracy, precision, recall, f1)  # as a Tuple

# ---------------------------------------------------------

def main():
# 0. get started
print("\nVoting Dataset using PyTorch ")
T.manual_seed(1)
np.random.seed(1)

# 1. create Dataset and DataLoader objects
print("\nCreating Voting train and test Datasets ")

train_ds = VotesDataset(train_file)  # 200 rows
test_ds = VotesDataset(test_file)    # 32 rows

bat_size = 10
batch_size=bat_size, shuffle=True)

# 2. create neural network
print("\nCreating 16-(10-10)-1 NN classifier \n")
net = Net().to(device)
net.train()  # set training mode

# 3. train network
lrn_rate = 0.01
loss_func = T.nn.BCELoss()  # binary cross entropy
# loss_func = T.nn.MSELoss()
optimizer = T.optim.SGD(net.parameters(),
lr=lrn_rate)
max_epochs = 500
ep_log_interval = 100

print("Loss function: " + str(loss_func))
print("Optimizer: " + str(optimizer.__class__.__name__))
print("Learn rate: " + "%0.3f" % lrn_rate)
print("Batch size: " + str(bat_size))
print("Max epochs: " + str(max_epochs))

print("\nStarting training")
for epoch in range(0, max_epochs):
epoch_loss = 0.0            # for one full epoch
for (batch_idx, batch) in enumerate(train_ldr):
X = batch[0]             # [bs,8]  inputs
Y = batch[1]             # [bs,1]  targets
oupt = net(X)            # [bs,1]  computeds

loss_val = loss_func(oupt, Y)   # a tensor
epoch_loss += loss_val.item()  # accumulate
optimizer.step()      # update all weights

if epoch % ep_log_interval == 0:
print("epoch = %4d   loss = %8.4f" % \
(epoch, epoch_loss))
print("Done ")

# ---------------------------------------------------------

# 4. evaluate model
net.eval()
metrics_train = metrics(net, train_ds, thresh=0.5)
print("\nMetrics for train data: ")
print("accuracy  = %0.4f " % metrics_train[0])
print("precision = %0.4f " % metrics_train[1])
print("recall    = %0.4f " % metrics_train[2])
print("F1        = %0.4f " % metrics_train[3])

metrics_test = metrics(net, test_ds, thresh=0.5)
print("\nMetrics for test data: ")
print("accuracy  = %0.4f " % metrics_test[0])
print("precision = %0.4f " % metrics_test[1])
print("recall    = %0.4f " % metrics_test[2])
print("F1        = %0.4f " % metrics_test[3])

# 5. save model
print("\nSaving trained model state_dict ")
net.eval()
# path = ".\\Models\\voting_model.pt"
# T.save(net.state_dict(), path)

# 6. make a prediction
print("\nSetting dummy voting data ")
x = np.array([[1,1,1,1, 0,0,0,0,
1,1,1,1, 0,0,0,0,]], dtype=np.float32)
print(x)
x = T.tensor(x, dtype=T.float32).to(device)

net.eval()
oupt = net(x)    # a Tensor
pred_prob = oupt.item()  # scalar, [0.0, 1.0]
print("\nComputed output: ", end="")
print("%0.4f" % pred_prob)

if pred_prob "lt" 0.5:
print("Prediction = Democrat ")
else:
print("Prediction = Repblican ")

print("\nEnd Voting classification demo ")

if __name__== "__main__":
main()
```

Training data:

```# votes_train.txt
# 1st col = democrat (0), republican (1)
# next 16 cols = no vote (0) or yes (1)
# 200 rows (32 in test)
#
0,0,1,1,0,1,1,0,0,0,0,0,0,1,1,1,1
1,0,1,0,1,1,1,0,0,0,0,0,1,1,1,0,1
0,1,1,1,0,0,0,1,1,1,0,1,0,0,0,1,1
0,1,1,1,0,0,0,1,1,1,0,0,0,0,0,1,1
0,1,0,1,0,0,0,1,1,1,1,0,0,0,0,1,1
0,1,0,1,0,0,0,1,1,1,0,1,0,0,0,1,1
0,1,1,1,0,0,0,1,1,1,0,1,0,0,0,1,1
1,1,0,0,1,1,0,1,1,1,0,0,1,1,1,0,1
0,1,1,1,0,0,0,1,1,1,0,1,0,0,0,1,1
1,0,1,0,1,1,1,0,0,0,0,0,1,1,1,0,0
0,1,1,1,0,0,0,1,1,1,1,0,0,1,0,1,1
1,0,1,0,1,1,1,0,0,0,0,0,1,1,1,0,1
0,1,1,1,0,0,0,1,1,1,0,0,0,0,0,1,1
1,0,1,0,1,1,1,0,0,0,0,0,1,1,1,0,0
1,1,1,0,1,1,1,0,0,0,0,0,0,1,1,0,1
1,0,1,0,1,1,1,0,0,0,1,0,1,1,1,0,0
0,1,0,1,0,0,0,1,1,1,1,1,0,1,0,1,1
0,1,0,1,0,0,0,1,1,1,0,0,0,0,0,0,1
0,1,0,1,0,0,0,1,1,1,0,0,0,0,0,1,1
0,1,1,1,0,0,0,1,1,1,0,0,0,0,0,0,1
0,1,1,1,0,0,0,1,1,0,0,0,0,0,1,0,1
0,1,1,1,0,0,0,1,1,1,0,1,0,0,0,1,1
1,1,1,0,1,1,1,0,0,0,1,0,1,1,1,0,0
1,0,1,0,1,1,1,0,0,0,1,1,1,1,1,0,0
1,0,1,0,1,1,1,0,0,0,1,1,1,1,1,0,1
1,0,1,0,1,1,1,0,0,0,1,0,1,1,1,0,1
1,0,1,0,1,1,1,0,0,0,1,0,1,1,1,0,1
1,0,1,0,1,1,1,0,0,0,0,0,1,1,1,0,0
0,1,1,1,0,0,0,1,1,1,0,1,0,0,0,0,1
1,1,1,0,1,1,1,1,0,0,0,0,1,1,1,0,1
1,0,1,0,1,1,1,1,0,0,0,1,1,1,1,0,1
1,0,1,0,1,1,1,0,0,0,1,0,1,1,1,0,0
0,1,1,1,0,0,0,1,1,1,0,0,0,0,0,1,1
1,1,1,1,1,0,0,1,1,1,1,1,0,0,1,0,1
1,1,0,1,1,1,0,1,0,1,1,0,0,1,1,0,1
0,1,0,1,0,0,1,1,1,1,1,1,0,0,1,1,1
0,0,1,1,1,1,1,0,0,0,1,1,0,1,1,0,0
0,0,1,1,1,1,1,0,1,1,1,1,1,1,1,0,1
0,1,1,1,0,1,1,0,0,0,1,1,0,1,1,0,1
1,0,0,0,1,1,0,0,0,0,1,0,1,1,1,0,0
1,0,0,0,1,1,1,0,0,0,1,0,1,1,1,0,1
1,0,0,0,1,1,1,0,0,0,0,0,1,1,1,0,0
0,0,0,1,0,1,1,0,0,0,1,1,1,1,1,0,1
1,0,0,0,1,1,1,0,0,0,1,0,1,1,1,0,0
1,0,0,0,1,1,1,0,0,0,0,0,1,1,1,0,0
0,0,1,1,0,1,1,1,0,1,1,1,0,1,1,0,1
0,1,0,1,0,0,0,1,1,1,1,0,0,0,0,1,1
0,1,0,1,0,0,0,1,1,1,1,1,0,0,0,1,1
0,1,0,1,0,0,0,1,0,1,1,1,0,0,0,1,1
0,1,0,1,0,1,1,0,0,0,0,0,0,0,0,0,1
0,1,0,0,0,1,1,1,0,0,1,1,0,0,1,0,1
0,1,1,1,0,0,1,1,1,1,1,0,0,0,0,0,1
0,1,0,0,0,1,1,0,0,0,0,1,1,0,1,0,1
0,1,0,1,0,1,1,1,0,0,0,1,0,0,1,0,1
0,1,1,1,0,0,0,0,1,1,0,1,0,0,0,1,1
1,0,1,0,1,1,1,0,0,0,0,0,1,1,1,0,1
0,0,0,1,0,0,0,1,1,1,1,0,0,0,0,1,1
0,1,0,1,0,0,0,1,1,1,0,1,0,0,0,1,1
1,1,1,1,1,1,0,1,0,0,0,0,1,1,1,0,1
0,0,1,1,0,0,0,0,1,1,1,1,0,0,0,1,1
1,0,0,0,1,1,1,0,0,0,0,0,1,1,1,0,0
1,0,0,0,1,1,1,0,0,0,1,0,1,0,1,0,1
0,0,0,1,0,0,1,0,1,1,1,0,0,0,1,0,1
1,0,0,0,1,1,1,0,0,0,1,0,1,1,1,0,1
1,0,0,0,1,1,1,0,0,0,1,0,1,1,1,0,0
1,0,1,0,1,1,1,0,0,0,1,1,1,1,0,0,1
0,0,0,1,0,0,1,1,1,1,1,0,0,0,1,0,1
0,1,0,1,0,0,1,1,1,1,0,0,0,0,0,1,1
1,0,0,0,1,0,0,1,1,1,1,0,0,1,1,0,1
1,0,0,0,1,1,1,1,1,1,1,0,1,1,1,0,1
1,0,1,0,1,1,1,0,0,0,0,0,1,1,1,0,1
0,0,0,0,0,0,0,1,1,1,1,0,1,1,1,1,1
1,0,1,0,1,1,1,0,0,0,1,1,1,1,1,0,1
0,0,0,1,0,0,0,1,1,1,1,0,0,1,0,1,1
1,1,1,0,1,1,1,0,0,0,1,0,1,1,1,0,1
0,0,1,1,0,0,1,0,1,1,1,1,0,1,0,1,1
0,0,0,1,0,0,1,1,1,1,1,1,0,1,1,0,1
1,0,1,0,1,1,1,0,0,0,0,0,1,1,1,0,0
1,1,1,0,1,1,1,1,0,0,0,0,1,1,1,0,0
1,0,1,0,1,1,1,0,0,0,1,0,1,1,1,0,0
0,0,1,0,0,1,1,0,0,0,0,0,1,1,1,1,1
0,0,0,0,0,1,1,1,0,0,0,0,1,1,1,0,1
0,0,1,1,0,1,1,1,0,0,0,1,1,1,1,0,1
1,0,1,0,1,1,1,1,0,0,0,0,1,1,1,0,1
1,1,0,1,1,1,1,1,1,0,1,0,1,0,1,1,1
1,1,0,1,1,1,1,1,1,0,1,1,1,0,1,1,1
0,1,0,1,0,0,0,1,1,1,1,1,0,0,1,0,1
0,0,0,0,0,1,1,0,0,0,1,1,1,1,1,0,1
0,0,1,1,0,0,0,1,1,1,1,0,0,0,0,1,1
1,0,0,1,1,0,0,1,1,1,1,0,0,0,1,1,1
0,1,0,1,0,0,0,1,1,1,1,0,0,0,0,1,1
0,0,0,1,0,0,0,1,1,1,1,1,0,0,0,1,1
0,0,0,1,0,0,0,1,1,1,1,1,0,0,0,1,1
0,0,1,1,0,0,0,1,1,1,1,1,0,0,0,1,1
0,1,0,1,0,0,0,1,1,1,0,0,0,0,0,1,1
0,0,0,0,0,0,1,1,1,1,0,1,0,0,1,1,1
0,0,0,1,0,0,0,1,1,1,0,0,0,0,0,1,1
0,0,0,1,0,0,0,1,1,1,0,0,0,0,1,1,1
0,1,1,1,0,0,0,1,1,1,0,0,0,0,0,1,1
0,1,0,1,0,0,1,1,1,1,1,1,0,0,0,1,1
0,1,0,1,0,0,0,1,1,1,1,0,0,0,0,1,1
1,0,0,1,1,1,1,1,0,0,0,0,1,1,1,0,1
0,0,0,1,0,0,1,1,1,1,1,0,1,0,0,0,1
1,0,0,0,1,1,1,0,0,0,1,0,1,0,1,0,1
0,1,1,1,0,0,0,1,1,1,1,1,0,0,0,0,1
0,0,0,1,0,0,1,1,1,1,0,0,0,0,0,1,1
1,0,1,0,1,1,1,0,0,0,1,0,1,1,1,0,1
0,0,0,1,0,0,0,1,1,1,0,1,0,0,0,1,1
0,0,1,1,0,0,1,0,1,1,0,1,0,1,0,1,1
1,1,1,0,1,1,1,0,0,0,1,0,1,1,1,0,1
1,0,1,0,1,1,1,0,0,0,1,0,1,1,1,0,0
0,0,1,1,0,0,0,0,1,1,0,1,0,0,1,1,1
1,0,0,0,1,1,0,0,0,0,0,0,1,1,1,0,1
0,0,0,1,0,0,1,1,1,1,0,1,0,0,1,1,1
1,0,1,1,1,1,1,1,0,1,1,0,1,1,1,0,1
1,0,1,0,1,1,1,0,0,0,0,0,1,1,1,0,1
1,0,1,0,1,1,1,0,0,1,1,0,1,1,1,0,1
1,0,1,0,1,1,1,0,0,0,1,0,1,1,1,0,1
1,0,0,0,1,1,1,0,0,0,1,0,1,0,1,0,1
0,0,0,1,0,0,0,1,1,1,0,0,0,0,0,1,1
0,1,0,1,0,0,1,1,1,0,0,0,1,1,0,0,1
1,0,0,0,1,1,1,1,0,0,1,0,0,0,1,1,1
0,1,0,1,0,0,0,1,1,1,1,1,0,0,1,1,1
0,1,0,1,0,0,0,0,1,1,1,0,0,0,0,1,1
0,1,0,1,0,0,0,1,1,1,1,1,0,0,0,1,1
1,0,1,0,1,1,1,0,0,0,0,0,1,1,1,0,0
1,0,1,0,1,1,1,0,0,0,0,0,1,1,1,0,0
0,1,1,1,0,0,1,1,1,1,0,0,0,0,0,1,1
1,0,1,0,1,1,1,0,0,0,0,0,1,1,1,0,1
0,1,0,1,0,0,0,1,1,1,1,0,0,0,0,0,1
0,1,0,1,0,0,0,1,1,1,1,0,0,0,1,1,1
1,0,0,0,1,1,0,0,0,0,0,0,1,0,1,0,0
0,0,0,1,0,0,0,1,1,1,0,1,0,0,0,1,1
0,1,0,1,0,0,0,1,1,1,0,0,0,0,0,0,1
0,1,0,1,0,0,0,1,1,1,1,0,0,0,0,0,1
0,1,0,1,0,0,0,1,1,1,1,0,0,0,0,0,1
1,0,0,0,1,1,1,0,0,0,1,0,1,0,1,0,1
1,1,0,0,0,0,0,1,1,1,1,0,0,0,1,0,1
0,1,0,1,0,0,0,1,1,1,0,0,0,0,0,0,1
0,1,1,1,0,0,0,1,1,1,0,0,0,0,0,1,1
1,1,0,0,1,1,0,1,0,0,1,0,0,0,1,1,1
1,0,0,0,1,1,1,0,0,0,0,0,1,1,1,1,0
1,0,0,1,1,1,1,1,1,0,1,0,0,0,1,0,1
1,0,0,0,1,1,1,0,0,0,0,0,1,1,1,0,1
1,0,0,0,1,1,1,0,0,0,1,0,1,1,1,0,0
0,0,0,1,0,0,0,1,1,1,1,0,0,0,1,0,1
1,0,0,0,1,1,1,0,0,0,0,0,1,1,1,0,1
0,0,0,1,0,0,1,1,1,1,1,1,0,0,0,1,1
0,1,1,1,0,1,1,0,1,0,1,1,0,1,1,1,1
0,1,0,1,0,0,1,1,1,0,1,1,0,1,1,1,1
0,1,1,1,0,0,1,1,1,1,1,1,0,1,1,1,1
1,0,0,1,1,1,1,0,0,0,1,0,1,1,1,1,1
0,0,1,0,0,0,0,1,1,1,1,1,0,0,0,1,1
0,0,1,1,0,0,1,1,1,1,1,0,0,1,1,1,1
1,0,0,0,1,1,0,1,1,1,1,0,1,1,1,0,1
1,0,0,0,1,1,1,1,0,0,1,0,1,1,1,0,1
1,0,0,0,1,1,1,0,0,0,0,0,1,1,1,0,0
1,0,0,0,1,1,1,0,0,0,1,0,1,1,1,0,0
1,0,0,0,1,1,1,0,0,0,0,0,1,1,1,0,0
1,0,0,0,1,1,1,0,0,0,1,0,1,1,1,0,0
0,1,0,0,0,0,1,1,1,1,1,0,0,0,1,1,1
1,0,0,0,1,1,1,0,0,0,1,0,1,1,1,1,0
0,0,0,1,0,0,1,1,1,1,1,0,0,1,0,0,1
0,1,1,1,0,0,0,1,1,1,1,0,0,0,0,1,1
1,0,1,1,1,1,1,0,0,0,1,0,1,1,1,0,1
1,0,1,0,1,1,1,1,1,0,0,1,1,1,1,1,1
0,0,0,0,0,0,1,0,1,1,0,1,1,1,1,1,0
0,1,0,0,0,0,0,1,1,1,1,0,0,0,0,1,1
0,0,1,1,0,0,1,0,1,1,1,0,0,1,1,0,1
0,1,1,1,0,0,0,1,1,1,1,0,0,1,0,0,1
1,0,1,0,1,1,1,0,0,0,0,1,1,1,1,0,0
0,1,1,0,1,0,0,1,1,1,0,1,0,0,1,0,1
1,0,1,0,1,1,1,0,0,0,0,0,1,1,1,0,1
0,1,1,1,0,0,0,1,1,1,0,1,0,0,0,0,1
1,0,1,0,1,1,1,0,0,0,1,0,1,1,1,0,0
0,0,0,1,0,0,0,1,1,1,0,0,0,0,0,1,1
0,1,0,1,0,0,0,1,1,1,0,0,0,0,0,1,1
0,1,0,1,0,0,0,1,1,1,1,0,0,0,1,1,1
1,1,0,0,1,1,1,0,0,0,0,1,1,1,1,0,0
1,0,0,0,1,1,1,0,0,0,1,1,1,0,1,0,1
1,0,0,0,1,1,0,1,0,1,1,0,0,0,1,0,1
0,0,0,1,0,0,0,1,1,1,1,1,0,0,0,1,1
1,0,0,0,1,1,1,1,0,0,1,0,1,0,1,1,1
1,0,0,0,1,1,1,0,0,0,1,0,1,1,1,0,1
1,1,0,0,1,1,1,0,0,0,1,0,1,1,1,0,0
1,0,1,1,1,1,1,1,1,1,0,0,1,1,1,0,1
0,0,1,0,0,0,1,1,0,1,0,1,0,0,0,1,1
1,0,0,1,1,1,1,1,1,1,1,0,1,1,1,1,1
1,0,0,1,1,1,1,1,0,0,1,1,1,1,1,0,1
1,1,0,1,1,0,0,0,1,1,1,0,0,0,1,1,1
1,0,0,0,1,1,1,0,0,0,0,0,1,1,1,0,0
1,0,0,0,1,1,1,0,0,0,0,0,1,1,1,0,0
0,1,0,1,0,0,1,1,1,1,1,0,0,1,0,0,1
0,1,1,1,0,0,1,1,1,1,1,1,1,1,0,0,1
1,1,1,0,1,1,1,0,0,0,1,1,0,1,0,0,0
1,1,1,0,1,1,1,0,0,0,0,1,0,1,1,0,1
0,0,1,0,0,1,1,0,0,0,1,1,0,1,1,0,0
0,0,1,1,0,0,1,1,1,0,1,0,0,0,0,1,1
1,0,1,0,1,1,1,0,0,0,0,0,0,1,1,0,1
1,0,1,0,1,1,1,0,0,0,0,0,1,1,1,0,1
```

Test data:

```0,0,1,0,1,1,1,0,0,0,0,1,1,0,1,0,0
1,0,0,0,1,1,1,0,0,0,0,0,1,1,1,0,1
1,0,0,0,1,1,1,0,0,0,0,0,1,1,1,0,1
0,1,1,1,0,1,1,0,1,1,1,1,0,0,0,0,1
0,1,1,1,1,1,1,0,0,0,0,1,1,1,1,0,1
0,1,1,0,0,1,1,0,0,0,0,1,1,1,1,1,0
0,1,1,0,0,0,0,0,1,1,0,1,0,0,0,1,0
1,1,1,0,1,1,1,0,0,0,0,1,1,1,1,0,1
0,1,1,1,0,1,1,0,1,0,0,1,0,1,0,1,1
0,0,1,1,0,1,1,0,1,0,0,0,0,0,0,0,1
1,0,1,0,1,1,1,0,0,0,1,1,1,1,1,0,0
1,1,1,0,1,1,1,0,0,0,1,0,1,1,1,0,1
1,0,0,0,1,1,1,0,0,0,0,0,1,1,1,0,1
0,1,0,1,0,1,1,0,0,1,1,0,0,1,1,0,1
0,0,0,0,1,1,1,0,0,0,0,1,1,1,1,0,0
1,0,0,0,1,1,1,0,0,0,0,0,1,1,1,0,0
1,0,0,0,1,1,1,0,0,0,0,1,1,1,1,0,1
0,1,0,1,0,0,1,1,1,1,1,1,0,0,0,0,1
1,0,0,0,1,1,1,0,0,0,1,0,1,1,1,0,1
0,1,1,1,0,0,0,1,1,1,0,0,0,0,0,0,1
1,1,1,0,1,1,1,0,0,0,1,0,0,1,1,0,1
0,1,1,1,0,0,0,1,1,1,1,1,0,1,0,0,1
0,1,1,1,0,0,0,1,1,0,1,0,0,0,0,0,1
0,1,1,1,0,0,0,1,1,1,0,0,0,0,0,0,1
1,1,1,1,1,1,1,1,1,0,1,0,0,1,1,0,1
0,0,1,1,0,1,1,1,1,0,0,1,0,1,0,1,1
0,0,0,1,0,0,1,1,1,1,0,1,0,0,0,1,1
0,0,1,1,0,0,1,1,1,1,0,1,0,0,1,1,1
0,1,0,1,0,0,0,1,1,1,1,0,0,0,0,1,1
1,0,0,0,1,1,1,1,1,0,1,0,1,1,1,0,1
1,0,0,1,1,1,1,0,0,1,1,0,1,1,1,0,1
0,0,0,1,0,0,0,1,1,1,1,0,0,0,0,0,1
```

## Revisiting My Binary Classification Neural Network with Raw JavaScript

Quite some time ago I implemented a neural network binary classifier in raw JavaScript. One Saturday morning, I was going to walk my two dogs but it was raining so I decided to revisit my code while I waited for the rain to stop.

I didn’t run into any major problems but working with raw JavaScript is always a bit slow. My raw JavaScript neural network can only handle a single hidden layer but even so the results were pretty good.

I used one of my standard binary classification datasets. The data looks like:

```1  0.24  1  0  0  0.2950  0  0  1
0  0.39  0  0  1  0.5120  0  1  0
1  0.63  0  1  0  0.7580  1  0  0
. . .
```

Each line represents a person. The fields are sex (male = 0, female = 1), age (divided by 100), state (Michigan = 100, Nebraska = 010, Oklahoma = 001), income (divided by \$100,000) and political leaning (conservative = 100, moderate = 010, liberal = 001). The goal is to predict sex from age, state, income, and politics.

Implementing a neural network from scratch (in any language) is difficult. My implementation is several hundred lines of code long so I can’t present it in its entirety in this blog post.

```let U = require("../../Utilities/utilities_lib.js");
let FS = require("fs");

function main()
{
console.log("Begin binary classification demo  ");

// 1  0.24  1  0  0  0.2950  0  0  1
// 0  0.39  0  0  1  0.5120  0  1  0
"\t", [1,2,3,4,5,6,7,8], "//");
"\t", [0], "//");

"\t", [1,2,3,4,5,6,7,8], "//");
"\t", [0], "//");
. . .
```

And creating and training the network is:

```  // 2. create network
console.log("Creating 8-100-1 tanh, sigmoid NN ");
let seed = 0;
let nn = new NeuralNet(8, 100, 1, seed);

// 3. train network
let lrnRate = 0.01;
let maxEpochs = 10000;
console.log("Starting train learn rate = " +
lrnRate.toString());
nn.train(trainX, trainY, lrnRate, maxEpochs);
console.log("Done ");
. . .
```

Notice that I made train() a method that belongs to the NeuralNet class rather than a standalone function that accepts a neural net object. Design decisions like this are often more difficult than coding implementation. Anyway, it was a good mental exercise on a rainy Pacific Northwest morning.

Left: Riley (girl, black and white) and Kevin (boy, brown) waiting for the rain to stop. Right: When dog Llama stopped in from Denver to visit, I made a mini golf hole for her and my two dogs, but none of them were very interested.

## Example of Multinomial Naive Bayes Classification Using the scikit Library

The scikit-learn code library has a MultinomialNB class that can be used to create prediction models for multinomial data. The most common form of multinomial data has predictor variables where the values are counts. For example, suppose you want to predict the college course type (history = 0, math = 1, psychology = 2) from the counts of each letter grade students received.

I coded up a demo. My raw demo data is:

```# college_grades_train_raw.txt
# As Bs Cs Ds Fs Course
#
5,7,12,6,4,math
1,6,10,3,0,math
0,9,12,2,1,math
8,8,10,3,2,psychology
7,14,8,0,0,psychology
5,12,9,1,3,psychology
2,16,7,0,2,psychology
3,11,5,4,4,history
5,9,7,4,2,history
8,6,8,0,1,history
```

The first line of data means that in a particular math course, 5 students received As, and there were 7 Bs, 12 Cs, 6 Ds, and 4 Fs.

To use the scikit MultinomialNB class, the labels to predict should be ordinal/integer encoded. So the data used by the demo is:

```# college_grades_train.txt
# As Bs Cs Ds Fs Course
# history = 0, math = 1, psych = 2
#
5,7,12,6,4,1
1,6,10,3,0,1
0,9,12,2,1,1
8,8,10,3,2,2
7,14,8,0,0,2
5,12,9,1,3,2
2,16,7,0,2,2
3,11,5,4,4,0
5,9,7,4,2,0
8,6,8,0,1,0
```

The data is loaded into memory like so:

```  train_file = ".\\Data\\college_grades_train.txt"
X = XY[:,0:5]
y = XY[:,5]
```

The model is created and trained like so:

```  import numpy as np
from sklearn.naive_bayes import MultinomialNB

model = MultinomialNB(alpha=1)
model.fit(X, y)
```

The trained model can be evaluated:

```  y_predicteds = model.predict(X)
acc_train = model.score(X, y)
print("\nAccuracy on train data = %0.4f " % acc_train)
```

And the trained model can be used to make a prediction:

```  X = [[7,8,7,3,1]]  # 7 As, 8 Bs, etc.
probs = model.predict_proba(X)
print("Prediction probs: ")
print(probs)
```

The result probs matrix has just one row: [[0.7224 0.0186 0.2590]]. The values are pseudo-probabilities of each of the three possible course types. Because the value at position [0] is the largest, the prediction is class 0 = history.

I wasn’t a very good college student in my undergraduate days at U.C. Irvine. It’s something of a miracle I graduated at all because I spent more time partying than studying (except for my math and computer classes which I loved). Toga parties can trace their origins back to the 1950s. Left: The movie “Animal House” (1978) featured a toga party and popularized the idea. Center and Right: Gender differences when it comes to preparing for a toga party. College girls will spend hours creating their togas. College guys will spend approximately 45 seconds creating their togas.

Demo code:

```# multinomial_bayes.py
# predict college course type from grade counts

# Anaconda3-2020.02  Python 3.7.6
# scikit 0.22.1
# Windows 10/11

import numpy as np
from sklearn.naive_bayes import MultinomialNB

# ---------------------------------------------------------

# Data:
# numAs numBs numCs numDs numFs Course
# history = 0, math = 1, psych = 2
# 5,7,12,6,4,1
# 1,6,10,3,0,1
# . . .

# ---------------------------------------------------------

def show_confusion(cm):
dim = len(cm)
mx = np.max(cm)             # largest count in cm
wid = len(str(mx)) + 1      # width to print
fmt = "%" + str(wid) + "d"  # like "%3d"
for i in range(dim):
print("actual   ", end="")
print("%3d:" % i, end="")
for j in range(dim):
print(fmt % cm[i][j], end="")
print("")
print("------------")
print("predicted    ", end="")
for j in range(dim):
print(fmt % j, end="")
print("")

# ---------------------------------------------------------

def main():
print("\nBegin scikit multinomial Bayes demo ")
print("Predict (hist = 0, math = 1, psych = 2) from grades ")
np.random.seed(1)
np.set_printoptions(precision=4)

X = XY[:,0:5]
y = XY[:,5]

print("\nPredictor counts: ")
print(X)

print("\nCourse types: ")
print(y)

# 2. create and train model
print("\nCreating multinomial Bayes classifier ")
model = MultinomialNB(alpha=1)
model.fit(X, y)
print("Done ")

# 3. evaluate model
y_predicteds = model.predict(X)
print("\nPredicted classes: ")
print(y_predicteds)

acc_train = model.score(X, y)
print("\nAccuracy on train data = %0.4f " % acc_train)

# 3b. confusion matrix
# from sklearn.metrics import confusion_matrix
# cm = confusion_matrix(y, y_predicteds)  # actual, pred
# print("\nConfusion matrix raw: ")
# print(cm)
# print("\nConfusion matrix formatted: ")
# show_confusion(cm)

# 3c. precision, recall, F1
# for binary classification
# from sklearn.metrics import classification_report
# report = classification_report(y, y_predicteds)
# print(report)

# 4. use model
X = [[7,8,7,3,1]]  # 7 As, 8 Bs, etc.
print("\nPredicting course for grade counts: "
+ str(X))
probs = model.predict_proba(X)
print("\nPrediction probs: ")
print(probs)

pred_course = model.predict(X)  # 0,1,2
courses = ["history", "math", "psychology"]
print("\nPredicted course: ")
print(courses[pred_course[0]])

# 5. TODO: save model using pickle
# import pickle
# print("Saving trained naive Bayes model ")
# path = ".\\Models\\multinomial_scikit_model.sav"
# pickle.dump(model, open(path, "wb"))

# use saved model
# x = np.array([[6, 7, 8, 2, 1]], dtype=np.int64)
# with open(path, 'rb') as f:
# print(pa)

print("\nEnd multinomial Bayes demo ")

if __name__ == "__main__":
main()
```

## “Binary Classification Using a scikit Decision Tree” in Visual Studio Magazine

I wrote an article titled “Binary Classification Using a scikit Decision Tree” in the February 2023 edition of Microsoft Visual Studio Magazine. See https://visualstudiomagazine.com/articles/2023/02/21/scikit-decision-tree.aspx.

A decision tree is a machine learning technique that can be used for binary classification or multi-class classification. My article presents an end-to-end demo that predicts the sex a person (male = 0 or female = 1) based on their age, state where they live, income and political leaning.

There are several tools and code libraries that you can use to perform binary classification using a decision tree. The scikit-learn library (also called scikit or sklearn) is based on the Python language and is one of the most popular machine learning libraries.

The article demo data is one of my standard synthetic datasets and looks like:

```1   0.24   1   0   0   0.2950   0   0   1
0   0.39   0   0   1   0.5120   0   1   0
1   0.63   0   1   0   0.7580   1   0   0
0   0.36   1   0   0   0.4450   0   1   0
1   0.27   0   1   0   0.2860   0   0   1
. . .
```

The tab-delimited fields are sex (0 = male, 1 = female), age (divided by 100), state (Michigan = 100, Nebraska = 010, Oklahoma = 001), income (divided by \$100,000) and political leaning (conservative = 100, moderate = 010, liberal = 001).

One of the advantages of the scikit library is simplicity (at the expense of flexibility). Creating and training a decision tree is easy:

```  md = 4
print("Creating decision tree max_depth=" + str(md))
model = tree.DecisionTreeClassifier(max_depth=md,
random_state=1)
model.fit(train_x, train_y)
print("Done ")
```

An advantage of decision tree classifiers over neural network classifiers is that decision trees are somewhat interpretable because a decision tree is just a set of if-then rules. For example:

```|--- income <= 0.34
|   |--- pol2 <= 0.50
|   |   |--- age <= 0.23
|   |   |   |--- income <= 0.28
|   |   |   |   |--- class: 1.0
. . .
```

The two main downsides to decision trees are that they often don’t work well with large datasets, and they are highly susceptible to model overfitting.

One of the main areas of social science research is the study of behavioral differences between men and women. It’s well-known that men and women think about relationships differently. Left: How a woman thinks about her relationship with a man. Complicated. Right: How a man thinks about his relationship with a woman. Not so complicated.

Posted in Scikit | 1 Comment