matchzoo.modules.spatial_gru
¶
Spatial GRU module.
Module Contents¶
-
class
matchzoo.modules.spatial_gru.
SpatialGRU
(channels: int = 4, units: int = 10, activation: typing.Union[str, typing.Type[nn.Module], nn.Module] = 'tanh', recurrent_activation: typing.Union[str, typing.Type[nn.Module], nn.Module] = 'sigmoid', direction: str = 'lt')¶ Bases:
torch.nn.Module
Spatial GRU Module.
Parameters: - channels – Number of word interaction tensor channels.
- units – Number of SpatialGRU units.
- activation – Activation function to use, one of: - String: name of an activation - Torch Modele subclass - Torch Module instance Default: hyperbolic tangent (tanh).
- recurrent_activation –
Activation function to use for the recurrent step, one of:
- String: name of an activation
- Torch Modele subclass
- Torch Module instance
Default: sigmoid activation (sigmoid).
- direction – Scanning direction. lt (i.e., left top) indicates the scanning from left top to right bottom, and rb (i.e., right bottom) indicates the scanning from right bottom to left top.
Examples
>>> import matchzoo as mz >>> channels, units= 4, 10 >>> spatial_gru = mz.modules.SpatialGRU(channels, units)
-
reset_parameters
(self)¶ Initialize parameters.
-
softmax_by_row
(self, z: torch.tensor)¶ Conduct softmax on each dimension across the four gates.
-
calculate_recurrent_unit
(self, inputs: torch.tensor, states: list, i: int, j: int)¶ Calculate recurrent unit.
Parameters: - inputs – A tensor which contains interaction between left text and right text.
- states – An array of tensors which stores the hidden state of every step.
- i – Recurrent row index.
- j – Recurrent column index.
-
forward
(self, inputs)¶ Perform SpatialGRU on word interation matrix.
Parameters: inputs – input tensors.