Multi-Class Classification Using a scikit MLPClassifier Neural Network

The scikit-learn (aka scikit) library can do neural networks. In a work environment, I use PyTorch because it’s very flexible. But one morning, just for fun and mental exercise, I decided to create a scikit neural network multi-class classifier to compare it to PyTorch.

I used one of my standard synthetic datasets. The data looks like:

 1   0.24   1   0   0   0.2950   2
-1   0.39   0   0   1   0.5120   1
 1   0.63   0   1   0   0.7580   0
-1   0.36   1   0   0   0.4450   1
. . . 

Each line of data represents a person. The fields are sex (male = -1, female = 1), age (normalized by dividing by 100), state (michigan = 100, nebraska = 010, oklahoma = 001), annual income (divided by 100,000), and politics type (0 = conservative, 1 = moderate, 2 = liberal). The goal is to predict politics type from sex, age, state, income.

One of the characteristics of scikit classes is that many of them have zillions of parameters:

  # MLPClassifier(hidden_layer_sizes=(100,),
  #  activation='relu', *, solver='adam', alpha=0.0001,
  #  batch_size='auto', learning_rate='constant',
  #  learning_rate_init=0.001, power_t=0.5, max_iter=200,
  #  shuffle=True, random_state=None, tol=0.0001,
  #  verbose=False, warm_start=False, momentum=0.9,
  #  nesterovs_momentum=True, early_stopping=False,
  #  validation_fraction=0.1, beta_1=0.9, beta_2=0.999,
  #  epsilon=1e-08, n_iter_no_change=10, max_fun=15000)

From my years of experience with neural networks, I understand what these parameters do, but for someone new to neural networks, parsing through these parameters would take a long, long time.

For my demo, I set these parameters and created the network like so:

  import numpy as np 
  from sklearn.neural_network import MLPClassifier

  params = { 'hidden_layer_sizes' : [10,10],
    'activation' : 'tanh',
    'solver' : 'sgd',
    'alpha' : 0.0,
    'batch_size' : 10,
    'random_state' : 1,
    'tol' : 0.0001,
    'nesterovs_momentum' : False,
    'learning_rate' : 'constant',
    'learning_rate_init' : 0.01,
    'max_iter' : 1000,
    'shuffle' : True,
    'n_iter_no_change' : 90,
    'verbose' : False }

  print("\nCreating 6-(10-10)-3 tanh neural network ")
  net = MLPClassifier(**params)

Explaining all these parameters would take pages so I won’t try. I will point out that the n_iter_no_change is very important because otherwise training will automatically stop after a default of 10 iterations with no significant (the tol parameter) improvement — and in my experiments, the default often ended training too soon.

It was a fun and interesting experiment.



The parameters of the scikit MLPClassifier are essentially the control panel of the module. Left: Control panel for the Three Mile Island nuclear power planet, near Harrisburg, Pennsylvania. It was shut down in 2019 and is being decommissioned. Right: Control panel for the Watts Bar nuclear plant near Chattanooga, Tennessee. It’s the newest nuclear plant in the U.S.


Demo code below. The training and test data are below and also at: https://jamesmccaffrey.wordpress.com/2022/09/01/multi-class-classification-using-pytorch-1-12-1-on-windows-10-11/.

# people_politics_nn_sckit.py

# predict politics (0 = con, 1 = mod, 2 = lib) 
# from sex, age, state, income

# sex  age    state    income   politics
# -1   0.27   0  1  0   0.7610   2
#  1   0.19   0  0  1   0.6550   0
# sex: 0 = male, 1 = female
# state: michigan = 100, nebraska = 010, oklahoma = 001
# politics: conservative, moderate, liberal

# Anaconda3-2020.02  Python 3.7.6  scikit 0.22.1
# Windows 10/11

import numpy as np 
from sklearn.neural_network import MLPClassifier
import warnings
warnings.filterwarnings('ignore')  # early-stop warnings

# ---------------------------------------------------------

