matchzoo.trainers.trainer

Base Trainer.

Module Contents

Classes

Trainer

MatchZoo tranier.

class matchzoo.trainers.trainer.Trainer(model: BaseModel, optimizer: optim.Optimizer, trainloader: DataLoader, validloader: DataLoader, device: typing.Union[torch.device, int, list, None] = None, start_epoch: int = 1, epochs: int = 10, validate_interval: typing.Optional[int] = None, scheduler: typing.Any = None, clip_norm: typing.Union[float, int] = None, patience: typing.Optional[int] = None, key: typing.Any = None, checkpoint: typing.Union[str, Path] = None, save_dir: typing.Union[str, Path] = None, save_all: bool = False, verbose: int = 1, **kwargs)

MatchZoo tranier.

Parameters
  • model – A BaseModel instance.

  • optimizer – A optim.Optimizer instance.

  • trainloader – A :class`DataLoader` instance. The dataloader is used for training the model.

  • validloader – A :class`DataLoader` instance. The dataloader is used for validating the model.

  • device – The desired device of returned tensor. Default: if None, use the current device. If torch.device or int, use device specified by user. If list, use data parallel.

  • start_epoch – Int. Number of starting epoch.

  • epochs – The maximum number of epochs for training. Defaults to 10.

  • validate_interval – Int. Interval of validation.

  • scheduler – LR scheduler used to adjust the learning rate based on the number of epochs.

  • clip_norm – Max norm of the gradients to be clipped.

  • patience – Number fo events to wait if no improvement and then stop the training.

  • key – Key of metric to be compared.

  • checkpoint – A checkpoint from which to continue training. If None, training starts from scratch. Defaults to None. Should be a file-like object (has to implement read, readline, tell, and seek), or a string containing a file name.

  • save_dir – Directory to save trainer.

  • save_all – Bool. If True, save Trainer instance; If False, only save model. Defaults to False.

  • verbose – 0, 1, or 2. Verbosity mode. 0 = silent, 1 = verbose, 2 = one log line per epoch.

_load_dataloader(self, trainloader: DataLoader, validloader: DataLoader, validate_interval: typing.Optional[int] = None)

Load trainloader and determine validate interval.

Parameters
  • trainloader – A :class`DataLoader` instance. The dataloader is used to train the model.

  • validloader – A :class`DataLoader` instance. The dataloader is used to validate the model.

  • validate_interval – int. Interval of validation.

_load_model(self, model: BaseModel, device: typing.Union[torch.device, int, list, None] = None)

Load model.

Parameters
  • modelBaseModel instance.

  • device – The desired device of returned tensor. Default: if None, use the current device. If torch.device or int, use device specified by user. If list, use data parallel.

_load_path(self, checkpoint: typing.Union[str, Path], save_dir: typing.Union[str, Path])

Load save_dir and Restore from checkpoint.

Parameters
  • checkpoint – A checkpoint from which to continue training. If None, training starts from scratch. Defaults to None. Should be a file-like object (has to implement read, readline, tell, and seek), or a string containing a file name.

  • save_dir – Directory to save trainer.

_backward(self, loss)

Computes the gradient of current loss graph leaves.

Parameters

loss – Tensor. Loss of model.

_run_scheduler(self)

Run scheduler.

run(self)

Train model.

The processes:

Run each epoch -> Run scheduler -> Should stop early?

_run_epoch(self)

Run each epoch.

The training steps:
  • Get batch and feed them into model

  • Get outputs. Caculate all losses and sum them up

  • Loss backwards and optimizer steps

  • Evaluation

  • Update and output result

evaluate(self, dataloader: DataLoader)

Evaluate the model.

Parameters

dataloader – A DataLoader object to iterate over the data.

classmethod _eval_metric_on_data_frame(cls, metric: BaseMetric, id_left: typing.Any, y_true: typing.Union[list, np.array], y_pred: typing.Union[list, np.array])

Eval metric on data frame.

This function is used to eval metrics for Ranking task.

Parameters
  • metric – Metric for Ranking task.

  • id_left – id of input left. Samples with same id_left should be grouped for evaluation.

  • y_true – Labels of dataset.

  • y_pred – Outputs of model.

Returns

Evaluation result.

predict(self, dataloader: DataLoader) → np.array

Generate output predictions for the input samples.

Parameters

dataloader – input DataLoader

Returns

predictions

_save(self)

Save.

save_model(self)

Save the model.

save(self)

Save the trainer.

Trainer parameters like epoch, best_so_far, model, optimizer and early_stopping will be savad to specific file path.

Parameters

path – Path to save trainer.

restore_model(self, checkpoint: typing.Union[str, Path])

Restore model.

Parameters

checkpoint – A checkpoint from which to continue training.

restore(self, checkpoint: typing.Union[str, Path] = None)

Restore trainer.

Parameters

checkpoint – A checkpoint from which to continue training.