Source: floydhub
For our first step, we’ll be using an RNN or any of its variants (e.g. LSTM, GRU) to encode the input sequence. After passing the input sequence through the encoder RNN, a hidden state/output will be produced for each input passed in. Instead of using only the hidden state at the final time step, we’ll be carrying forward all the hidden states produced by the encoder to the next step.
class EncoderLSTM(nn.Module):
def __init__(self, input_size, hidden_size, n_layers=1, drop_prob=0):
super(EncoderLSTM, self).__init__()
self.hidden_size = hidden_size
self.n_layers = n_layers
self.embedding = nn.Embedding(input_size, hidden_size)
self.lstm = nn.LSTM(hidden_size, hidden_size, n_layers, dropout=drop_prob, batch_first=True)
def forward(self, inputs, hidden):
# Embed input words
embedded = self.embedding(inputs)
# Pass the embedded word vectors into LSTM and return all outputs
output, hidden = self.lstm(embedded, hidden)
return output, hidden
def init_hidden(self, batch_size=1):
return (torch.zeros(self.n_layers, batch_size, self.hidden_size, device=device),
torch.zeros(self.n_layers, batch_size, self.hidden_size, device=device))
In the code implementation of the encoder above, we’re first embedding the input words into word vectors (assuming that it’s a language task) and then passing it through an LSTM. The encoder over here is exactly the same as a normal encoder-decoder structure without Attention.
For these next 3 steps, we will be going through the processes that happen in the Attention Decoder and discuss how the Attention mechanism is utilised. The class BahdanauDecoderLSTM defined below encompasses these 3 steps in the forward function.
class BahdanauDecoder(nn.Module):
def __init__(self, hidden_size, output_size, n_layers=1, drop_prob=0.1):
super(BahdanauDecoder, self).__init__()
self.hidden_size = hidden_size
self.output_size = output_size
self.n_layers = n_layers
self.drop_prob = drop_prob
self.embedding = nn.Embedding(self.output_size, self.hidden_size)
self.fc_hidden = nn.Linear(self.hidden_size, self.hidden_size, bias=False)
self.fc_encoder = nn.Linear(self.hidden_size, self.hidden_size, bias=False)
self.weight = nn.Parameter(torch.FloatTensor(1, hidden_size))
self.attn_combine = nn.Linear(self.hidden_size * 2, self.hidden_size)
self.dropout = nn.Dropout(self.drop_prob)
self.lstm = nn.LSTM(self.hidden_size*2, self.hidden_size, batch_first=True)
self.classifier = nn.Linear(self.hidden_size, self.output_size)
def forward(self, inputs, hidden, encoder_outputs):
encoder_outputs = encoder_outputs.squeeze()
# Embed input words
embedded = self.embedding(inputs).view(1, -1)
embedded = self.dropout(embedded)
# Calculating Alignment Scores
x = torch.tanh(self.fc_hidden(hidden[0])+self.fc_encoder(encoder_outputs))
alignment_scores = x.bmm(self.weight.unsqueeze(2))
# Softmaxing alignment scores to get Attention weights
attn_weights = F.softmax(alignment_scores.view(1,-1), dim=1)
# Multiplying the Attention weights with encoder outputs to get the context vector
context_vector = torch.bmm(attn_weights.unsqueeze(0),
encoder_outputs.unsqueeze(0))
# Concatenating context vector with embedded input word
output = torch.cat((embedded, context_vector[0]), 1).unsqueeze(0)
# Passing the concatenated vector as input to the LSTM cell
output, hidden = self.lstm(output, hidden)
# Passing the LSTM output through a Linear layer acting as a classifier
output = F.log_softmax(self.classifier(output[0]), dim=1)
return output, hidden, attn_weights
After obtaining all of our encoder outputs, we can start using the decoder to produce outputs. At each time step of the decoder, we have to calculate the alignment score of each encoder output with respect to the decoder input and hidden state at that time step. The alignment score is the essence of the Attention mechanism, as it quantifies the amount of “Attention” the decoder will place on each of the encoder outputs when producing the next output.
The alignment scores for Bahdanau Attention are calculated using the hidden state produced by the decoder in the previous time step and the encoder outputs with the following equation:
$$ score_{alignment} = W_{combined}⋅tanh(W_{decoder}⋅H_{decoder}+W_{encoder}⋅H_{encoder}) $$
The decoder hidden state and encoder outputs will be passed through their individual Linear layer and have their own individual trainable weights.
Linear layers for encoder outputs and decoder hidden states
In the illustration above, the hidden size is 3 and the number of encoder outputs is 2.
Thereafter, they will be added together before being passed through a tanh activation function. The decoder hidden state is added to each encoder output in this case.