matchzoo.modules.stacked_brnn

Module Contents

class matchzoo.modules.stacked_brnn.StackedBRNN(input_size, hidden_size, num_layers, dropout_rate=0, dropout_output=False, rnn_type=nn.LSTM, concat_layers=False)

Bases: torch.nn.Module

Stacked Bi-directional RNNs.

Differs from standard PyTorch library in that it has the option to save and concat the hidden states between layers. (i.e. the output hidden size for each sequence input is num_layers * hidden_size).

Examples

>>> import torch
>>> rnn = StackedBRNN(
...     input_size=10,
...     hidden_size=10,
...     num_layers=2,
...     dropout_rate=0.2,
...     dropout_output=True,
...     concat_layers=False
... )
>>> x = torch.randn(2, 5, 10)
>>> x.size()
torch.Size([2, 5, 10])
>>> x_mask = (torch.ones(2, 5) == 1)
>>> rnn(x, x_mask).shape
torch.Size([2, 5, 20])
forward(self, x, x_mask)

Encode either padded or non-padded sequences.

_forward_unpadded(self, x, x_mask)

Faster encoding that ignores any padding.