matchzoo.dataloader.dataset

A basic class representing a Dataset.

Module Contents

class matchzoo.dataloader.dataset.Dataset(data_pack: mz.DataPack, mode='point', num_dup: int = 1, num_neg: int = 1, batch_size: int = 32, resample: bool = False, shuffle: bool = True, sort: bool = False, callbacks: typing.List[BaseCallback] = None)

Bases: torch.utils.data.IterableDataset

Dataset that is built from a data pack.

Parameters:
  • data_pack – DataPack to build the dataset.
  • mode – One of “point”, “pair”, and “list”. (default: “point”)
  • num_dup – Number of duplications per instance, only effective when mode is “pair”. (default: 1)
  • num_neg – Number of negative samples per instance, only effective when mode is “pair”. (default: 1)
  • batch_size – Batch size. (default: 32)
  • resample – Either to resample for each epoch, only effective when mode is “pair”. (default: True)
  • shuffle – Either to shuffle the samples/instances. (default: True)
  • sort – Whether to sort data according to length_right. (default: False)
  • callbacks – Callbacks. See matchzoo.dataloader.callbacks for more details.

Examples

>>> import matchzoo as mz
>>> data_pack = mz.datasets.toy.load_data(stage='train')
>>> preprocessor = mz.preprocessors.BasicPreprocessor()
>>> data_processed = preprocessor.fit_transform(data_pack)
>>> dataset_point = mz.dataloader.Dataset(
...     data_processed, mode='point', batch_size=32)
>>> len(dataset_point)
4
>>> dataset_pair = mz.dataloader.Dataset(
...     data_processed, mode='pair', num_dup=2, num_neg=2, batch_size=32)
>>> len(dataset_pair)
1
callbacks

callbacks getter.

num_neg

num_neg getter.

num_dup

num_dup getter.

mode

mode getter.

batch_size

batch_size getter.

shuffle

shuffle getter.

sort

sort getter.

resample

resample getter.

batch_indices

batch_indices getter.

__getitem__(self, item)

Get a batch from index idx.

Parameters:item – the index of the batch.
__len__(self)

Get the total number of batches.

__iter__(self)

Create a generator that iterate over the Batches.

on_epoch_end(self)

Reorganize the index array if needed.

resample_data(self)

Reorganize data.

reset_index(self)

Set the _batch_indices.

Here the _batch_indices records the index of all the instances.

_handle_callbacks_on_batch_data_pack(self, batch_data_pack)
_handle_callbacks_on_batch_unpacked(self, x, y)
classmethod _reorganize_pair_wise(cls, relation: pd.DataFrame, num_dup: int = 1, num_neg: int = 1)

Re-organize the data pack as pair-wise format.