PyTorch Not-Fully-Connected Layer Using prune.custom_from_mask()

I ran across an interesting PyTorch function that I hadn’t seen before. The torch.nn.utils.prune.custom_from_mask() function can mask out weights and biases in a neural layer. This allows you to create layers that are not fully connected.

I checked the PyTorch documentation, and sadly, as usual, it wasn’t much help:

Prunes tensor corresponding to parameter called name in module
by applying the pre-computed mask in mask. Modifies module in
place (and also return the modified module) by:

1.) adding a named buffer called name+'_mask' corresponding
to the binary mask applied to the parameter name by the
pruning method.

2.) replacing the parameter name by its pruned version, while
the original (unpruned) parameter is stored in a new parameter
named name+'_orig'.

So I decided to experiment. I started with one of my standard multi-class classification examples. The goal is to predict employee job-type (mgmt, supp, tech) from sex, age, city (one of three), and income. My starting network was 6-(10-10)-3.

I modified the network to delete the weight from input node [0] (the sex node) to hidden1 layer node [1]. I also deleted the bias to hidden1 layer node [1].

The key code is:


class Net(T.nn.Module):
  def __init__(self):
    super(Net, self).__init__()
    self.hid1 = T.nn.Linear(6, 10)  # 6-(10-10)-3
    self.hid2 = T.nn.Linear(10, 10)
    self.oupt = T.nn.Linear(10, 3)

    # prune
    print("\nMasking hid1 from node 0 to node 1 ")
    msk_wts = T.ones((10,6),
      dtype=T.float32).to(device) # [to, from]
    msk_wts[1][0] = 0  # to [1] from [0]
    T.nn.utils.prune.custom_from_mask(self.hid1,
      name='weight', mask=msk_wts)

    msk_bias = T.tensor([1,0,1,1,1,1,1,1,1,1],
      dtype=T.float32).to(device)
    T.nn.utils.prune.custom_from_mask(self.hid1,
      name='bias', mask=msk_bias)
 
    # default init

  def forward(self, x):
    z = T.tanh(self.hid1(x))
    z = T.tanh(self.hid2(z))
    z = self.oupt(z)  # no softmax: CrossEntropyLoss() 
    return z

Like most PyTorch topics, the idea is relatively simple on the surface but there are many very complex ideas lurking under the covers.

Interesting and good fun.



PyTorch masks are used quite often. Masks in fantasy movies are also common. Unusual masks with a sort of Asian theme in three of my favorite fantasy movies. Left: In “The Fall” (2006) Evelyn (actress Justine Waddell) is saved by the hero. Center: Miao Yin (played by actress Suzee Pai) is menaced by an evil wizard in “Big Trouble in Little China” (1986). Right: Princess Su Lin (actress Ni Ni) in “Enter the Warriors Gate” (2016).


Demo code. The training and test data can be found at jamesmccaffrey.wordpress.com/2022/04/29/predicting-employee-job-type-using-pytorch-1-10-on-windows-11/

# employee_job_prune.py
# predict job type from sex, age, city, income
# PyTorch 1.10.0-CPU Anaconda3-2020.02  Python 3.7.6
# Windows 10/11 

# explore T.nn.utils.prune.custom_from_mask()

import numpy as np
import time
import torch as T
import torch.nn.utils.prune
device = T.device('cpu')  # apply to Tensor or Module

# -----------------------------------------------------------

class EmployeeDataset(T.utils.data.Dataset):
  # sex  age    city      income  job-type
  # -1   0.27   0  1  0   0.7610   2
  # +1   0.19   0  0  1   0.6550   0
  # sex: -1 = male, +1 = female
  # city: anaheim, boulder, concord
  # job type: mgmt, supp, tech

  def __init__(self, src_file, num_rows=None):
    all_xy = np.loadtxt(src_file, max_rows=num_rows,
      usecols=range(0,7), delimiter="\t", comments="#",
      dtype=np.float32)
    tmp_x = all_xy[0:num_rows,0:6]   # cols [0,6) = [0,5]
    tmp_y = all_xy[0:num_rows,6]  # 1-D
    
    self.x_data = T.tensor(tmp_x, 
      dtype=T.float32).to(device)
    self.y_data = T.tensor(tmp_y,
      dtype=T.int64).to(device)  # 1-D

  def __len__(self):
    return len(self.x_data)

  def __getitem__(self, idx):
    preds = self.x_data[idx]
    trgts = self.y_data[idx] 
    return preds, trgts  # as a Tuple

# -----------------------------------------------------------