def show_confusion(cm):
  dim = len(cm)
  mx = np.max(cm)             # largest count in cm
  wid = len(str(mx)) + 1      # width to print
  fmt = "%" + str(wid) + "d"  # like "%3d"
  for i in range(dim):
    print("actual   ", end="")
    print("%3d:" % i, end="")
    for j in range(dim):
      print(fmt % cm[i][j], end="")
    print("")
  print("------------")
  print("predicted    ", end="")
  for j in range(dim):
    print(fmt % j, end="")
  print("")

# ---------------------------------------------------------

def main():
  # 0. get ready
  print("\nBegin scikit neural network example ")
  print("Predict politics from sex, age, State, income ")
  np.random.seed(1)
  np.set_printoptions(precision=4, suppress=True)

  # sex  age    state    income   politics
  # -1   0.27   0  1  0   0.7610   2
  #  1   0.19   0  0  1   0.6550   0

  # 1. load data
  print("\nLoading data into memory ")
  train_file = ".\\Data\\people_train.txt"
  train_xy = np.loadtxt(train_file, usecols=range(0,7),
    delimiter="\t", comments="#",  dtype=np.float32) 
  train_x = train_xy[:,0:6]
  train_y = train_xy[:,6].astype(int)

  test_file = ".\\Data\\people_test.txt"
  test_xy = np.loadtxt(test_file, usecols=range(0,7),
    delimiter="\t", comments="#",  dtype=np.float32) 
  test_x = test_xy[:,0:6]
  test_y = test_xy[:,6].astype(int)

  print("\nTraining data:")
  print(train_x[0:4])
  print(". . . \n")
  print(train_y[0:4])
  print(". . . ")
 
# ---------------------------------------------------------

  # 2. create network 
  # MLPClassifier(hidden_layer_sizes=(100,),
  #  activation='relu', *, solver='adam', alpha=0.0001,
  #  batch_size='auto', learning_rate='constant',
  #  learning_rate_init=0.001, power_t=0.5, max_iter=200,
  #  shuffle=True, random_state=None, tol=0.0001,
  #  verbose=False, warm_start=False, momentum=0.9,
  #  nesterovs_momentum=True, early_stopping=False,
  #  validation_fraction=0.1, beta_1=0.9, beta_2=0.999,
  #  epsilon=1e-08, n_iter_no_change=10, max_fun=15000)

  params = { 'hidden_layer_sizes' : [10,10],
    'activation' : 'tanh',
    'solver' : 'sgd',
    'alpha' : 0.0,
    'batch_size' : 10,
    'random_state' : 1,
    'tol' : 0.0001,
    'nesterovs_momentum' : False,
    'learning_rate' : 'constant',
    'learning_rate_init' : 0.01,
    'max_iter' : 1000,
    'shuffle' : True,
    'n_iter_no_change' : 90,
    'verbose' : False }
       
  print("\nCreating 6-(10-10)-3 tanh neural network ")
  net = MLPClassifier(**params)

# ---------------------------------------------------------

  # 3. train
  print("\nTraining with bat sz = " + \
    str(params['batch_size']) + " lrn rate = " + \
    str(params['learning_rate_init']) + " ")
  print("Stop if no change " + \
    str(params['n_iter_no_change']) + " iterations ")
  net.fit(train_x, train_y)
  print("Done ")

  # 4. evaluate
  acc_train = net.score(train_x, train_y)
  print("\nAccuracy on train = %0.4f " % acc_train)
  acc_test = net.score(test_x, test_y)
  print("Accuracy on test = %0.4f " % acc_test)

  # from sklearn.metrics import confusion_matrix
  # y_predicteds = net.predict(test_x)
  # cm = confusion_matrix(test_y, y_predicteds) 
  # print("\nConfusion matrix raw: ")
  # print(cm)
  # show_confusion(cm)  # with formatted labels

  # 5. use model
  print("\nPredict for: M 35 Nebraska $55K ")
  X = np.array([[-1, 0.35, 0,1,0, 0.5500]],
    dtype=np.float32)

  probs = net.predict_proba(X)
  print("\nPrediction pseudo-probs: ")
  print(probs)

  politic = net.predict(X)  # 0,1,2
  lbls = ["conservative", "moderate", "liberal"]
  print("\nPredicted class: ")
  print(lbls[politic[0]])

  # 6. TODO: save model using pickle
  # import pickle
  # print("Saving trained network ")
  # path = ".\\Models\\network.sav"
  # pickle.dump(model, open(path, "wb"))

  # load and use saved model
  # X = np.array([[-1, 0.35, 0,1,0, 0.5500]],
  #   dtype=np.float32)
  # with open(path, 'rb') as f:
  #   loaded_model = pickle.load(f)
  # pa = loaded_model.predict_proba(X)
  # print(pa)

  print("\nEnd scikit neural network demo ")

