Neural Binary Classification Using PyTorch

Of the neural network code libraries I use most often (TensorFlow, Keras, CNTK, PyTorch), PyTorch is by far the least mature. The Windows version of PyTorch was released only a few weeks ago.

So, there are almost no good PyTorch examples available, and learning PyTorch is a slow process. I took a big step forward recently when I created a binary classifier using PyTorch. The code was surprisingly difficult — many tricky details.

My demo program uses the Banknote Authentication dataset. The goal is to use four predictor variables from digital images of the banknotes and predict which are authentic and which are forgeries.

Because PyTorch operates at a very low level, there are a huge number of design decisions to make. My demo uses a 4-(8-8)-1 deep neural network with tanh activation on the hidden layers and the standard-for-binary-classification sigmoid activation on the output node. I used explicit Glorot initialization on all weights, and initialized all biases to zero.

PyTorch currently doesn’t have any built-in classification accuracy functions so I wrote my own. And there’s no built-in mechanism to generate training mini-batches so I wrote a custom class to do that.

It’s clear that PyTorch is very immature and will change greatly over the next year or so. There’s a strong temptation to just wait until PyTorch stabilizes. But I know from previous experience it’s better to man-up and dive in now and learn as much as possible, even at the expense of a lot of extra effort.

In a weird way, struggling to get models created using PyTorch is fun in spite of the intellectual pain. If you’re a software guy like me, you know exactly what I’m talking about. And if you’re not a software guy, I’ll bet there’s a similar difficult activity you’re passionate about and enjoy the challenge.



The definition of “manly” is “having qualities such as strength and courage that are expected in a man”. But a manly approach to software development may not always be optimal.

Advertisements
This entry was posted in Machine Learning, PyTorch. Bookmark the permalink.