The t-Distributed Stochastic Neighbor Embedding Technique

I don’t normally work with data visualizations but a few days ago I looked at using principal component analysis (PCA) for dimension reduction for visualization. At a recent conference, one of my colleagues mentioned t-distributed stochastic neighbor embedding (TSNE or t-SNE), which is also a technique for dimensionality reduction for visualization, so I thought I’d explore a bit.

The idea of dimensionality reduction is perhaps best explained by example. Suppose you have a set of 8×8 images where each image is a crude handwritten digit from ‘0’ to ‘9’. This dataset has 64 dimensions and so you can’t visualize the dataset easily. TSNE can be used to reduce the dimensionality from 64 down to 2 or 3 (or any dimension in principle) so the dataset can be graphed (and then examined by a human to see if there are any interesting patterns).

I coded up a short demo. I used the built-in load_digits() function from the scikit-learn library. The dataset has 1797 digits. First I displayed digit [04] which is a four. Then I used TSNE to reduce the dimensionality of the entire dataset to 2 so it could be graphed.

It’s tempting to think that you could use TSNE for clustering, or perhaps for k-NN classification, but for technical reasons TSNE isn’t really suited for those tasks. TSNE is well-suited for visualizations however.

import numpy as np
import matplotlib.pyplot as plt

from sklearn.datasets import load_digits
from sklearn.manifold import TSNE

print("Begin t-distributed stochastic neighbor embedding")
print("Loading 1797 8x8 digit images into memory")
digits = load_digits()

print("Displaying digit [04] which is a four")
pixels = digits.data[4]
pixels = pixels.reshape((8,8))
for i in range(8):
  for j in range(8):
    v = np.int(pixels[i,j])
    print("%.2X " % v, end="")
    #print(" ", end="")
  print("")

# print(digits.target[4])

print("Displaying digit using pyplot")
img = np.array(digits.data[4])   # as float32
img = img.reshape((8,8))
plt.imshow(img, cmap=plt.get_cmap('gray_r'))
plt.show()  

print("Using TSNE(2) on entire dataset")
tsne = TSNE(2)
projected = tsne.fit_transform(digits.data)

plt.scatter(projected[:, 0], projected[:, 1],
            c=digits.target, edgecolor='none', alpha=0.9,
            cmap=plt.cm.get_cmap('nipy_spectral', 10),
            s=80)
plt.xlabel('component 1')
plt.ylabel('component 2')
plt.colorbar()
plt.show()
Advertisements
This entry was posted in Machine Learning. Bookmark the permalink.