Using a Variational Autoencoder for Dimensionality Reduction to Make a Visualization

One morning, I had just written a PyTorch program that used a neural autoencoder to reduce MNIST 28 by 28 digits from 784 dimensions down to 2 dimensions, so that each image could be plotted on an xy graph. It was an interesting experiment. I figured that as long as I had a lot of the data wrangling code handy, I’d try to use a variational autoencoder (VAE) for dimensionality reduction and visualization. This was a stretch because VAEs are designed to generate synthetic data, and not for dimensionality reduction.

The bottom line is that I don’t think the idea worked very well — a VAE does not appear to be well suited for dimensionality reduction for visualization. But like many things related to deep neural networks, there were as many questions raised as there were questions answered.

For my source data, I used the first 10,000 of the 60,000 MNIST training images. See

My demo VAE had a 784-400-[2,2]-2-400-784 architecture. In a preliminary try, I reduced the images down to [1,1] where the first value was the data distribution mean and the second value was the distribution log-variance. I soon realized that this approach didn’t make sense because I’d be plotting a mean against a log-variance. So, I increased the core representation to [2,2] where the first component is a mean vector with two values and the second component is a log-variance with two values. To graph the reduced form of each MNIST image, I used just the two mean values.

VAEs are tricky. They use a custom error/loss function. All examples I’ve seen on the Internet use binary cross entropy loss plus Kullback-Leibler divergence. I’m skeptical of the binary cross entropy error component and so I used mean squared loss instead.

The resulting visualization concentrated all of the images in one giant cluster, but did not produce sub-clusters for each of the ten digit types. This is a good thing for generating synthetic data because two random seed values will be near an image representation and probably produce a realistic looking synthetic image. But no sub-clustering isn’t good for visualization because no patterns emerge.

Anyway, it was an interesting experiment. In the back of my mind, I’m thinking about the idea of using a VAE for anomaly detection. Regular autoencoders are quite good at anomaly detection if you use reconstruction error. But regular autoencoders tend to overfit data so you get lots of false positive detections. VAEs tend to underfit, as this experiment showed, so a VAE anomaly detection system would likely give false negatives. My idea is to chain together a regular autoencoder with a VAE for anomaly detection. But I don’t have any of the details of the proposed model worked out in my mind.

Artist Abdulrahman Eid created an incredibly detailed model diorama of a street in 1950s Jeddah, Saudi Arabia.

Code below (long).


# PyTorch variational autoencoder for MNIST visualization
# compress each 28x28 MNIST digit to 2 values then plot

# use custom generated text MNIST rather than
# the built-in torchvision MNIST

# PyTorch 1.8.0-CPU Anaconda3-2020.02  Python 3.7.6
# CPU, Windows 10

import numpy as np
import torch as T
import matplotlib.pyplot as plt
import torchvision as tv  # to visualize fakes

device = T.device("cpu")

# -----------------------------------------------------------

