matchzoo.modules.spatial_gru
¶
Spatial GRU module.
Module Contents¶
Classes¶
Spatial GRU Module. |
-
class
matchzoo.modules.spatial_gru.
SpatialGRU
(channels: int = 4, units: int = 10, activation: typing.Union[str, typing.Type[nn.Module], nn.Module] = 'tanh', recurrent_activation: typing.Union[str, typing.Type[nn.Module], nn.Module] = 'sigmoid', direction: str = 'lt')¶ Bases:
torch.nn.Module
Spatial GRU Module.
- Parameters
channels – Number of word interaction tensor channels.
units – Number of SpatialGRU units.
activation – Activation function to use, one of: - String: name of an activation - Torch Modele subclass - Torch Module instance Default: hyperbolic tangent (tanh).
recurrent_activation –
Activation function to use for the recurrent step, one of:
String: name of an activation
Torch Modele subclass
Torch Module instance
Default: sigmoid activation (sigmoid).
direction – Scanning direction. lt (i.e., left top) indicates the scanning from left top to right bottom, and rb (i.e., right bottom) indicates the scanning from right bottom to left top.
Examples
>>> import matchzoo as mz >>> channels, units= 4, 10 >>> spatial_gru = mz.modules.SpatialGRU(channels, units)
-
reset_parameters
(self)¶ Initialize parameters.
-
softmax_by_row
(self, z: torch.tensor) → tuple¶ Conduct softmax on each dimension across the four gates.
-
calculate_recurrent_unit
(self, inputs: torch.tensor, states: list, i: int, j: int)¶ Calculate recurrent unit.
- Parameters
inputs – A tensor which contains interaction between left text and right text.
states – An array of tensors which stores the hidden state of every step.
i – Recurrent row index.
j – Recurrent column index.
-
forward
(self, inputs)¶ Perform SpatialGRU on word interation matrix.
- Parameters
inputs – input tensors.