In a naive encoder-decoder model, one RNN unit reads a sentence, and the other one outputs a sentence, as in machine translation.
But what can be done to improve this model’s performance? Here, we’ll explore a modification to this encoder-decoder mechanism, commonly known as an attention model.
The role of an attention model in long sequences
In machine translation, we’re feeding our input into the encoder (green part) of the network, with the output coming from the decoder (purple part) of the network, as depicted above.
The approach followed by the neural network is to memorize the input and store it in the activation units. This context vector is expected to be a good summary of the input sentence, and is then translated altogether. This approach is different from the human approach of translation, in which the translation is done by taking few input words at a time.
BLEU, or the Bilingual Evaluation Understudy, is a score for comparing a candidate translation of text to one or more reference translations. The above graph shows that the encoder-decoder unit fails to memorize the whole long sentence, and hence what’s reflected from the graph above is that the encoder-decoder unit works well for shorter sentences (high bleu score).
If the encoder makes a bad summary, the translation will also be bad. And indeed it has been observed that the encoder creates a bad summary when it tries to understand longer sentences. It is called the long-range dependency problem of RNN/LSTMs.
On that note, here comes the role of the attention mechanism in such long sequences.
Understanding the attention mechanism
In this case, we’ve used a bidirectional RNN in order to compute a set of features for each of the input words. All the vectors h1,h2.., etc., used are basically the concatenation of forward and backward hidden states in the encoder.
Now, what part of the input sentence X1, X2, X3,..XT would be used to generate the output yT? The attention model computes a set of attention weights denoted by α(t,1),…,α(t,t) because not all the inputs would be used in generating the corresponding output. The context vector ci for the output word yi is generated using the weighted sum of the annotations:
The attention weights are calculated by normalizing the output score of a feed-forward neural network described by the function a that captures the alignment between input at j and output at i.
Types of attention
Depending on how many source states that contribute while deriving the attention vector(α), there can be three types of attention mechanisms:
- Global(Soft) Attention: When attention is placed on all source states. In global attention, we require as many weights as the source sentence length.
- Local Attention: When attention is placed on few source states.
- Hard Attention: When attention is placed on only one source state.
Implementation of an attention model on the IMDB dataset using Keras
You can see for yourself that using attention yields a higher accuracy on the IMDB dataset. Here, we consider two LSTM networks: one with the attention layer and the other one with a fully connected layer. Each network has the same number of parameters (250K in my example).
The attention model needs an attention vector to be calculated, which can be done using the below code snippet:
def attention_3d_block(hidden_states): """ @param hidden_states: 3D tensor with shape (batch_size, time_steps, input_dim). @return: 2D tensor with shape (batch_size, 128) """ hidden_size = int(hidden_states.shape) score_first_part = Dense(hidden_size, use_bias=False, name='attention_score_vec')(hidden_states) h_t = Lambda(lambda x: x[:, -1, :], output_shape=(hidden_size,), name='last_hidden_state')(hidden_states) score = dot([score_first_part, h_t], [2, 1], name='attention_score') attention_weights = Activation('softmax', name='attention_weight')(score) context_vector = dot([hidden_states, attention_weights], [1, 1], name='context_vector') pre_activation = concatenate([context_vector, h_t], name='attention_output') attention_vector = Dense(128, use_bias=False, activation='tanh', name='attention_vector')(pre_activation) return attention_vector
Here’s a comparative study of the accuracy obtained with and without an attention mechanism in just 10 epochs.
This shows the positive effect of an attention model in the overall LSTM network. To view the complete code of this tutorial, check out my GitHub repository.