matchzoo.modules.matching

Matching module.

Module Contents

class matchzoo.modules.matching.Matching(normalize: bool = False, matching_type: str = 'dot')

Bases: torch.nn.Module

Module that computes a matching matrix between samples in two tensors.

Parameters:
  • normalize – Whether to L2-normalize samples along the dot product axis before taking the dot product. If set to True, then the output of the dot product is the cosine proximity between the two samples.
  • matching_type – the similarity function for matching

Examples

>>> import torch
>>> matching = Matching(matching_type='dot', normalize=True)
>>> x = torch.randn(2, 3, 2)
>>> y = torch.randn(2, 4, 2)
>>> matching(x, y).shape
torch.Size([2, 3, 4])
classmethod _validate_matching_type(cls, matching_type: str = 'dot')
forward(self, x, y)

Perform attention on the input.