if __name__ == "__main__":
  main()

Training data. Replace commas with tabs or modify program.

# people_train.txt
# sex (M=-1, F=1) , age (div 100),  state
# (michigan=100, nebraska=010, oklahoma=001) 
# income (div $100,000)
# politics (con=0, mod=1, lib=2)
#
 1,0.24,1,0,0,0.2950,2
-1,0.39,0,0,1,0.5120,1
 1,0.63,0,1,0,0.7580,0
-1,0.36,1,0,0,0.4450,1
 1,0.27,0,1,0,0.2860,2
 1,0.50,0,1,0,0.5650,1
 1,0.50,0,0,1,0.5500,1
-1,0.19,0,0,1,0.3270,0
 1,0.22,0,1,0,0.2770,1
-1,0.39,0,0,1,0.4710,2
 1,0.34,1,0,0,0.3940,1
-1,0.22,1,0,0,0.3350,0
1,0.35,0,0,1,0.3520,2
-1,0.33,0,1,0,0.4640,1
1,0.45,0,1,0,0.5410,1
1,0.42,0,1,0,0.5070,1
-1,0.33,0,1,0,0.4680,1
1,0.25,0,0,1,0.3000,1
-1,0.31,0,1,0,0.4640,0
1,0.27,1,0,0,0.3250,2
1,0.48,1,0,0,0.5400,1
-1,0.64,0,1,0,0.7130,2
1,0.61,0,1,0,0.7240,0
1,0.54,0,0,1,0.6100,0
1,0.29,1,0,0,0.3630,0
1,0.50,0,0,1,0.5500,1
1,0.55,0,0,1,0.6250,0
1,0.40,1,0,0,0.5240,0
1,0.22,1,0,0,0.2360,2
1,0.68,0,1,0,0.7840,0
-1,0.60,1,0,0,0.7170,2
-1,0.34,0,0,1,0.4650,1
-1,0.25,0,0,1,0.3710,0
-1,0.31,0,1,0,0.4890,1
1,0.43,0,0,1,0.4800,1
1,0.58,0,1,0,0.6540,2
-1,0.55,0,1,0,0.6070,2
-1,0.43,0,1,0,0.5110,1
-1,0.43,0,0,1,0.5320,1
-1,0.21,1,0,0,0.3720,0
1,0.55,0,0,1,0.6460,0
1,0.64,0,1,0,0.7480,0
-1,0.41,1,0,0,0.5880,1
1,0.64,0,0,1,0.7270,0
-1,0.56,0,0,1,0.6660,2
1,0.31,0,0,1,0.3600,1
-1,0.65,0,0,1,0.7010,2
1,0.55,0,0,1,0.6430,0
-1,0.25,1,0,0,0.4030,0
1,0.46,0,0,1,0.5100,1
-1,0.36,1,0,0,0.5350,0
1,0.52,0,1,0,0.5810,1
1,0.61,0,0,1,0.6790,0
1,0.57,0,0,1,0.6570,0
-1,0.46,0,1,0,0.5260,1
-1,0.62,1,0,0,0.6680,2
1,0.55,0,0,1,0.6270,0
-1,0.22,0,0,1,0.2770,1
-1,0.50,1,0,0,0.6290,0
-1,0.32,0,1,0,0.4180,1
-1,0.21,0,0,1,0.3560,0
1,0.44,0,1,0,0.5200,1
1,0.46,0,1,0,0.5170,1
1,0.62,0,1,0,0.6970,0
1,0.57,0,1,0,0.6640,0
-1,0.67,0,0,1,0.7580,2
1,0.29,1,0,0,0.3430,2
1,0.53,1,0,0,0.6010,0
-1,0.44,1,0,0,0.5480,1
1,0.46,0,1,0,0.5230,1
-1,0.20,0,1,0,0.3010,1
-1,0.38,1,0,0,0.5350,1
1,0.50,0,1,0,0.5860,1
1,0.33,0,1,0,0.4250,1
-1,0.33,0,1,0,0.3930,1
1,0.26,0,1,0,0.4040,0
1,0.58,1,0,0,0.7070,0
1,0.43,0,0,1,0.4800,1
-1,0.46,1,0,0,0.6440,0
1,0.60,1,0,0,0.7170,0
-1,0.42,1,0,0,0.4890,1
-1,0.56,0,0,1,0.5640,2
-1,0.62,0,1,0,0.6630,2
-1,0.50,1,0,0,0.6480,1
1,0.47,0,0,1,0.5200,1
-1,0.67,0,1,0,0.8040,2
-1,0.40,0,0,1,0.5040,1
1,0.42,0,1,0,0.4840,1
1,0.64,1,0,0,0.7200,0
-1,0.47,1,0,0,0.5870,2
1,0.45,0,1,0,0.5280,1
-1,0.25,0,0,1,0.4090,0
1,0.38,1,0,0,0.4840,0
1,0.55,0,0,1,0.6000,1
-1,0.44,1,0,0,0.6060,1
1,0.33,1,0,0,0.4100,1
1,0.34,0,0,1,0.3900,1
1,0.27,0,1,0,0.3370,2
1,0.32,0,1,0,0.4070,1
1,0.42,0,0,1,0.4700,1
-1,0.24,0,0,1,0.4030,0
1,0.42,0,1,0,0.5030,1
1,0.25,0,0,1,0.2800,2
1,0.51,0,1,0,0.5800,1
-1,0.55,0,1,0,0.6350,2
1,0.44,1,0,0,0.4780,2
-1,0.18,1,0,0,0.3980,0
-1,0.67,0,1,0,0.7160,2
1,0.45,0,0,1,0.5000,1
1,0.48,1,0,0,0.5580,1
-1,0.25,0,1,0,0.3900,1
-1,0.67,1,0,0,0.7830,1
1,0.37,0,0,1,0.4200,1
-1,0.32,1,0,0,0.4270,1
1,0.48,1,0,0,0.5700,1
-1,0.66,0,0,1,0.7500,2
1,0.61,1,0,0,0.7000,0
-1,0.58,0,0,1,0.6890,1
1,0.19,1,0,0,0.2400,2
1,0.38,0,0,1,0.4300,1
-1,0.27,1,0,0,0.3640,1
1,0.42,1,0,0,0.4800,1
1,0.60,1,0,0,0.7130,0
-1,0.27,0,0,1,0.3480,0
1,0.29,0,1,0,0.3710,0
-1,0.43,1,0,0,0.5670,1
1,0.48,1,0,0,0.5670,1
1,0.27,0,0,1,0.2940,2
-1,0.44,1,0,0,0.5520,0
1,0.23,0,1,0,0.2630,2
-1,0.36,0,1,0,0.5300,2
1,0.64,0,0,1,0.7250,0
1,0.29,0,0,1,0.3000,2
-1,0.33,1,0,0,0.4930,1
-1,0.66,0,1,0,0.7500,2
-1,0.21,0,0,1,0.3430,0
1,0.27,1,0,0,0.3270,2
1,0.29,1,0,0,0.3180,2
-1,0.31,1,0,0,0.4860,1
1,0.36,0,0,1,0.4100,1
1,0.49,0,1,0,0.5570,1
-1,0.28,1,0,0,0.3840,0
-1,0.43,0,0,1,0.5660,1
-1,0.46,0,1,0,0.5880,1
1,0.57,1,0,0,0.6980,0
-1,0.52,0,0,1,0.5940,1
-1,0.31,0,0,1,0.4350,1
-1,0.55,1,0,0,0.6200,2
1,0.50,1,0,0,0.5640,1
1,0.48,0,1,0,0.5590,1
-1,0.22,0,0,1,0.3450,0
1,0.59,0,0,1,0.6670,0
1,0.34,1,0,0,0.4280,2
-1,0.64,1,0,0,0.7720,2
1,0.29,0,0,1,0.3350,2
-1,0.34,0,1,0,0.4320,1
-1,0.61,1,0,0,0.7500,2
1,0.64,0,0,1,0.7110,0
-1,0.29,1,0,0,0.4130,0
1,0.63,0,1,0,0.7060,0
-1,0.29,0,1,0,0.4000,0
-1,0.51,1,0,0,0.6270,1
-1,0.24,0,0,1,0.3770,0
1,0.48,0,1,0,0.5750,1
1,0.18,1,0,0,0.2740,0
1,0.18,1,0,0,0.2030,2
1,0.33,0,1,0,0.3820,2
-1,0.20,0,0,1,0.3480,0
1,0.29,0,0,1,0.3300,2
-1,0.44,0,0,1,0.6300,0
-1,0.65,0,0,1,0.8180,0
-1,0.56,1,0,0,0.6370,2
-1,0.52,0,0,1,0.5840,1
-1,0.29,0,1,0,0.4860,0
-1,0.47,0,1,0,0.5890,1
1,0.68,1,0,0,0.7260,2
1,0.31,0,0,1,0.3600,1
1,0.61,0,1,0,0.6250,2
1,0.19,0,1,0,0.2150,2
1,0.38,0,0,1,0.4300,1
-1,0.26,1,0,0,0.4230,0
1,0.61,0,1,0,0.6740,0
1,0.40,1,0,0,0.4650,1
-1,0.49,1,0,0,0.6520,1
1,0.56,1,0,0,0.6750,0
-1,0.48,0,1,0,0.6600,1
1,0.52,1,0,0,0.5630,2
-1,0.18,1,0,0,0.2980,0
-1,0.56,0,0,1,0.5930,2
-1,0.52,0,1,0,0.6440,1
-1,0.18,0,1,0,0.2860,1
-1,0.58,1,0,0,0.6620,2
-1,0.39,0,1,0,0.5510,1
-1,0.46,1,0,0,0.6290,1
-1,0.40,0,1,0,0.4620,1
-1,0.60,1,0,0,0.7270,2
1,0.36,0,1,0,0.4070,2
1,0.44,1,0,0,0.5230,1
1,0.28,1,0,0,0.3130,2
1,0.54,0,0,1,0.6260,0

