Exploring LSTM Networks

One of the most interesting, relatively recent, developments in deep machine learning is the success of LSTM networks. I’ve been on a mission to completely understand LSTMs. I’ve made good progress. In the end, LSTMs aren’t too difficult to understand — but only if you have a complete understanding of basic neural networks.

LSTMs are an enhanced form of recurrent neural networks (RNNs). An ordinary neural network has no memory and so can’t remember things. The standard example is predicting the next word in a sentence. For example, “The baseball team that won the championship in 2016 was (predict).” A regular NN would have great difficulty but an RNN could be trained to make the prediction because the word “baseball” is nearby.

However, RNNs have very short term memory. For example, an RNN would have difficulty with, “I’m a big fan of baseball. By the way, the team that won the championship in 2016 was (predict).” The key term “baseball” is too far away.

LSTM networks were designed to have longer short-term memory — LSTM stands for Long Short Term Memory.

There are many variations of LSTM. The basic variation is beautiful (from a software engineer’s point of view anyway). A deep LSTM network can be constructed using a sequence of identical code modules. At first glance, the architecture of an LSTM module looks complex. To be sure, they’re not simple, but once you know what each symbol means, implementing an LSTM isn’t too difficult.

There are existing code libraries that implement a basic LSTM but my goal is to code one up from scratch. That way, I’ll be confident I really understand LSTMs and I’ll have the ability to experiment with custom LSTM variations.

This entry was posted in Machine Learning. Bookmark the permalink.