class Net(T.nn.Module):
  def __init__(self):
    super(Net, self).__init__()
    self.hid1 = T.nn.Linear(6, 10)  # 6-(10-10)-3
    self.hid2 = T.nn.Linear(10, 10)
    self.oupt = T.nn.Linear(10, 3)

    # prune
    print("\nMasking hid1 from node 0 to node 1 ")
    msk_wts = T.ones((10,6),
      dtype=T.float32).to(device) # [to, from]
    msk_wts[1][0] = 0  # to [1] from [0]
    T.nn.utils.prune.custom_from_mask(self.hid1,
      name='weight', mask=msk_wts)

    msk_bias = T.tensor([1,0,1,1,1,1,1,1,1,1],
      dtype=T.float32).to(device)
    T.nn.utils.prune.custom_from_mask(self.hid1,
      name='bias', mask=msk_bias)
 
    # default init

  def forward(self, x):
    z = T.tanh(self.hid1(x))
    z = T.tanh(self.hid2(z))
    z = self.oupt(z)  # no softmax: CrossEntropyLoss() 
    return z

# -----------------------------------------------------------

def accuracy(model, ds):
  # assumes model.eval()
  # item-by-item version
  n_correct = 0; n_wrong = 0
  for i in range(len(ds)):
    X = ds[i][0]
    Y = ds[i][1]  # 0 1 or 2
    with T.no_grad():
      oupt = model(X)  # logits form

    big_idx = T.argmax(oupt)  # 0 or 1 or 2
    if big_idx == Y:
      n_correct += 1
    else:
      n_wrong += 1

  acc = (n_correct * 1.0) / (n_correct + n_wrong)
  return acc

# -----------------------------------------------------------

def main():
  # 0. get started
  print("\nBegin Employee predict job type pruning demo")
  T.manual_seed(1)
  np.random.seed(1)
  
  # 1. create DataLoader objects
  print("\nCreating Employee Datasets ")

  train_file = ".\\Data\\employee_train.txt"
  train_ds = EmployeeDataset(train_file)  # all 200 rows

  test_file = ".\\Data\\employee_test.txt"
  test_ds = EmployeeDataset(test_file)  # all 40 rows

  bat_size = 10
  train_ldr = T.utils.data.DataLoader(train_ds,
    batch_size=bat_size, shuffle=True)

# -----------------------------------------------------------

  # 2. create network
  print("\nCreating 6-(10-10)-3 pruned neural network ")
  net = Net().to(device)

# -----------------------------------------------------------

  # 3. train model
  max_epochs = 1000
  ep_log_interval = 100
  lrn_rate = 0.01

  loss_func = T.nn.CrossEntropyLoss()  # applies log-softmax()
  optimizer = T.optim.SGD(net.parameters(), lr=lrn_rate)

  print("\nbat_size = %3d " % bat_size)
  print("loss = " + str(loss_func))
  print("optimizer = SGD")
  print("max_epochs = %3d " % max_epochs)
  print("lrn_rate = %0.3f " % lrn_rate)

  print("\nStarting training")
  net.train()  # or net = net.train()
  for epoch in range(0, max_epochs):
    T.manual_seed(epoch+1)  # checkpoint reproducibility
    epoch_loss = 0  # for one full epoch

    for (batch_idx, batch) in enumerate(train_ldr):
      X = batch[0]  # inputs
      Y = batch[1]     # correct class/label/job

      optimizer.zero_grad()
      oupt = net(X)
      loss_val = loss_func(oupt, Y)  # a tensor
      epoch_loss += loss_val.item()  # accumulate
      loss_val.backward()
      optimizer.step()

    if epoch % ep_log_interval == 0:
      print("epoch = %5d  |  loss = %10.4f" % \
        (epoch, epoch_loss))
  print("Done ")

  # print(net.hid1.weight)  # one wt is 0
  # print(net.hid1.bias)    # corresponding bias is 0

# -----------------------------------------------------------

  # 4. evaluate model accuracy
  print("\nComputing model accuracy")
  net.eval()
  acc_train = accuracy(net, train_ds)  # item-by-item
  print("Accuracy on training data = %0.4f" % acc_train)
  acc_test = accuracy(net, test_ds) 
  print("Accuracy on test data = %0.4f" % acc_test)

  # 5. make a prediction
  print("\nPredicting job for M  30  concord  $50,000: ")
  X = np.array([[-1, 0.30,  0,0,1,  0.5000]], dtype=np.float32)
  X = T.tensor(X, dtype=T.float32).to(device) 

  with T.no_grad():
    logits = net(X)  # do not sum to 1.0
  probs = T.softmax(logits, dim=1)  # tensor
  probs = probs.numpy()  # numpy vector prints better
  np.set_printoptions(precision=4, suppress=True)
  print(probs)

  # 6. save model (state_dict approach)
  print("\nSaving trained model state")
  # fn = ".\\Models\\employee_model.pth"
  # T.save(net.state_dict(), fn)

  print("\nEnd Employee predict job pruning demo")

if __name__ == "__main__":
  main()
This entry was posted in PyTorch. Bookmark the permalink.

Leave a Reply

Please log in using one of these methods to post your comment:

WordPress.com Logo

You are commenting using your WordPress.com account. Log Out /  Change )

Twitter picture

You are commenting using your Twitter account. Log Out /  Change )

Facebook photo

You are commenting using your Facebook account. Log Out /  Change )

Connecting to %s