An Example Of The skorch Library For PyTorch

Writing deep neural network code using the PyTorch library is quite difficult. The skorch (“scikit” + “torch”) library provides wrapper code over PyTorch code that is intended to make using PyTorch easier. The skorch library gives an interface like, and access to, the popular scickit-learn library.

I experimented with skorch for a full day. I wanted to like it. But . . . my conclusion is that the convenience gained by using skorch is far outweighed by the loss of control. The skorch library might appeal to hard-core scikit-learn users but for people like me, who need full control over neural networks, skorch has no advantage.




Top: Example of multi-class classification using standard PyTorch. Bottom: Same data, but using the skorch wrapper library over PyTorch.

The only way to grasp the few pros and many cons of the skorch library (for my scenarios) is to closely examine a code example. See below. And this in itself is one of the cons of using skorch — it has a moderately steep learning curve and that time could be better spent learning nuances of PyTorch.

The skorch library allows you to skip creating a PyTorch Dataset object, but that's a very small gain. To create a neural network for skorch, you define a PyTorch class as usual. Then you pass the PyTorch net definition to a skorch NeuralNetClassifier class, along with information needed to train the network, such as loss function, max_epochs, learning rate, batch_size, and so on. With a regular PyTorch neural network you must define your own train() method. But using a skorch approach you can just call a fit() method.

Over the course of my experimentation, I gained dozens of insights into skorch. It would take many pages to explain all these details so I won't try because the bottom line is that ultimately I was not impressed with skorch. In my scenarios I need complete control over my neural network code. The skorch library takes away the flexibility I need. For example, I wanted to adjust the way the training code in the fit() shuffles data items — nope, no easy way. I wanted to just print training progress every so often instead of on every epoch — nope, no easy way. And so on.

One positive use case I can see for skorch is as a transition bridge for scikit users (often beginners) who aren't familiar with neural networks.



Clothes made from bubble wrap. Hmmm, not sure about this idea.


Demo code below.

# iris_skorch.py
# skorch library version of Iris

import numpy as np
import torch as T
from skorch import NeuralNetClassifier
dvc = T.device("cpu")

# -----------------------------------------------------------

class Net(T.nn.Module):
  def __init__(self):
    super(Net, self).__init__()
    self.hid1 = T.nn.Linear(4, 7)  # 4-7-3
    self.oupt = T.nn.Linear(7, 3)

    # skorch documentation recommends not initializing here (!?)

  def forward(self, x):
    z = T.tanh(self.hid1(x))
    z = self.oupt(z)  # no softmax: CrossEntropyLoss() 
    return z

# -----------------------------------------------------------

def main():
  # 0. get started
  print("\nBegin Iris with PyTorch 1.9 and skorch 0.10 \n")
  T.manual_seed(1)
  np.random.seed(1)
  # np.set_printoptions(precision=4, suppress=True, sign=" ")

  # 1. load training data
  print("\nLoading Iris training data ")
  X = np.loadtxt(".\\Data\\iris_train.txt", usecols=range(0,4), 
    delimiter=",", dtype=np.float32)
  y = np.loadtxt(".\\Data\\iris_train.txt", usecols=4, 
    delimiter=",", dtype=np.int64)

  # 2. create net
  print("\nCreating skorch style neural net ")
  net = NeuralNetClassifier(module=Net, max_epochs=12,
    criterion=T.nn.CrossEntropyLoss,  lr=0.05, device=dvc,
    batch_size=4, verbose=1)

  # 3. train
  print("\nStarting training ")
  net.fit(X, y)  # scikit style. does fit() shuffle ??
  print("Training complete ")

  # 3. compute prediction accuracy on train data
  acc = net.score(X, y)
  print("\nAccuracy on train data = %0.4f " % acc)

  # 4. make a prediction
  np.set_printoptions(formatter={'float': '{: 0.1f}'.format})
  print("\nPredicting species for [6.1, 3.1, 5.1, 1.1]: ")
  x = np.array([[6.1, 3.1, 5.1, 1.1]], dtype=np.float32)

  np.set_printoptions(formatter={'float': '{: 0.4f}'.format})
  y_pred = net.predict_proba(x)
  print("\nPrediction: ")
  print(y_pred)

  print("\nEnd Iris with skorch demo")

if __name__ == "__main__":
  main()
This entry was posted in PyTorch. Bookmark the permalink.

Leave a Reply

Please log in using one of these methods to post your comment:

WordPress.com Logo

You are commenting using your WordPress.com account. Log Out /  Change )

Google photo

You are commenting using your Google account. Log Out /  Change )

Twitter picture

You are commenting using your Twitter account. Log Out /  Change )

Facebook photo

You are commenting using your Facebook account. Log Out /  Change )

Connecting to %s