matchzoo.modules.spatial_gru

Spatial GRU module.

Module Contents

Classes

SpatialGRU

Spatial GRU Module.

class matchzoo.modules.spatial_gru.SpatialGRU(channels: int = 4, units: int = 10, activation: typing.Union[str, typing.Type[nn.Module], nn.Module] = 'tanh', recurrent_activation: typing.Union[str, typing.Type[nn.Module], nn.Module] = 'sigmoid', direction: str = 'lt')

Bases: torch.nn.Module

Spatial GRU Module.

Parameters
  • channels – Number of word interaction tensor channels.

  • units – Number of SpatialGRU units.

  • activation – Activation function to use, one of: - String: name of an activation - Torch Modele subclass - Torch Module instance Default: hyperbolic tangent (tanh).

  • recurrent_activation

    Activation function to use for the recurrent step, one of:

    • String: name of an activation

    • Torch Modele subclass

    • Torch Module instance

    Default: sigmoid activation (sigmoid).

  • direction – Scanning direction. lt (i.e., left top) indicates the scanning from left top to right bottom, and rb (i.e., right bottom) indicates the scanning from right bottom to left top.

Examples

>>> import matchzoo as mz
>>> channels, units= 4, 10
>>> spatial_gru = mz.modules.SpatialGRU(channels, units)
reset_parameters(self)

Initialize parameters.

softmax_by_row(self, z: torch.tensor) → tuple

Conduct softmax on each dimension across the four gates.

calculate_recurrent_unit(self, inputs: torch.tensor, states: list, i: int, j: int)

Calculate recurrent unit.

Parameters
  • inputs – A tensor which contains interaction between left text and right text.

  • states – An array of tensors which stores the hidden state of every step.

  • i – Recurrent row index.

  • j – Recurrent column index.

forward(self, inputs)

Perform SpatialGRU on word interation matrix.

Parameters

inputs – input tensors.