## Nucleus Sampling for Natural Language Processing

I ran into an interesting idea called nucleus sampling, also called top-p sampling. Nucleus sampling is used for natural language processing (NLP) next-word prediction.

Suppose you have a sentence that starts with “I got up and ran to the . . ” and you have a prediction model that emits next possible words and their associated logits (raw output values from the model). Suppose the model predicts these 7 words, ordered from largest logit (most likely) to smallest logit: door, car, store, race, finish, monkey, apple.

The simplest way to pick the next word is to just select the most likely, “door”. Picking the most likely next word doesn’t work very well because it turns out that when you string several words together in this way, the generated text doesn’t seem human.

A more sophisticated way is to examine the top-k candidates and then randomly select one of those top-k words. For example, if k = 3 the top three next words are door, car, store and you’d use one of these, randomly selected (either uniform random or from a multinomial distribution), as the next word.

The problem with top-k selection is that it’s difficult to pick k. If k is too large, you might include bad candidates. If k is too small, you might exclude good candidates.

For nucleus sampling, you select the fewest top items where the cumulative probability is less than some specified probability threshold, p. So you convert logits to probabilities using the softmax() function, then compute the cumulative probabilities, then select the items where the cumulative probability is less than p, then randomly select one of the candidates.

For example:

```next    logits   exp      prob   cum prob
------------------------------------------
door    -0.54   0.5827   0.2201   0.2201
car     -0.65   0.5220   0.1972   0.4173
store   -0.79   0.4538   0.1714   0.5888
race    -0.98   0.3753   0.1418   0.7306
finish  -1.28   0.2780   0.1050   0.8356
monkey  -1.43   0.2393   0.0904   0.9260
apple   -1.63   0.1959   0.0740   1.0000

2.6472	 1.0000
```

I’ve listed the logits from large to small (logits are often, but not always, negative). The exp column is a scratch for the calculation of softmax-probabilities. Notice that the sum of the prob values is 1.0 as it should be.

With the cumulative probabilities in hand, you can specify a top-p value, such as p = 0.6 and then select those items where the cumulative probability is less than 0.6 which would be “door”, “car”, “store”, and then randomly pick one of these candidates.

Nucleus sampling isn’t magic because you still have to specify the threshold p value, but it’s easier to pick a good p value for nucleus sampling than it is to pick a good k value for top-k sampling.

The research paper which describes nucleus sampling is “The Curious Case of Neural Text Degeneration” by A. Holtzman et al. The paper presents some evidence that nucleus sampling generates more human-like text than top-k sampling.

I initially ran into nucleus sampling while working with the Hugging Face (HF) neural code library for NLP. The HF library has a weird function top_k_top_p_filtering() that combines top-k and top-p (nucleus) sampling.

Filtering logits for natural language — Tricky. Interesting. Three dresses made from coffee filters. Tricky. Interesting.

This entry was posted in Machine Learning. Bookmark the permalink.