matchzoo.modules.spatial_gru

Spatial GRU module.

Module Contents

class matchzoo.modules.spatial_gru.SpatialGRU(channels: int = 4, units: int = 10, activation: typing.Union[str, typing.Type[nn.Module], nn.Module] = 'tanh', recurrent_activation: typing.Union[str, typing.Type[nn.Module], nn.Module] = 'sigmoid', direction: str = 'lt')

Bases: torch.nn.Module

Spatial GRU Module.

Parameters:
  • channels – Number of word interaction tensor channels.
  • units – Number of SpatialGRU units.
  • activation – Activation function to use, one of: - String: name of an activation - Torch Modele subclass - Torch Module instance Default: hyperbolic tangent (tanh).
  • recurrent_activation

    Activation function to use for the recurrent step, one of:

    • String: name of an activation
    • Torch Modele subclass
    • Torch Module instance

    Default: sigmoid activation (sigmoid).

  • direction – Scanning direction. lt (i.e., left top) indicates the scanning from left top to right bottom, and rb (i.e., right bottom) indicates the scanning from right bottom to left top.

Examples

>>> import matchzoo as mz
>>> channels, units= 4, 10
>>> spatial_gru = mz.modules.SpatialGRU(channels, units)
reset_parameters(self)

Initialize parameters.

softmax_by_row(self, z: torch.tensor)

Conduct softmax on each dimension across the four gates.

calculate_recurrent_unit(self, inputs: torch.tensor, states: list, i: int, j: int)

Calculate recurrent unit.

Parameters:
  • inputs – A tensor which contains interaction between left text and right text.
  • states – An array of tensors which stores the hidden state of every step.
  • i – Recurrent row index.
  • j – Recurrent column index.
forward(self, inputs)

Perform SpatialGRU on word interation matrix.

Parameters:inputs – input tensors.