Building a ChatBot, Pt. 2: Building a Conversational Tensorflow Model
TensorFlow is an open-source machine learning library developed by the Google Brain team and released in November 2015. It relies on the construction of dataflow graphs with nodes that represent mathematical operations (ops) and edges that represent tensors (multidimensional arrays represented internally as numpy ndarrays). The dataflow graph is a summary of all the computations that are executed asynchronously and in parallel within a TensorFlow session on a given device (CPUs or GPUs). The library boasts true portability between devices, meaning that the same code can be used to run models in CPU-only or heterogeneous GPU-accelerated environments.
TensorFlow relies on highly optimized C++ for computation and supports APIs in C and C++ in addition to the Python API used to build this ChatBot.
The ChatBot is an embedding Seq2Seq model based initially on a French-English translation model built by some of the same Google engineers that wrote TensorFlow. A TensorFlow version of this model is available as a Seq2Seq tutorial provided in the TensorFlow documentation.
Word embedding in this model is a trainable process that embeds words from the vocabulary into a dense vector where semantically similar words are mapped to the same space. This is good for models like my ChatBot because there is often no "right" answer for language and conversation problems. For example, if I have a sentence that begins with "I like to eat", both "chocolate" and "pizza" would be equally correct. Without embedding, the model would learn about "pizza" and "chocolate" as independent concepts, but with embedding, it can leverage what it has learned about "chocolate" when processing sentences about "pizza" or similar words like "burrito" or "spaghetti." If you're interested, TensorFlow has a tutorial on word embeddings here.
Besides the embedding cell, the model is made of LSTM cells with an internal cell state that changes as inputs (in this case, words in a sentence) are fed sequentially into the model. This cell state allows the model to consider the context in which an input is received, and the output for a given input depends partially on the inputs that came before. To learn more about LSTMs, check out this blogpost from Christopher Ola.
This model has 25,000 input and output nodes (one for each word in the vocabulary) and 3 hidden layers of 768 nodes each. The vocabulary is the 25,000 most common words in the Reddit dataset except numbers and represents contracted words as their separated parts (so "I'll" is represented as "I", "'", "ll"). The other aspects of the model's architecture were determined experimentally to optimize for training time without sacrificing too much pattern-recognition capability.
Perplexity and Adjusting the Learning Rate
I trained this model on a dataset with 2,970,943 prompt/response pairs with a batch-size of 64. This means that 1 epoch requires 46,421 steps.
The quality of the model was evaluated by measuring the perplexity of its responses compared to a separate validation dataset (10% the size of the training dataset). Perplexity is the standard way to evaluate language models and is basically a complicated measurement of the difference between the empirical distribution (in the validation dataset) and the predicted distribution of your model over all possible next words for a given word in a sentence. Using word embedding plays a huge role in decreasing perplexity quickly since perplexity is indirectly determined by the semantic similarity between the 2 probability distributions. For a more in-depth explanation of perplexity, you can go watch this short video from Dan Jurafsky.
In early versions of the model, the learning rate started at .5 and decreased by 1% if the perplexity did not improve after 600 steps (38,400 training pairs). The problem with this approach was that the learning rate began to drop well before the model had completed an entire epoch. To guarantee that the model was being trained equally by the entire dataset, I adjusted the learning rate to be .5 for the first 5 epochs and decreased it by 50% every epoch for a total of 7.5 epochs.