Test data:

-1,0.51,1,0,0,0.6120,1
-1,0.32,0,1,0,0.4610,1
1,0.55,1,0,0,0.6270,0
1,0.25,0,0,1,0.2620,2
1,0.33,0,0,1,0.3730,2
-1,0.29,0,1,0,0.4620,0
1,0.65,1,0,0,0.7270,0
-1,0.43,0,1,0,0.5140,1
-1,0.54,0,1,0,0.6480,2
1,0.61,0,1,0,0.7270,0
1,0.52,0,1,0,0.6360,0
1,0.30,0,1,0,0.3350,2
1,0.29,1,0,0,0.3140,2
-1,0.47,0,0,1,0.5940,1
1,0.39,0,1,0,0.4780,1
1,0.47,0,0,1,0.5200,1
-1,0.49,1,0,0,0.5860,1
-1,0.63,0,0,1,0.6740,2
-1,0.30,1,0,0,0.3920,0
-1,0.61,0,0,1,0.6960,2
-1,0.47,0,0,1,0.5870,1
1,0.30,0,0,1,0.3450,2
-1,0.51,0,0,1,0.5800,1
-1,0.24,1,0,0,0.3880,1
-1,0.49,1,0,0,0.6450,1
1,0.66,0,0,1,0.7450,0
-1,0.65,1,0,0,0.7690,0
-1,0.46,0,1,0,0.5800,0
-1,0.45,0,0,1,0.5180,1
-1,0.47,1,0,0,0.6360,0
-1,0.29,1,0,0,0.4480,0
-1,0.57,0,0,1,0.6930,2
-1,0.20,1,0,0,0.2870,2
-1,0.35,1,0,0,0.4340,1
-1,0.61,0,0,1,0.6700,2
-1,0.31,0,0,1,0.3730,1
1,0.18,1,0,0,0.2080,2
1,0.26,0,0,1,0.2920,2
-1,0.28,1,0,0,0.3640,2
-1,0.59,0,0,1,0.6940,2
This entry was posted in Scikit. Bookmark the permalink.