class MNIST_Dataset(
  # for an Autoencoder (not a classifier)
  # assumes data has been converted to tab-delim text files:
  # 784 pixel values (0-255) (tab) label (0-9)
  # [0] [1] . . [783] [784] 

  def __init__(self, src_file):
    tmp_x = np.loadtxt(src_file, usecols=range(0,784),
      delimiter="\t", comments="#", dtype=np.float32)
    tmp_y = np.loadtxt(src_file, usecols=[784],
      delimiter="\t", comments="#", dtype=np.int64)
    self.x_data = T.tensor(tmp_x, dtype=T.float32).to(device) 
    self.x_data /= 255.0  # normalize pixels
    self.y_data = T.tensor(tmp_y, dtype=T.int64).to(device)
    # don't normalize digit labels

  def __len__(self):
    return len(self.x_data)

  def __getitem__(self, idx):
    pixels = self.x_data[idx]
    label = self.y_data[idx]
    return (pixels, label)

# -----------------------------------------------------------

class VAE(T.nn.Module):  # [784-400-[2,2]-2-400-784]
  def __init__(self):
    super(VAE, self).__init__()  
    self.fc1 = T.nn.Linear(784, 400)  # no labels
    self.fc2a = T.nn.Linear(400, 2)   # u
    self.fc2b = T.nn.Linear(400,2)   # log-var
    self.fc3 = T.nn.Linear(2, 400) 
    self.fc4 = T.nn.Linear(400, 784)

  def encode(self, x):              # 784-400-[2,2]  
    z = T.relu(self.fc1(x)) 
    z1 = self.fc2a(z)               # activation here ??
    z2 = self.fc2b(z) 
    return (z1, z2)                 # (u, log-var)

  def decode(self, x):              # 1-400-784
    z = T.relu(self.fc3(x))      
    z = T.sigmoid(self.fc4(z))      # in [0, 1]
    return z 

  def forward(self, x):
    (u, logvar) = self.encode(x)
    stdev = T.exp(0.5 * logvar)
    noise = T.randn_like(stdev)
    z = u + (noise * stdev)         # [2]
    oupt = self.decode(z)
    return (oupt, u, logvar)

# -----------------------------------------------------------

def cus_loss_func(recon_x, x, u, logvar):
  # KLD = 0.5 * sum(1 + log(sigma^2) - u^2 - sigma^2)
  # bce = T.nn.functional.binary_cross_entropy(recon_x, \
  #   x.view(-1, 784), reduction="sum")

  # mse = T.nn.functional.mse_loss(recon_x, x.view(-1, 784))
  mse = T.nn.functional.mse_loss(recon_x, x)

  kld = -0.5 * T.sum(1 + logvar - u.pow(2) - \

  BETA = 1.0
  return mse + (BETA * kld)

# -----------------------------------------------------------

def train(vae, ds, bs, me, lr, le):
  # train autoencoder vae with dataset ds using batch size bs, 
  # with max epochs me, learn rate lr, log_every le
  data_ldr =, batch_size=bs,
  # loss_func = T.nn.MSELoss() # use custom loss
  opt = T.optim.SGD(vae.parameters(), lr=lr)
  print("Starting training")
  for epoch in range(0, me):
    for (b_idx, batch) in enumerate(data_ldr):
      X = batch[0]  # don't use Y labels to train
      recon_x, u, logvar = vae(X)
      loss_val = cus_loss_func(recon_x, X, u, logvar)

    if epoch != 0 and epoch % le == 0:
      print("epoch = %6d" % epoch, end="")
      print("  curr batch loss = %7.4f" % \
loss_val.item(), end="")

      # save and view sample images as sanity check
      num_images = 64
      rinpt = T.randn(num_images, 2).to(device)
      with T.no_grad():
        fakes = vae.decode(rinpt)
      fakes = fakes.view(num_images, 1, 28, 28)
        ".\\Fakes\\fakes_" + str(epoch) + ".jpg",
        padding=4, pad_value=1.0) # no overwrite

  print("Training complete ")

# -----------------------------------------------------------

def main():
  # 0. get started
  print("\nBegin MNIST VAE visualization ")

  # 1. create Dataset object
  print("\nCreating MNIST Dataset ")
  fn = ".\\Data\\mnist_train_10000.txt"
  data_ds = MNIST_Dataset(fn)

  # 2. create and train VAE model 
  print("\nCreating VAE  \n")
  vae = VAE()   # 784-400-[2,2]-2-400-784
  vae.train()           # set mode

  bat_size = 10
  max_epochs = 40
  lrn_rate = 0.01
  log_every = int(max_epochs / 10)
  train(vae, data_ds, bat_size, max_epochs, \
    lrn_rate, log_every)

  # 3. TODO: save trained VAE

  # 4. use model encoder to generate (x,y) pairs
  all_pixels = data_ds[0:10000][0]  # all pixel values
  all_labels = data_ds[0:10000][1]

  with T.no_grad():
    u, logvar = vae.encode(all_pixels) # mean logvar

  print("\nImages reduced to 2 values: ")

  # 5. graph the reduced-form digits in 2D
  print("\nPlotting reduced-dim MNIST images")
  plt.scatter(u[:,0], u[:,1],
            c=all_labels, edgecolor='none', alpha=0.9,
  'nipy_spectral', 11),
            s=20)  # s=20 orig, alpha=0.9

  print("\nEnd MNIST VAE visualization")

# -----------------------------------------------------------

if __name__ == "__main__":
This entry was posted in PyTorch. Bookmark the permalink.

Leave a Reply

Please log in using one of these methods to post your comment: Logo

You are commenting using your account. Log Out /  Change )

Google photo

You are commenting using your Google account. Log Out /  Change )

Twitter picture

You are commenting using your Twitter account. Log Out /  Change )

Facebook photo

You are commenting using your Facebook account. Log Out /  Change )

Connecting to %s