matchzoo.modules.stacked_brnn
¶
Module Contents¶
Classes¶
Stacked Bi-directional RNNs. |
-
class
matchzoo.modules.stacked_brnn.
StackedBRNN
(input_size, hidden_size, num_layers, dropout_rate=0, dropout_output=False, rnn_type=nn.LSTM, concat_layers=False)¶ Bases:
torch.nn.Module
Stacked Bi-directional RNNs.
Differs from standard PyTorch library in that it has the option to save and concat the hidden states between layers. (i.e. the output hidden size for each sequence input is num_layers * hidden_size).
Examples
>>> import torch >>> rnn = StackedBRNN( ... input_size=10, ... hidden_size=10, ... num_layers=2, ... dropout_rate=0.2, ... dropout_output=True, ... concat_layers=False ... ) >>> x = torch.randn(2, 5, 10) >>> x.size() torch.Size([2, 5, 10]) >>> x_mask = (torch.ones(2, 5) == 1) >>> rnn(x, x_mask).shape torch.Size([2, 5, 20])
-
forward
(self, x, x_mask)¶ Encode either padded or non-padded sequences.
-
_forward_unpadded
(self, x, x_mask)¶ Faster encoding that ignores any padding.
-